From 2b2b994f7d110efd0ce5923fb154fc0290e00063 Mon Sep 17 00:00:00 2001 From: Pro7ech Date: Sun, 12 Oct 2025 21:34:10 +0200 Subject: [PATCH] wip --- poulpy-backend/src/cpu_fft64_avx/scratch.rs | 181 +---- poulpy-backend/src/cpu_fft64_avx/svp.rs | 2 +- .../src/cpu_fft64_avx/vec_znx_big.rs | 2 +- .../src/cpu_fft64_avx/vec_znx_dft.rs | 2 +- poulpy-backend/src/cpu_fft64_avx/vmp.rs | 2 +- poulpy-backend/src/cpu_fft64_ref/scratch.rs | 181 +---- poulpy-backend/src/cpu_fft64_ref/svp.rs | 2 +- .../src/cpu_fft64_ref/vec_znx_big.rs | 2 +- .../src/cpu_fft64_ref/vec_znx_dft.rs | 2 +- poulpy-backend/src/cpu_fft64_ref/vmp.rs | 2 +- .../src/cpu_spqlios/fft64/scratch.rs | 181 +---- .../src/cpu_spqlios/fft64/svp_ppol.rs | 2 +- .../src/cpu_spqlios/fft64/vec_znx_big.rs | 2 +- .../src/cpu_spqlios/fft64/vec_znx_dft.rs | 2 +- .../src/cpu_spqlios/fft64/vmp_pmat.rs | 2 +- .../src/cpu_spqlios/ntt120/svp_ppol.rs | 2 +- .../src/cpu_spqlios/ntt120/vec_znx_big.rs | 2 +- .../src/cpu_spqlios/ntt120/vec_znx_dft.rs | 2 +- poulpy-core/README.md | 4 +- .../benches/external_product_glwe_fft64.rs | 49 +- poulpy-core/benches/keyswitch_glwe_fft64.rs | 52 +- poulpy-core/examples/encryption.rs | 18 +- poulpy-core/src/automorphism/gglwe_atk.rs | 44 +- poulpy-core/src/automorphism/ggsw_ct.rs | 66 +- poulpy-core/src/automorphism/glwe_ct.rs | 82 +-- poulpy-core/src/conversion/gglwe_to_ggsw.rs | 279 ++++++++ poulpy-core/src/conversion/glwe_to_lwe.rs | 38 +- poulpy-core/src/conversion/lwe_to_glwe.rs | 34 +- poulpy-core/src/conversion/mod.rs | 3 + poulpy-core/src/decryption/glwe_ct.rs | 18 +- poulpy-core/src/decryption/lwe_ct.rs | 4 +- .../src/encryption/compressed/gglwe_atk.rs | 143 ++-- .../src/encryption/compressed/gglwe_ct.rs | 161 +++-- .../src/encryption/compressed/gglwe_ksk.rs | 119 ++-- .../src/encryption/compressed/gglwe_tsk.rs | 132 ++-- .../src/encryption/compressed/ggsw_ct.rs | 181 ++--- .../src/encryption/compressed/glwe_ct.rs | 137 ++-- poulpy-core/src/encryption/gglwe_atk.rs | 143 ++-- poulpy-core/src/encryption/gglwe_ct.rs | 171 +++-- poulpy-core/src/encryption/gglwe_ksk.rs | 46 +- poulpy-core/src/encryption/gglwe_tsk.rs | 46 +- poulpy-core/src/encryption/ggsw_ct.rs | 163 +++-- poulpy-core/src/encryption/glwe_ct.rs | 653 ++++++++++------- poulpy-core/src/encryption/glwe_pk.rs | 88 +-- poulpy-core/src/encryption/glwe_to_lwe_ksk.rs | 32 +- poulpy-core/src/encryption/lwe_ct.rs | 4 +- poulpy-core/src/encryption/lwe_ksk.rs | 33 +- poulpy-core/src/encryption/lwe_to_glwe_ksk.rs | 27 +- poulpy-core/src/encryption/mod.rs | 2 - poulpy-core/src/external_product/gglwe_atk.rs | 101 +-- poulpy-core/src/external_product/gglwe_ksk.rs | 248 ++++--- poulpy-core/src/external_product/ggsw_ct.rs | 247 ++++--- poulpy-core/src/external_product/glwe_ct.rs | 213 +++--- poulpy-core/src/external_product/mod.rs | 21 +- poulpy-core/src/glwe_packing.rs | 103 ++- poulpy-core/src/glwe_trace.rs | 50 +- poulpy-core/src/keyswitching/gglwe_ct.rs | 373 +++++----- poulpy-core/src/keyswitching/ggsw_ct.rs | 449 +++--------- poulpy-core/src/keyswitching/glwe_ct.rs | 660 ++++++++---------- poulpy-core/src/keyswitching/lwe_ct.rs | 190 ++--- .../src/layouts/compressed/gglwe_atk.rs | 191 +++-- .../src/layouts/compressed/gglwe_ct.rs | 257 ++++--- .../src/layouts/compressed/gglwe_ksk.rs | 186 +++-- .../src/layouts/compressed/gglwe_tsk.rs | 225 +++--- poulpy-core/src/layouts/compressed/ggsw_ct.rs | 216 ++++-- poulpy-core/src/layouts/compressed/glwe_ct.rs | 230 +++--- .../src/layouts/compressed/glwe_to_lwe_ksk.rs | 151 +++- poulpy-core/src/layouts/compressed/lwe_ct.rs | 159 ++++- poulpy-core/src/layouts/compressed/lwe_ksk.rs | 153 +++- .../src/layouts/compressed/lwe_to_glwe_ksk.rs | 152 +++- poulpy-core/src/layouts/compressed/mod.rs | 6 - poulpy-core/src/layouts/gglwe_atk.rs | 189 +++-- poulpy-core/src/layouts/gglwe_ct.rs | 322 ++++----- poulpy-core/src/layouts/gglwe_ksk.rs | 218 ++++-- poulpy-core/src/layouts/gglwe_tsk.rs | 180 +++-- poulpy-core/src/layouts/ggsw_ct.rs | 303 ++++---- poulpy-core/src/layouts/glwe_ct.rs | 264 +++---- poulpy-core/src/layouts/glwe_pk.rs | 198 +++--- poulpy-core/src/layouts/glwe_pt.rs | 215 +++--- poulpy-core/src/layouts/glwe_sk.rs | 108 ++- poulpy-core/src/layouts/glwe_to_lwe_ksk.rs | 140 +++- poulpy-core/src/layouts/lwe_ct.rs | 222 +++--- poulpy-core/src/layouts/lwe_ksk.rs | 162 +++-- poulpy-core/src/layouts/lwe_pt.rs | 51 +- poulpy-core/src/layouts/lwe_sk.rs | 51 +- poulpy-core/src/layouts/lwe_to_glwe_ksk.rs | 123 +++- poulpy-core/src/layouts/mod.rs | 28 +- poulpy-core/src/layouts/prepared/gglwe_atk.rs | 232 ++++-- poulpy-core/src/layouts/prepared/gglwe_ct.rs | 363 +++++----- poulpy-core/src/layouts/prepared/gglwe_ksk.rs | 225 ++++-- poulpy-core/src/layouts/prepared/gglwe_tsk.rs | 244 ++++--- poulpy-core/src/layouts/prepared/ggsw_ct.rs | 340 ++++----- poulpy-core/src/layouts/prepared/glwe_pk.rs | 252 ++++--- poulpy-core/src/layouts/prepared/glwe_sk.rs | 167 +++-- .../src/layouts/prepared/glwe_to_lwe_ksk.rs | 197 ++++-- poulpy-core/src/layouts/prepared/lwe_ksk.rs | 216 ++++-- .../src/layouts/prepared/lwe_to_glwe_ksk.rs | 205 ++++-- poulpy-core/src/layouts/prepared/mod.rs | 13 - poulpy-core/src/lib.rs | 3 +- poulpy-core/src/noise/gglwe_ct.rs | 20 +- poulpy-core/src/noise/ggsw_ct.rs | 40 +- poulpy-core/src/noise/glwe_ct.rs | 24 +- poulpy-core/src/operations/glwe.rs | 100 ++- poulpy-core/src/scratch.rs | 461 +++++------- poulpy-core/src/tests/serialization.rs | 52 +- .../test_suite/automorphism/gglwe_atk.rs | 78 +-- .../tests/test_suite/automorphism/ggsw_ct.rs | 76 +- .../tests/test_suite/automorphism/glwe_ct.rs | 74 +- .../src/tests/test_suite/conversion.rs | 63 +- .../tests/test_suite/encryption/gglwe_atk.rs | 40 +- .../tests/test_suite/encryption/gglwe_ct.rs | 39 +- .../tests/test_suite/encryption/ggsw_ct.rs | 78 +-- .../tests/test_suite/encryption/glwe_ct.rs | 94 ++- .../tests/test_suite/encryption/glwe_tsk.rs | 46 +- .../test_suite/external_product/gglwe_ksk.rs | 60 +- .../test_suite/external_product/ggsw_ct.rs | 56 +- .../test_suite/external_product/glwe_ct.rs | 60 +- .../tests/test_suite/keyswitch/gglwe_ct.rs | 78 +-- .../src/tests/test_suite/keyswitch/ggsw_ct.rs | 77 +- .../src/tests/test_suite/keyswitch/glwe_ct.rs | 73 +- .../src/tests/test_suite/keyswitch/lwe_ct.rs | 28 +- poulpy-core/src/tests/test_suite/packing.rs | 49 +- poulpy-core/src/tests/test_suite/trace.rs | 51 +- poulpy-hal/src/api/module.rs | 4 + poulpy-hal/src/api/scratch.rs | 159 +++-- poulpy-hal/src/api/svp_ppol.rs | 4 +- poulpy-hal/src/api/vec_znx_big.rs | 4 +- poulpy-hal/src/api/vec_znx_dft.rs | 4 +- poulpy-hal/src/api/vmp_pmat.rs | 4 +- poulpy-hal/src/delegates/scratch.rs | 114 +-- poulpy-hal/src/delegates/svp_ppol.rs | 8 +- poulpy-hal/src/delegates/vec_znx_big.rs | 8 +- poulpy-hal/src/delegates/vec_znx_dft.rs | 12 +- poulpy-hal/src/delegates/vmp_pmat.rs | 8 +- poulpy-hal/src/layouts/mat_znx.rs | 12 +- poulpy-hal/src/layouts/scalar_znx.rs | 6 +- poulpy-hal/src/layouts/svp_ppol.rs | 4 +- poulpy-hal/src/layouts/vec_znx.rs | 8 +- poulpy-hal/src/layouts/vec_znx_big.rs | 4 +- poulpy-hal/src/layouts/vec_znx_dft.rs | 4 +- poulpy-hal/src/layouts/vmp_pmat.rs | 6 +- poulpy-hal/src/layouts/zn.rs | 8 +- poulpy-hal/src/oep/scratch.rs | 110 +-- poulpy-hal/src/oep/svp_ppol.rs | 2 +- poulpy-hal/src/oep/vec_znx_big.rs | 2 +- poulpy-hal/src/oep/vec_znx_dft.rs | 2 +- poulpy-hal/src/oep/vmp_pmat.rs | 2 +- poulpy-hal/src/reference/fft64/vmp.rs | 2 +- .../benches/circuit_bootstrapping.rs | 50 +- .../examples/circuit_bootstrapping.rs | 32 +- .../src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs | 12 +- .../tfhe/bdd_arithmetic/ciphertexts/block.rs | 28 +- .../ciphertexts/block_prepared.rs | 28 +- .../tfhe/bdd_arithmetic/ciphertexts/word.rs | 28 +- .../src/tfhe/bdd_arithmetic/eval.rs | 36 +- poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs | 30 +- .../src/tfhe/bdd_arithmetic/parameters.rs | 26 +- .../src/tfhe/bdd_arithmetic/test.rs | 26 +- .../src/tfhe/blind_rotation/cggi_algo.rs | 81 ++- .../src/tfhe/blind_rotation/cggi_key.rs | 39 +- poulpy-schemes/src/tfhe/blind_rotation/key.rs | 26 +- .../src/tfhe/blind_rotation/key_compressed.rs | 12 +- .../src/tfhe/blind_rotation/key_prepared.rs | 14 +- poulpy-schemes/src/tfhe/blind_rotation/mod.rs | 6 +- .../tests/generic_blind_rotation.rs | 35 +- .../src/tfhe/circuit_bootstrapping/circuit.rs | 48 +- .../src/tfhe/circuit_bootstrapping/key.rs | 62 +- .../src/tfhe/circuit_bootstrapping/mod.rs | 10 +- .../tests/circuit_bootstrapping.rs | 78 +-- 169 files changed, 8705 insertions(+), 7677 deletions(-) create mode 100644 poulpy-core/src/conversion/gglwe_to_ggsw.rs diff --git a/poulpy-backend/src/cpu_fft64_avx/scratch.rs b/poulpy-backend/src/cpu_fft64_avx/scratch.rs index 922166b..f6595f9 100644 --- a/poulpy-backend/src/cpu_fft64_avx/scratch.rs +++ b/poulpy-backend/src/cpu_fft64_avx/scratch.rs @@ -3,13 +3,8 @@ use std::marker::PhantomData; use poulpy_hal::{ DEFAULTALIGN, alloc_aligned, api::ScratchFromBytes, - layouts::{Backend, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}, - oep::{ - ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, SvpPPolAllocBytesImpl, - TakeMatZnxImpl, TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl, - TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, TakeVmpPMatImpl, VecZnxBigAllocBytesImpl, - VecZnxDftAllocBytesImpl, VmpPMatAllocBytesImpl, - }, + layouts::{Backend, Scratch, ScratchOwned}, + oep::{ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeSliceImpl}, }; use crate::cpu_fft64_avx::FFT64Avx; @@ -64,178 +59,6 @@ where } } -unsafe impl TakeScalarZnxImpl for FFT64Avx -where - B: ScratchFromBytesImpl, -{ - fn take_scalar_znx_impl(scratch: &mut Scratch, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, ScalarZnx::alloc_bytes(n, cols)); - ( - ScalarZnx::from_data(take_slice, n, cols), - Scratch::from_bytes(rem_slice), - ) - } -} - -unsafe impl TakeSvpPPolImpl for FFT64Avx -where - B: SvpPPolAllocBytesImpl + ScratchFromBytesImpl, -{ - fn take_svp_ppol_impl(scratch: &mut Scratch, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, B::svp_ppol_alloc_bytes_impl(n, cols)); - ( - SvpPPol::from_data(take_slice, n, cols), - Scratch::from_bytes(rem_slice), - ) - } -} - -unsafe impl TakeVecZnxImpl for FFT64Avx -where - B: ScratchFromBytesImpl, -{ - fn take_vec_znx_impl(scratch: &mut Scratch, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, VecZnx::alloc_bytes(n, cols, size)); - ( - VecZnx::from_data(take_slice, n, cols, size), - Scratch::from_bytes(rem_slice), - ) - } -} - -unsafe impl TakeVecZnxBigImpl for FFT64Avx -where - B: VecZnxBigAllocBytesImpl + ScratchFromBytesImpl, -{ - fn take_vec_znx_big_impl( - scratch: &mut Scratch, - n: usize, - cols: usize, - size: usize, - ) -> (VecZnxBig<&mut [u8], B>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned( - &mut scratch.data, - B::vec_znx_big_alloc_bytes_impl(n, cols, size), - ); - ( - VecZnxBig::from_data(take_slice, n, cols, size), - Scratch::from_bytes(rem_slice), - ) - } -} - -unsafe impl TakeVecZnxDftImpl for FFT64Avx -where - B: VecZnxDftAllocBytesImpl + ScratchFromBytesImpl, -{ - fn take_vec_znx_dft_impl( - scratch: &mut Scratch, - n: usize, - cols: usize, - size: usize, - ) -> (VecZnxDft<&mut [u8], B>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned( - &mut scratch.data, - B::vec_znx_dft_alloc_bytes_impl(n, cols, size), - ); - - ( - VecZnxDft::from_data(take_slice, n, cols, size), - Scratch::from_bytes(rem_slice), - ) - } -} - -unsafe impl TakeVecZnxDftSliceImpl for FFT64Avx -where - B: VecZnxDftAllocBytesImpl + ScratchFromBytesImpl + TakeVecZnxDftImpl, -{ - fn take_vec_znx_dft_slice_impl( - scratch: &mut Scratch, - len: usize, - n: usize, - cols: usize, - size: usize, - ) -> (Vec>, &mut Scratch) { - let mut scratch: &mut Scratch = scratch; - let mut slice: Vec> = Vec::with_capacity(len); - for _ in 0..len { - let (znx, new_scratch) = B::take_vec_znx_dft_impl(scratch, n, cols, size); - scratch = new_scratch; - slice.push(znx); - } - (slice, scratch) - } -} - -unsafe impl TakeVecZnxSliceImpl for FFT64Avx -where - B: ScratchFromBytesImpl + TakeVecZnxImpl, -{ - fn take_vec_znx_slice_impl( - scratch: &mut Scratch, - len: usize, - n: usize, - cols: usize, - size: usize, - ) -> (Vec>, &mut Scratch) { - let mut scratch: &mut Scratch = scratch; - let mut slice: Vec> = Vec::with_capacity(len); - for _ in 0..len { - let (znx, new_scratch) = B::take_vec_znx_impl(scratch, n, cols, size); - scratch = new_scratch; - slice.push(znx); - } - (slice, scratch) - } -} - -unsafe impl TakeVmpPMatImpl for FFT64Avx -where - B: VmpPMatAllocBytesImpl + ScratchFromBytesImpl, -{ - fn take_vmp_pmat_impl( - scratch: &mut Scratch, - n: usize, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - ) -> (VmpPMat<&mut [u8], B>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned( - &mut scratch.data, - B::vmp_pmat_alloc_bytes_impl(n, rows, cols_in, cols_out, size), - ); - ( - VmpPMat::from_data(take_slice, n, rows, cols_in, cols_out, size), - Scratch::from_bytes(rem_slice), - ) - } -} - -unsafe impl TakeMatZnxImpl for FFT64Avx -where - B: ScratchFromBytesImpl, -{ - fn take_mat_znx_impl( - scratch: &mut Scratch, - n: usize, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - ) -> (MatZnx<&mut [u8]>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned( - &mut scratch.data, - MatZnx::alloc_bytes(n, rows, cols_in, cols_out, size), - ); - ( - MatZnx::from_data(take_slice, n, rows, cols_in, cols_out, size), - Scratch::from_bytes(rem_slice), - ) - } -} - fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]) { let ptr: *mut u8 = data.as_mut_ptr(); let self_len: usize = data.len(); diff --git a/poulpy-backend/src/cpu_fft64_avx/svp.rs b/poulpy-backend/src/cpu_fft64_avx/svp.rs index f505597..1c2c999 100644 --- a/poulpy-backend/src/cpu_fft64_avx/svp.rs +++ b/poulpy-backend/src/cpu_fft64_avx/svp.rs @@ -22,7 +22,7 @@ unsafe impl SvpPPolAllocImpl for FFT64Avx { } unsafe impl SvpPPolAllocBytesImpl for FFT64Avx { - fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize { + fn svp_ppol_bytes_of_impl(n: usize, cols: usize) -> usize { Self::layout_prep_word_count() * n * cols * size_of::() } } diff --git a/poulpy-backend/src/cpu_fft64_avx/vec_znx_big.rs b/poulpy-backend/src/cpu_fft64_avx/vec_znx_big.rs index 99a39fd..08ec98a 100644 --- a/poulpy-backend/src/cpu_fft64_avx/vec_znx_big.rs +++ b/poulpy-backend/src/cpu_fft64_avx/vec_znx_big.rs @@ -27,7 +27,7 @@ use poulpy_hal::{ }; unsafe impl VecZnxBigAllocBytesImpl for FFT64Avx { - fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize { + fn vec_znx_big_bytes_of_impl(n: usize, cols: usize, size: usize) -> usize { Self::layout_big_word_count() * n * cols * size * size_of::() } } diff --git a/poulpy-backend/src/cpu_fft64_avx/vec_znx_dft.rs b/poulpy-backend/src/cpu_fft64_avx/vec_znx_dft.rs index 862f623..063ee26 100644 --- a/poulpy-backend/src/cpu_fft64_avx/vec_znx_dft.rs +++ b/poulpy-backend/src/cpu_fft64_avx/vec_znx_dft.rs @@ -24,7 +24,7 @@ unsafe impl VecZnxDftFromBytesImpl for FFT64Avx { } unsafe impl VecZnxDftAllocBytesImpl for FFT64Avx { - fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize { + fn vec_znx_dft_bytes_of_impl(n: usize, cols: usize, size: usize) -> usize { Self::layout_prep_word_count() * n * cols * size * size_of::<::ScalarPrep>() } } diff --git a/poulpy-backend/src/cpu_fft64_avx/vmp.rs b/poulpy-backend/src/cpu_fft64_avx/vmp.rs index 6b87ce1..fcb6236 100644 --- a/poulpy-backend/src/cpu_fft64_avx/vmp.rs +++ b/poulpy-backend/src/cpu_fft64_avx/vmp.rs @@ -16,7 +16,7 @@ use poulpy_hal::{ use crate::cpu_fft64_avx::{FFT64Avx, module::FFT64ModuleHandle}; unsafe impl VmpPMatAllocBytesImpl for FFT64Avx { - fn vmp_pmat_alloc_bytes_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { + fn vmp_pmat_bytes_of_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { Self::layout_prep_word_count() * n * rows * cols_in * cols_out * size * size_of::() } } diff --git a/poulpy-backend/src/cpu_fft64_ref/scratch.rs b/poulpy-backend/src/cpu_fft64_ref/scratch.rs index 80b228d..1593370 100644 --- a/poulpy-backend/src/cpu_fft64_ref/scratch.rs +++ b/poulpy-backend/src/cpu_fft64_ref/scratch.rs @@ -3,13 +3,8 @@ use std::marker::PhantomData; use poulpy_hal::{ DEFAULTALIGN, alloc_aligned, api::ScratchFromBytes, - layouts::{Backend, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}, - oep::{ - ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, SvpPPolAllocBytesImpl, - TakeMatZnxImpl, TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl, - TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, TakeVmpPMatImpl, VecZnxBigAllocBytesImpl, - VecZnxDftAllocBytesImpl, VmpPMatAllocBytesImpl, - }, + layouts::{Backend, Scratch, ScratchOwned}, + oep::{ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeSliceImpl}, }; use crate::cpu_fft64_ref::FFT64Ref; @@ -64,178 +59,6 @@ where } } -unsafe impl TakeScalarZnxImpl for FFT64Ref -where - B: ScratchFromBytesImpl, -{ - fn take_scalar_znx_impl(scratch: &mut Scratch, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, ScalarZnx::alloc_bytes(n, cols)); - ( - ScalarZnx::from_data(take_slice, n, cols), - Scratch::from_bytes(rem_slice), - ) - } -} - -unsafe impl TakeSvpPPolImpl for FFT64Ref -where - B: SvpPPolAllocBytesImpl + ScratchFromBytesImpl, -{ - fn take_svp_ppol_impl(scratch: &mut Scratch, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, B::svp_ppol_alloc_bytes_impl(n, cols)); - ( - SvpPPol::from_data(take_slice, n, cols), - Scratch::from_bytes(rem_slice), - ) - } -} - -unsafe impl TakeVecZnxImpl for FFT64Ref -where - B: ScratchFromBytesImpl, -{ - fn take_vec_znx_impl(scratch: &mut Scratch, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, VecZnx::alloc_bytes(n, cols, size)); - ( - VecZnx::from_data(take_slice, n, cols, size), - Scratch::from_bytes(rem_slice), - ) - } -} - -unsafe impl TakeVecZnxBigImpl for FFT64Ref -where - B: VecZnxBigAllocBytesImpl + ScratchFromBytesImpl, -{ - fn take_vec_znx_big_impl( - scratch: &mut Scratch, - n: usize, - cols: usize, - size: usize, - ) -> (VecZnxBig<&mut [u8], B>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned( - &mut scratch.data, - B::vec_znx_big_alloc_bytes_impl(n, cols, size), - ); - ( - VecZnxBig::from_data(take_slice, n, cols, size), - Scratch::from_bytes(rem_slice), - ) - } -} - -unsafe impl TakeVecZnxDftImpl for FFT64Ref -where - B: VecZnxDftAllocBytesImpl + ScratchFromBytesImpl, -{ - fn take_vec_znx_dft_impl( - scratch: &mut Scratch, - n: usize, - cols: usize, - size: usize, - ) -> (VecZnxDft<&mut [u8], B>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned( - &mut scratch.data, - B::vec_znx_dft_alloc_bytes_impl(n, cols, size), - ); - - ( - VecZnxDft::from_data(take_slice, n, cols, size), - Scratch::from_bytes(rem_slice), - ) - } -} - -unsafe impl TakeVecZnxDftSliceImpl for FFT64Ref -where - B: VecZnxDftAllocBytesImpl + ScratchFromBytesImpl + TakeVecZnxDftImpl, -{ - fn take_vec_znx_dft_slice_impl( - scratch: &mut Scratch, - len: usize, - n: usize, - cols: usize, - size: usize, - ) -> (Vec>, &mut Scratch) { - let mut scratch: &mut Scratch = scratch; - let mut slice: Vec> = Vec::with_capacity(len); - for _ in 0..len { - let (znx, new_scratch) = B::take_vec_znx_dft_impl(scratch, n, cols, size); - scratch = new_scratch; - slice.push(znx); - } - (slice, scratch) - } -} - -unsafe impl TakeVecZnxSliceImpl for FFT64Ref -where - B: ScratchFromBytesImpl + TakeVecZnxImpl, -{ - fn take_vec_znx_slice_impl( - scratch: &mut Scratch, - len: usize, - n: usize, - cols: usize, - size: usize, - ) -> (Vec>, &mut Scratch) { - let mut scratch: &mut Scratch = scratch; - let mut slice: Vec> = Vec::with_capacity(len); - for _ in 0..len { - let (znx, new_scratch) = B::take_vec_znx_impl(scratch, n, cols, size); - scratch = new_scratch; - slice.push(znx); - } - (slice, scratch) - } -} - -unsafe impl TakeVmpPMatImpl for FFT64Ref -where - B: VmpPMatAllocBytesImpl + ScratchFromBytesImpl, -{ - fn take_vmp_pmat_impl( - scratch: &mut Scratch, - n: usize, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - ) -> (VmpPMat<&mut [u8], B>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned( - &mut scratch.data, - B::vmp_pmat_alloc_bytes_impl(n, rows, cols_in, cols_out, size), - ); - ( - VmpPMat::from_data(take_slice, n, rows, cols_in, cols_out, size), - Scratch::from_bytes(rem_slice), - ) - } -} - -unsafe impl TakeMatZnxImpl for FFT64Ref -where - B: ScratchFromBytesImpl, -{ - fn take_mat_znx_impl( - scratch: &mut Scratch, - n: usize, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - ) -> (MatZnx<&mut [u8]>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned( - &mut scratch.data, - MatZnx::alloc_bytes(n, rows, cols_in, cols_out, size), - ); - ( - MatZnx::from_data(take_slice, n, rows, cols_in, cols_out, size), - Scratch::from_bytes(rem_slice), - ) - } -} - fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]) { let ptr: *mut u8 = data.as_mut_ptr(); let self_len: usize = data.len(); diff --git a/poulpy-backend/src/cpu_fft64_ref/svp.rs b/poulpy-backend/src/cpu_fft64_ref/svp.rs index 06dad9e..37ca14a 100644 --- a/poulpy-backend/src/cpu_fft64_ref/svp.rs +++ b/poulpy-backend/src/cpu_fft64_ref/svp.rs @@ -22,7 +22,7 @@ unsafe impl SvpPPolAllocImpl for FFT64Ref { } unsafe impl SvpPPolAllocBytesImpl for FFT64Ref { - fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize { + fn svp_ppol_bytes_of_impl(n: usize, cols: usize) -> usize { Self::layout_prep_word_count() * n * cols * size_of::() } } diff --git a/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs b/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs index bb75c8f..348c9b6 100644 --- a/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs +++ b/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs @@ -27,7 +27,7 @@ use poulpy_hal::{ }; unsafe impl VecZnxBigAllocBytesImpl for FFT64Ref { - fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize { + fn vec_znx_big_bytes_of_impl(n: usize, cols: usize, size: usize) -> usize { Self::layout_big_word_count() * n * cols * size * size_of::() } } diff --git a/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs b/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs index a08b728..a2a743d 100644 --- a/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs +++ b/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs @@ -24,7 +24,7 @@ unsafe impl VecZnxDftFromBytesImpl for FFT64Ref { } unsafe impl VecZnxDftAllocBytesImpl for FFT64Ref { - fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize { + fn vec_znx_dft_bytes_of_impl(n: usize, cols: usize, size: usize) -> usize { Self::layout_prep_word_count() * n * cols * size * size_of::<::ScalarPrep>() } } diff --git a/poulpy-backend/src/cpu_fft64_ref/vmp.rs b/poulpy-backend/src/cpu_fft64_ref/vmp.rs index 2286de5..34cbf07 100644 --- a/poulpy-backend/src/cpu_fft64_ref/vmp.rs +++ b/poulpy-backend/src/cpu_fft64_ref/vmp.rs @@ -16,7 +16,7 @@ use poulpy_hal::{ use crate::cpu_fft64_ref::{FFT64Ref, module::FFT64ModuleHandle}; unsafe impl VmpPMatAllocBytesImpl for FFT64Ref { - fn vmp_pmat_alloc_bytes_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { + fn vmp_pmat_bytes_of_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { Self::layout_prep_word_count() * n * rows * cols_in * cols_out * size * size_of::() } } diff --git a/poulpy-backend/src/cpu_spqlios/fft64/scratch.rs b/poulpy-backend/src/cpu_spqlios/fft64/scratch.rs index 9bddcb3..d32b9f4 100644 --- a/poulpy-backend/src/cpu_spqlios/fft64/scratch.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/scratch.rs @@ -3,13 +3,8 @@ use std::marker::PhantomData; use poulpy_hal::{ DEFAULTALIGN, alloc_aligned, api::ScratchFromBytes, - layouts::{Backend, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}, - oep::{ - ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, SvpPPolAllocBytesImpl, - TakeMatZnxImpl, TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl, - TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, TakeVmpPMatImpl, VecZnxBigAllocBytesImpl, - VecZnxDftAllocBytesImpl, VmpPMatAllocBytesImpl, - }, + layouts::{Backend, Scratch, ScratchOwned}, + oep::{ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeSliceImpl}, }; use crate::cpu_spqlios::FFT64Spqlios; @@ -64,178 +59,6 @@ where } } -unsafe impl TakeScalarZnxImpl for FFT64Spqlios -where - B: ScratchFromBytesImpl, -{ - fn take_scalar_znx_impl(scratch: &mut Scratch, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, ScalarZnx::alloc_bytes(n, cols)); - ( - ScalarZnx::from_data(take_slice, n, cols), - Scratch::from_bytes(rem_slice), - ) - } -} - -unsafe impl TakeSvpPPolImpl for FFT64Spqlios -where - B: SvpPPolAllocBytesImpl + ScratchFromBytesImpl, -{ - fn take_svp_ppol_impl(scratch: &mut Scratch, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, B::svp_ppol_alloc_bytes_impl(n, cols)); - ( - SvpPPol::from_data(take_slice, n, cols), - Scratch::from_bytes(rem_slice), - ) - } -} - -unsafe impl TakeVecZnxImpl for FFT64Spqlios -where - B: ScratchFromBytesImpl, -{ - fn take_vec_znx_impl(scratch: &mut Scratch, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, VecZnx::alloc_bytes(n, cols, size)); - ( - VecZnx::from_data(take_slice, n, cols, size), - Scratch::from_bytes(rem_slice), - ) - } -} - -unsafe impl TakeVecZnxBigImpl for FFT64Spqlios -where - B: VecZnxBigAllocBytesImpl + ScratchFromBytesImpl, -{ - fn take_vec_znx_big_impl( - scratch: &mut Scratch, - n: usize, - cols: usize, - size: usize, - ) -> (VecZnxBig<&mut [u8], B>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned( - &mut scratch.data, - B::vec_znx_big_alloc_bytes_impl(n, cols, size), - ); - ( - VecZnxBig::from_data(take_slice, n, cols, size), - Scratch::from_bytes(rem_slice), - ) - } -} - -unsafe impl TakeVecZnxDftImpl for FFT64Spqlios -where - B: VecZnxDftAllocBytesImpl + ScratchFromBytesImpl, -{ - fn take_vec_znx_dft_impl( - scratch: &mut Scratch, - n: usize, - cols: usize, - size: usize, - ) -> (VecZnxDft<&mut [u8], B>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned( - &mut scratch.data, - B::vec_znx_dft_alloc_bytes_impl(n, cols, size), - ); - - ( - VecZnxDft::from_data(take_slice, n, cols, size), - Scratch::from_bytes(rem_slice), - ) - } -} - -unsafe impl TakeVecZnxDftSliceImpl for FFT64Spqlios -where - B: VecZnxDftAllocBytesImpl + ScratchFromBytesImpl + TakeVecZnxDftImpl, -{ - fn take_vec_znx_dft_slice_impl( - scratch: &mut Scratch, - len: usize, - n: usize, - cols: usize, - size: usize, - ) -> (Vec>, &mut Scratch) { - let mut scratch: &mut Scratch = scratch; - let mut slice: Vec> = Vec::with_capacity(len); - for _ in 0..len { - let (znx, new_scratch) = B::take_vec_znx_dft_impl(scratch, n, cols, size); - scratch = new_scratch; - slice.push(znx); - } - (slice, scratch) - } -} - -unsafe impl TakeVecZnxSliceImpl for FFT64Spqlios -where - B: ScratchFromBytesImpl + TakeVecZnxImpl, -{ - fn take_vec_znx_slice_impl( - scratch: &mut Scratch, - len: usize, - n: usize, - cols: usize, - size: usize, - ) -> (Vec>, &mut Scratch) { - let mut scratch: &mut Scratch = scratch; - let mut slice: Vec> = Vec::with_capacity(len); - for _ in 0..len { - let (znx, new_scratch) = B::take_vec_znx_impl(scratch, n, cols, size); - scratch = new_scratch; - slice.push(znx); - } - (slice, scratch) - } -} - -unsafe impl TakeVmpPMatImpl for FFT64Spqlios -where - B: VmpPMatAllocBytesImpl + ScratchFromBytesImpl, -{ - fn take_vmp_pmat_impl( - scratch: &mut Scratch, - n: usize, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - ) -> (VmpPMat<&mut [u8], B>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned( - &mut scratch.data, - B::vmp_pmat_alloc_bytes_impl(n, rows, cols_in, cols_out, size), - ); - ( - VmpPMat::from_data(take_slice, n, rows, cols_in, cols_out, size), - Scratch::from_bytes(rem_slice), - ) - } -} - -unsafe impl TakeMatZnxImpl for FFT64Spqlios -where - B: ScratchFromBytesImpl, -{ - fn take_mat_znx_impl( - scratch: &mut Scratch, - n: usize, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - ) -> (MatZnx<&mut [u8]>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned( - &mut scratch.data, - MatZnx::alloc_bytes(n, rows, cols_in, cols_out, size), - ); - ( - MatZnx::from_data(take_slice, n, rows, cols_in, cols_out, size), - Scratch::from_bytes(rem_slice), - ) - } -} - fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]) { let ptr: *mut u8 = data.as_mut_ptr(); let self_len: usize = data.len(); diff --git a/poulpy-backend/src/cpu_spqlios/fft64/svp_ppol.rs b/poulpy-backend/src/cpu_spqlios/fft64/svp_ppol.rs index b917400..f46b795 100644 --- a/poulpy-backend/src/cpu_spqlios/fft64/svp_ppol.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/svp_ppol.rs @@ -27,7 +27,7 @@ unsafe impl SvpPPolAllocImpl for FFT64Spqlios { } unsafe impl SvpPPolAllocBytesImpl for FFT64Spqlios { - fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize { + fn svp_ppol_bytes_of_impl(n: usize, cols: usize) -> usize { FFT64Spqlios::layout_prep_word_count() * n * cols * size_of::() } } diff --git a/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_big.rs b/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_big.rs index 8becaf6..5021f6b 100644 --- a/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_big.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_big.rs @@ -22,7 +22,7 @@ use poulpy_hal::{ }; unsafe impl VecZnxBigAllocBytesImpl for FFT64Spqlios { - fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize { + fn vec_znx_big_bytes_of_impl(n: usize, cols: usize, size: usize) -> usize { Self::layout_big_word_count() * n * cols * size * size_of::() } } diff --git a/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_dft.rs b/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_dft.rs index 461d327..3b67089 100644 --- a/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_dft.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_dft.rs @@ -30,7 +30,7 @@ unsafe impl VecZnxDftFromBytesImpl for FFT64Spqlios { } unsafe impl VecZnxDftAllocBytesImpl for FFT64Spqlios { - fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize { + fn vec_znx_dft_bytes_of_impl(n: usize, cols: usize, size: usize) -> usize { Self::layout_prep_word_count() * n * cols * size * size_of::<::ScalarPrep>() } } diff --git a/poulpy-backend/src/cpu_spqlios/fft64/vmp_pmat.rs b/poulpy-backend/src/cpu_spqlios/fft64/vmp_pmat.rs index ca64992..ff1eaa2 100644 --- a/poulpy-backend/src/cpu_spqlios/fft64/vmp_pmat.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/vmp_pmat.rs @@ -16,7 +16,7 @@ use crate::cpu_spqlios::{ }; unsafe impl VmpPMatAllocBytesImpl for FFT64Spqlios { - fn vmp_pmat_alloc_bytes_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { + fn vmp_pmat_bytes_of_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { Self::layout_prep_word_count() * n * rows * cols_in * cols_out * size * size_of::() } } diff --git a/poulpy-backend/src/cpu_spqlios/ntt120/svp_ppol.rs b/poulpy-backend/src/cpu_spqlios/ntt120/svp_ppol.rs index c98237a..8c7fdcc 100644 --- a/poulpy-backend/src/cpu_spqlios/ntt120/svp_ppol.rs +++ b/poulpy-backend/src/cpu_spqlios/ntt120/svp_ppol.rs @@ -18,7 +18,7 @@ unsafe impl SvpPPolAllocImpl for NTT120 { } unsafe impl SvpPPolAllocBytesImpl for NTT120 { - fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize { + fn svp_ppol_bytes_of_impl(n: usize, cols: usize) -> usize { NTT120::layout_prep_word_count() * n * cols * size_of::() } } diff --git a/poulpy-backend/src/cpu_spqlios/ntt120/vec_znx_big.rs b/poulpy-backend/src/cpu_spqlios/ntt120/vec_znx_big.rs index 715b432..58ddf78 100644 --- a/poulpy-backend/src/cpu_spqlios/ntt120/vec_znx_big.rs +++ b/poulpy-backend/src/cpu_spqlios/ntt120/vec_znx_big.rs @@ -3,7 +3,7 @@ use poulpy_hal::{layouts::Backend, oep::VecZnxBigAllocBytesImpl}; use crate::cpu_spqlios::NTT120; unsafe impl VecZnxBigAllocBytesImpl for NTT120 { - fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize { + fn vec_znx_big_bytes_of_impl(n: usize, cols: usize, size: usize) -> usize { NTT120::layout_big_word_count() * n * cols * size * size_of::() } } diff --git a/poulpy-backend/src/cpu_spqlios/ntt120/vec_znx_dft.rs b/poulpy-backend/src/cpu_spqlios/ntt120/vec_znx_dft.rs index 53dd24f..9e1666b 100644 --- a/poulpy-backend/src/cpu_spqlios/ntt120/vec_znx_dft.rs +++ b/poulpy-backend/src/cpu_spqlios/ntt120/vec_znx_dft.rs @@ -6,7 +6,7 @@ use poulpy_hal::{ use crate::cpu_spqlios::NTT120; unsafe impl VecZnxDftAllocBytesImpl for NTT120 { - fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize { + fn vec_znx_dft_bytes_of_impl(n: usize, cols: usize, size: usize) -> usize { NTT120::layout_prep_word_count() * n * cols * size * size_of::() } } diff --git a/poulpy-core/README.md b/poulpy-core/README.md index 07d5304..259988e 100644 --- a/poulpy-core/README.md +++ b/poulpy-core/README.md @@ -52,8 +52,8 @@ fn main() { // Scratch space let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::encrypt_sk_scratch_space(&module, n, base2k, ct.k()) - | GLWECiphertext::decrypt_scratch_space(&module, n, base2k, ct.k()), + GLWECiphertext::encrypt_sk_tmp_bytes(&module, n, base2k, ct.k()) + | GLWECiphertext::decrypt_tmp_bytes(&module, n, base2k, ct.k()), ); // Generate secret-key diff --git a/poulpy-core/benches/external_product_glwe_fft64.rs b/poulpy-core/benches/external_product_glwe_fft64.rs index 360a9f4..2849e3d 100644 --- a/poulpy-core/benches/external_product_glwe_fft64.rs +++ b/poulpy-core/benches/external_product_glwe_fft64.rs @@ -1,7 +1,6 @@ use poulpy_core::layouts::{ - Base2K, Degree, Dnum, Dsize, GGSWCiphertext, GGSWCiphertextLayout, GLWECiphertext, GLWECiphertextLayout, GLWESecret, Rank, - TorusPrecision, - prepared::{GGSWCiphertextPrepared, GLWESecretPrepared, PrepareAlloc}, + Base2K, Dnum, Dsize, GGSW, GGSWLayout, GLWE, GLWELayout, GLWESecret, Rank, RingDegree, TorusPrecision, + prepared::{GGSWPrepared, GLWESecretPrepared, PrepareAlloc}, }; use std::hint::black_box; @@ -29,7 +28,7 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) { fn runner(p: Params) -> impl FnMut() { let module: Module = Module::::new(1 << p.log_n); - let n: Degree = Degree(module.n() as u32); + let n: RingDegree = RingDegree(module.n() as u32); let base2k: Base2K = p.base2k; let k_ct_in: TorusPrecision = p.k_ct_in; let k_ct_out: TorusPrecision = p.k_ct_out; @@ -39,7 +38,7 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) { let dnum: Dnum = Dnum(1); //(p.k_ct_in.div_ceil(p.base2k); - let ggsw_layout: GGSWCiphertextLayout = GGSWCiphertextLayout { + let ggsw_layout: GGSWLayout = GGSWLayout { n, base2k, k: k_ggsw, @@ -48,36 +47,36 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) { rank, }; - let glwe_out_layout: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_out_layout: GLWELayout = GLWELayout { n, base2k, k: k_ct_out, rank, }; - let glwe_in_layout: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_in_layout: GLWELayout = GLWELayout { n, base2k, k: k_ct_in, rank, }; - let mut ct_ggsw: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_layout); - let mut ct_glwe_in: GLWECiphertext> = GLWECiphertext::alloc(&glwe_in_layout); - let mut ct_glwe_out: GLWECiphertext> = GLWECiphertext::alloc(&glwe_out_layout); + let mut ct_ggsw: GGSW> = GGSW::alloc_from_infos(&ggsw_layout); + let mut ct_glwe_in: GLWE> = GLWE::alloc_from_infos(&glwe_in_layout); + let mut ct_glwe_out: GLWE> = GLWE::alloc_from_infos(&glwe_out_layout); let pt_rgsw: ScalarZnx> = ScalarZnx::alloc(n.into(), 1); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(&module, &ggsw_layout) - | GLWECiphertext::encrypt_sk_scratch_space(&module, &glwe_in_layout) - | GLWECiphertext::external_product_scratch_space(&module, &glwe_out_layout, &glwe_in_layout, &ggsw_layout), + GGSW::encrypt_sk_tmp_bytes(&module, &ggsw_layout) + | GLWE::encrypt_sk_tmp_bytes(&module, &glwe_in_layout) + | GLWE::external_product_tmp_bytes(&module, &glwe_out_layout, &glwe_in_layout, &ggsw_layout), ); let mut source_xs = Source::new([0u8; 32]); let mut source_xe = Source::new([0u8; 32]); let mut source_xa = Source::new([0u8; 32]); - let mut sk: GLWESecret> = GLWESecret::alloc(&glwe_in_layout); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_in_layout); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_dft: GLWESecretPrepared, FFT64Spqlios> = sk.prepare_alloc(&module, scratch.borrow()); @@ -98,7 +97,7 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) { scratch.borrow(), ); - let ggsw_prepared: GGSWCiphertextPrepared, FFT64Spqlios> = ct_ggsw.prepare_alloc(&module, scratch.borrow()); + let ggsw_prepared: GGSWPrepared, FFT64Spqlios> = ct_ggsw.prepare_alloc(&module, scratch.borrow()); move || { ct_glwe_out.external_product(&module, &ct_glwe_in, &ggsw_prepared, scratch.borrow()); @@ -138,7 +137,7 @@ fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) { fn runner(p: Params) -> impl FnMut() { let module: Module = Module::::new(1 << p.log_n); - let n: Degree = Degree(module.n() as u32); + let n: RingDegree = RingDegree(module.n() as u32); let base2k: Base2K = p.base2k; let k_glwe: TorusPrecision = p.k_ct; let k_ggsw: TorusPrecision = p.k_ggsw; @@ -147,7 +146,7 @@ fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) { let dnum: Dnum = p.k_ct.div_ceil(p.base2k).into(); - let ggsw_layout: GGSWCiphertextLayout = GGSWCiphertextLayout { + let ggsw_layout: GGSWLayout = GGSWLayout { n, base2k, k: k_ggsw, @@ -156,28 +155,28 @@ fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) { rank, }; - let glwe_layout: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_layout: GLWELayout = GLWELayout { n, base2k, k: k_glwe, rank, }; - let mut ct_ggsw: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_layout); - let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(&glwe_layout); + let mut ct_ggsw: GGSW> = GGSW::alloc_from_infos(&ggsw_layout); + let mut ct_glwe: GLWE> = GLWE::alloc_from_infos(&glwe_layout); let pt_rgsw: ScalarZnx> = ScalarZnx::alloc(n.into(), 1); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(&module, &ggsw_layout) - | GLWECiphertext::encrypt_sk_scratch_space(&module, &glwe_layout) - | GLWECiphertext::external_product_inplace_scratch_space(&module, &glwe_layout, &ggsw_layout), + GGSW::encrypt_sk_tmp_bytes(&module, &ggsw_layout) + | GLWE::encrypt_sk_tmp_bytes(&module, &glwe_layout) + | GLWE::external_product_inplace_tmp_bytes(&module, &glwe_layout, &ggsw_layout), ); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut sk: GLWESecret> = GLWESecret::alloc(&glwe_layout); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_layout); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_dft: GLWESecretPrepared, FFT64Spqlios> = sk.prepare_alloc(&module, scratch.borrow()); @@ -198,7 +197,7 @@ fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) { scratch.borrow(), ); - let ggsw_prepared: GGSWCiphertextPrepared, FFT64Spqlios> = ct_ggsw.prepare_alloc(&module, scratch.borrow()); + let ggsw_prepared: GGSWPrepared, FFT64Spqlios> = ct_ggsw.prepare_alloc(&module, scratch.borrow()); move || { let scratch_borrow = scratch.borrow(); diff --git a/poulpy-core/benches/keyswitch_glwe_fft64.rs b/poulpy-core/benches/keyswitch_glwe_fft64.rs index 2da2032..fc1dc69 100644 --- a/poulpy-core/benches/keyswitch_glwe_fft64.rs +++ b/poulpy-core/benches/keyswitch_glwe_fft64.rs @@ -1,7 +1,7 @@ use poulpy_core::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWEAutomorphismKey, GGLWEAutomorphismKeyLayout, GGLWESwitchingKey, GGLWESwitchingKeyLayout, - GLWECiphertext, GLWECiphertextLayout, GLWESecret, Rank, TorusPrecision, - prepared::{GGLWEAutomorphismKeyPrepared, GGLWESwitchingKeyPrepared, GLWESecretPrepared, PrepareAlloc}, + AutomorphismKey, AutomorphismKeyLayout, Base2K, Dnum, Dsize, GLWE, GLWELayout, GLWESecret, GLWESwitchingKey, + GLWESwitchingKeyLayout, Rank, RingDegree, TorusPrecision, + prepared::{AutomorphismKeyPrepared, GLWESecretPrepared, GLWESwitchingKeyPrepared, PrepareAlloc}, }; use std::{hint::black_box, time::Duration}; @@ -29,7 +29,7 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { fn runner(p: Params) -> impl FnMut() { let module: Module = Module::::new(1 << p.log_n); - let n: Degree = Degree(module.n() as u32); + let n: RingDegree = RingDegree(module.n() as u32); let base2k: Base2K = p.base2k; let k_glwe_in: TorusPrecision = p.k_ct_in; let k_glwe_out: TorusPrecision = p.k_ct_out; @@ -39,7 +39,7 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { let dnum: Dnum = p.k_ct_in.div_ceil(p.base2k.0 * dsize.0).into(); - let gglwe_atk_layout: GGLWEAutomorphismKeyLayout = GGLWEAutomorphismKeyLayout { + let gglwe_atk_layout: AutomorphismKeyLayout = AutomorphismKeyLayout { n, base2k, k: k_gglwe, @@ -48,28 +48,28 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { dsize, }; - let glwe_in_layout: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_in_layout: GLWELayout = GLWELayout { n, base2k, k: k_glwe_in, rank, }; - let glwe_out_layout: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_out_layout: GLWELayout = GLWELayout { n, base2k, k: k_glwe_out, rank, }; - let mut ksk: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&gglwe_atk_layout); - let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(&glwe_in_layout); - let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(&glwe_out_layout); + let mut ksk: AutomorphismKey> = AutomorphismKey::alloc_from_infos(&gglwe_atk_layout); + let mut ct_in: GLWE> = GLWE::alloc_from_infos(&glwe_in_layout); + let mut ct_out: GLWE> = GLWE::alloc_from_infos(&glwe_out_layout); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWESwitchingKey::encrypt_sk_scratch_space(&module, &gglwe_atk_layout) - | GLWECiphertext::encrypt_sk_scratch_space(&module, &glwe_in_layout) - | GLWECiphertext::keyswitch_scratch_space( + GLWESwitchingKey::encrypt_sk_tmp_bytes(&module, &gglwe_atk_layout) + | GLWE::encrypt_sk_tmp_bytes(&module, &glwe_in_layout) + | GLWE::keyswitch_tmp_bytes( &module, &glwe_out_layout, &glwe_in_layout, @@ -81,7 +81,7 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut sk_in: GLWESecret> = GLWESecret::alloc(&glwe_in_layout); + let mut sk_in: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_in_layout); sk_in.fill_ternary_prob(0.5, &mut source_xs); let sk_in_dft: GLWESecretPrepared, FFT64Spqlios> = sk_in.prepare_alloc(&module, scratch.borrow()); @@ -102,7 +102,7 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { scratch.borrow(), ); - let ksk_prepared: GGLWEAutomorphismKeyPrepared, _> = ksk.prepare_alloc(&module, scratch.borrow()); + let ksk_prepared: AutomorphismKeyPrepared, _> = ksk.prepare_alloc(&module, scratch.borrow()); move || { ct_out.automorphism(&module, &ct_in, &ksk_prepared, scratch.borrow()); @@ -148,7 +148,7 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) { fn runner(p: Params) -> impl FnMut() { let module: Module = Module::::new(1 << p.log_n); - let n: Degree = Degree(module.n() as u32); + let n: RingDegree = RingDegree(module.n() as u32); let base2k: Base2K = p.base2k; let k_ct: TorusPrecision = p.k_ct; let k_ksk: TorusPrecision = p.k_ksk; @@ -157,7 +157,7 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) { let dnum: Dnum = p.k_ct.div_ceil(p.base2k).into(); - let gglwe_layout: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + let gglwe_layout: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n, base2k, k: k_ksk, @@ -167,31 +167,31 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) { rank_out: rank, }; - let glwe_layout: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_layout: GLWELayout = GLWELayout { n, base2k, k: k_ct, rank, }; - let mut ksk: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&gglwe_layout); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&glwe_layout); + let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc_from_infos(&gglwe_layout); + let mut ct: GLWE> = GLWE::alloc_from_infos(&glwe_layout); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWESwitchingKey::encrypt_sk_scratch_space(&module, &gglwe_layout) - | GLWECiphertext::encrypt_sk_scratch_space(&module, &glwe_layout) - | GLWECiphertext::keyswitch_inplace_scratch_space(&module, &glwe_layout, &gglwe_layout), + GLWESwitchingKey::encrypt_sk_tmp_bytes(&module, &gglwe_layout) + | GLWE::encrypt_sk_tmp_bytes(&module, &glwe_layout) + | GLWE::keyswitch_inplace_tmp_bytes(&module, &glwe_layout, &gglwe_layout), ); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut sk_in: GLWESecret> = GLWESecret::alloc(&glwe_layout); + let mut sk_in: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_layout); sk_in.fill_ternary_prob(0.5, &mut source_xs); let sk_in_dft: GLWESecretPrepared, FFT64Spqlios> = sk_in.prepare_alloc(&module, scratch.borrow()); - let mut sk_out: GLWESecret> = GLWESecret::alloc(&glwe_layout); + let mut sk_out: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_layout); sk_out.fill_ternary_prob(0.5, &mut source_xs); ksk.encrypt_sk( @@ -211,7 +211,7 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) { scratch.borrow(), ); - let ksk_prepared: GGLWESwitchingKeyPrepared, FFT64Spqlios> = ksk.prepare_alloc(&module, scratch.borrow()); + let ksk_prepared: GLWESwitchingKeyPrepared, FFT64Spqlios> = ksk.prepare_alloc(&module, scratch.borrow()); move || { ct.keyswitch_inplace(&module, &ksk_prepared, scratch.borrow()); diff --git a/poulpy-core/examples/encryption.rs b/poulpy-core/examples/encryption.rs index a65b473..efd838e 100644 --- a/poulpy-core/examples/encryption.rs +++ b/poulpy-core/examples/encryption.rs @@ -2,8 +2,7 @@ use poulpy_backend::cpu_spqlios::FFT64Spqlios; use poulpy_core::{ GLWEOperations, SIGMA, layouts::{ - Base2K, Degree, GLWECiphertext, GLWECiphertextLayout, GLWEPlaintext, GLWEPlaintextLayout, GLWESecret, LWEInfos, Rank, - TorusPrecision, + Base2K, GLWE, GLWELayout, GLWEPlaintext, GLWEPlaintextLayout, GLWESecret, LWEInfos, Rank, RingDegree, TorusPrecision, prepared::{GLWESecretPrepared, PrepareAlloc}, }, }; @@ -17,7 +16,7 @@ fn main() { // Ring degree let log_n: usize = 10; - let n: Degree = Degree(1 << log_n); + let n: RingDegree = RingDegree(1 << log_n); // Base-2-k (implicit digit decomposition) let base2k: Base2K = Base2K(14); @@ -34,7 +33,7 @@ fn main() { // Instantiate Module (DFT Tables) let module: Module = Module::::new(n.0 as u64); - let glwe_ct_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_ct_infos: GLWELayout = GLWELayout { n, base2k, k: k_ct, @@ -44,9 +43,9 @@ fn main() { let glwe_pt_infos: GLWEPlaintextLayout = GLWEPlaintextLayout { n, base2k, k: k_pt }; // Allocates ciphertext & plaintexts - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&glwe_ct_infos); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_pt_infos); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_pt_infos); + let mut ct: GLWE> = GLWE::alloc_from_infos(&glwe_ct_infos); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_pt_infos); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_pt_infos); // CPRNG let mut source_xs: Source = Source::new([0u8; 32]); @@ -55,12 +54,11 @@ fn main() { // Scratch space let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::encrypt_sk_scratch_space(&module, &glwe_ct_infos) - | GLWECiphertext::decrypt_scratch_space(&module, &glwe_ct_infos), + GLWE::encrypt_sk_tmp_bytes(&module, &glwe_ct_infos) | GLWE::decrypt_tmp_bytes(&module, &glwe_ct_infos), ); // Generate secret-key - let mut sk: GLWESecret> = GLWESecret::alloc(&glwe_ct_infos); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_ct_infos); sk.fill_ternary_prob(0.5, &mut source_xs); // Backend-prepared secret diff --git a/poulpy-core/src/automorphism/gglwe_atk.rs b/poulpy-core/src/automorphism/gglwe_atk.rs index 9b08e68..9650aa2 100644 --- a/poulpy-core/src/automorphism/gglwe_atk.rs +++ b/poulpy-core/src/automorphism/gglwe_atk.rs @@ -1,16 +1,16 @@ use poulpy_hal::{ api::{ - ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, - VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, - VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + ScratchAvailable, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigNormalize, + VecZnxBigNormalizeTmpBytes, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, + VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero}, }; -use crate::layouts::{GGLWEAutomorphismKey, GGLWEInfos, GLWECiphertext, prepared::GGLWEAutomorphismKeyPrepared}; +use crate::layouts::{AutomorphismKey, GGLWEInfos, GLWE, prepared::AutomorphismKeyPrepared}; -impl GGLWEAutomorphismKey> { - pub fn automorphism_scratch_space( +impl AutomorphismKey> { + pub fn automorphism_tmp_bytes( module: &Module, out_infos: &OUT, in_infos: &IN, @@ -20,9 +20,9 @@ impl GGLWEAutomorphismKey> { OUT: GGLWEInfos, IN: GGLWEInfos, KEY: GGLWEInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, + Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - GLWECiphertext::keyswitch_scratch_space( + GLWE::keyswitch_tmp_bytes( module, &out_infos.glwe_layout(), &in_infos.glwe_layout(), @@ -30,25 +30,25 @@ impl GGLWEAutomorphismKey> { ) } - pub fn automorphism_inplace_scratch_space(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize + pub fn automorphism_inplace_tmp_bytes(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize where OUT: GGLWEInfos, KEY: GGLWEInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, + Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - GGLWEAutomorphismKey::automorphism_scratch_space(module, out_infos, out_infos, key_infos) + AutomorphismKey::automorphism_tmp_bytes(module, out_infos, out_infos, key_infos) } } -impl GGLWEAutomorphismKey { +impl AutomorphismKey { pub fn automorphism( &mut self, module: &Module, - lhs: &GGLWEAutomorphismKey, - rhs: &GGLWEAutomorphismKeyPrepared, + lhs: &AutomorphismKey, + rhs: &AutomorphismKeyPrepared, scratch: &mut Scratch, ) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft @@ -61,7 +61,7 @@ impl GGLWEAutomorphismKey { + VecZnxAutomorphismInplace + VecZnxNormalize + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnx, + Scratch: ScratchAvailable, { #[cfg(debug_assertions)] { @@ -103,8 +103,8 @@ impl GGLWEAutomorphismKey { (0..self.rank_in().into()).for_each(|col_i| { (0..self.dnum().into()).for_each(|row_j| { - let mut res_ct: GLWECiphertext<&mut [u8]> = self.at_mut(row_j, col_i); - let lhs_ct: GLWECiphertext<&[u8]> = lhs.at(row_j, col_i); + let mut res_ct: GLWE<&mut [u8]> = self.at_mut(row_j, col_i); + let lhs_ct: GLWE<&[u8]> = lhs.at(row_j, col_i); // Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a) (0..cols_out).for_each(|i| { @@ -133,10 +133,10 @@ impl GGLWEAutomorphismKey { pub fn automorphism_inplace( &mut self, module: &Module, - rhs: &GGLWEAutomorphismKeyPrepared, + rhs: &AutomorphismKeyPrepared, scratch: &mut Scratch, ) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft @@ -149,7 +149,7 @@ impl GGLWEAutomorphismKey { + VecZnxAutomorphismInplace + VecZnxNormalize + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnx, + Scratch: ScratchAvailable, { #[cfg(debug_assertions)] { @@ -176,7 +176,7 @@ impl GGLWEAutomorphismKey { (0..self.rank_in().into()).for_each(|col_i| { (0..self.dnum().into()).for_each(|row_j| { - let mut res_ct: GLWECiphertext<&mut [u8]> = self.at_mut(row_j, col_i); + let mut res_ct: GLWE<&mut [u8]> = self.at_mut(row_j, col_i); // Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a) (0..cols_out).for_each(|i| { diff --git a/poulpy-core/src/automorphism/ggsw_ct.rs b/poulpy-core/src/automorphism/ggsw_ct.rs index eef3082..a3cef86 100644 --- a/poulpy-core/src/automorphism/ggsw_ct.rs +++ b/poulpy-core/src/automorphism/ggsw_ct.rs @@ -1,20 +1,20 @@ use poulpy_hal::{ api::{ - ScratchAvailable, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, - VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAddInplace, VecZnxDftAllocBytes, - VecZnxDftApply, VecZnxDftCopy, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeTmpBytes, - VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + ScratchAvailable, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigBytesOf, VecZnxBigNormalize, + VecZnxBigNormalizeTmpBytes, VecZnxDftAddInplace, VecZnxDftApply, VecZnxDftBytesOf, VecZnxDftCopy, VecZnxIdftApplyConsume, + VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, + VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, Module, Scratch}, }; use crate::layouts::{ - GGLWEInfos, GGSWCiphertext, GGSWInfos, GLWECiphertext, - prepared::{GGLWEAutomorphismKeyPrepared, GGLWETensorKeyPrepared}, + GGLWEInfos, GGSW, GGSWInfos, GLWE, + prepared::{AutomorphismKeyPrepared, TensorKeyPrepared}, }; -impl GGSWCiphertext> { - pub fn automorphism_scratch_space( +impl GGSW> { + pub fn automorphism_tmp_bytes( module: &Module, out_infos: &OUT, in_infos: &IN, @@ -26,25 +26,22 @@ impl GGSWCiphertext> { IN: GGSWInfos, KEY: GGLWEInfos, TSK: GGLWEInfos, - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigAllocBytes - + VecZnxNormalizeTmpBytes - + VecZnxBigNormalizeTmpBytes, + Module: + VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigBytesOf + VecZnxNormalizeTmpBytes + VecZnxBigNormalizeTmpBytes, { let out_size: usize = out_infos.size(); - let ci_dft: usize = module.vec_znx_dft_alloc_bytes((key_infos.rank_out() + 1).into(), out_size); - let ks_internal: usize = GLWECiphertext::keyswitch_scratch_space( + let ci_dft: usize = module.bytes_of_vec_znx_dft((key_infos.rank_out() + 1).into(), out_size); + let ks_internal: usize = GLWE::keyswitch_tmp_bytes( module, &out_infos.glwe_layout(), &in_infos.glwe_layout(), key_infos, ); - let expand: usize = GGSWCiphertext::expand_row_scratch_space(module, out_infos, tsk_infos); + let expand: usize = GGSW::expand_row_tmp_bytes(module, out_infos, tsk_infos); ci_dft + (ks_internal | expand) } - pub fn automorphism_inplace_scratch_space( + pub fn automorphism_inplace_tmp_bytes( module: &Module, out_infos: &OUT, key_infos: &KEY, @@ -54,26 +51,23 @@ impl GGSWCiphertext> { OUT: GGSWInfos, KEY: GGLWEInfos, TSK: GGLWEInfos, - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigAllocBytes - + VecZnxNormalizeTmpBytes - + VecZnxBigNormalizeTmpBytes, + Module: + VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigBytesOf + VecZnxNormalizeTmpBytes + VecZnxBigNormalizeTmpBytes, { - GGSWCiphertext::automorphism_scratch_space(module, out_infos, out_infos, key_infos, tsk_infos) + GGSW::automorphism_tmp_bytes(module, out_infos, out_infos, key_infos, tsk_infos) } } -impl GGSWCiphertext { +impl GGSW { pub fn automorphism( &mut self, module: &Module, - lhs: &GGSWCiphertext, - auto_key: &GGLWEAutomorphismKeyPrepared, - tensor_key: &GGLWETensorKeyPrepared, + lhs: &GGSW, + auto_key: &AutomorphismKeyPrepared, + tensor_key: &TensorKeyPrepared, scratch: &mut Scratch, ) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft @@ -83,13 +77,13 @@ impl GGSWCiphertext { + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxAutomorphismInplace - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxNormalizeTmpBytes + VecZnxDftCopy + VecZnxDftAddInplace + VecZnxIdftApplyTmpA + VecZnxNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnxBig + TakeVecZnx, + Scratch: ScratchAvailable, { #[cfg(debug_assertions)] { @@ -121,7 +115,7 @@ impl GGSWCiphertext { self.rank(), tensor_key.rank_out() ); - assert!(scratch.available() >= GGSWCiphertext::automorphism_scratch_space(module, self, lhs, auto_key, tensor_key)) + assert!(scratch.available() >= GGSW::automorphism_tmp_bytes(module, self, lhs, auto_key, tensor_key)) }; // Keyswitch the j-th row of the col 0 @@ -137,11 +131,11 @@ impl GGSWCiphertext { pub fn automorphism_inplace( &mut self, module: &Module, - auto_key: &GGLWEAutomorphismKeyPrepared, - tensor_key: &GGLWETensorKeyPrepared, + auto_key: &AutomorphismKeyPrepared, + tensor_key: &TensorKeyPrepared, scratch: &mut Scratch, ) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft @@ -151,13 +145,13 @@ impl GGSWCiphertext { + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxAutomorphismInplace - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxNormalizeTmpBytes + VecZnxDftCopy + VecZnxDftAddInplace + VecZnxIdftApplyTmpA + VecZnxNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnxBig + TakeVecZnx, + Scratch: ScratchAvailable, { // Keyswitch the j-th row of the col 0 (0..self.dnum().into()).for_each(|row_i| { diff --git a/poulpy-core/src/automorphism/glwe_ct.rs b/poulpy-core/src/automorphism/glwe_ct.rs index 79fcb12..0c8b581 100644 --- a/poulpy-core/src/automorphism/glwe_ct.rs +++ b/poulpy-core/src/automorphism/glwe_ct.rs @@ -1,17 +1,17 @@ use poulpy_hal::{ api::{ - ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, - VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallInplace, - VecZnxBigSubSmallNegateInplace, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalize, - VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + ScratchAvailable, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace, VecZnxBigNormalize, + VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallInplace, VecZnxBigSubSmallNegateInplace, VecZnxDftApply, VecZnxDftBytesOf, + VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, + VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnxBig}, }; -use crate::layouts::{GGLWEInfos, GLWECiphertext, GLWEInfos, LWEInfos, prepared::GGLWEAutomorphismKeyPrepared}; +use crate::layouts::{GGLWEInfos, GLWE, GLWEInfos, LWEInfos, prepared::AutomorphismKeyPrepared}; -impl GLWECiphertext> { - pub fn automorphism_scratch_space( +impl GLWE> { + pub fn automorphism_tmp_bytes( module: &Module, out_infos: &OUT, in_infos: &IN, @@ -21,30 +21,30 @@ impl GLWECiphertext> { OUT: GLWEInfos, IN: GLWEInfos, KEY: GGLWEInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, + Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - Self::keyswitch_scratch_space(module, out_infos, in_infos, key_infos) + Self::keyswitch_tmp_bytes(module, out_infos, in_infos, key_infos) } - pub fn automorphism_inplace_scratch_space(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize + pub fn automorphism_inplace_tmp_bytes(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize where OUT: GLWEInfos, KEY: GGLWEInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, + Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - Self::keyswitch_inplace_scratch_space(module, out_infos, key_infos) + Self::keyswitch_inplace_tmp_bytes(module, out_infos, key_infos) } } -impl GLWECiphertext { +impl GLWE { pub fn automorphism( &mut self, module: &Module, - lhs: &GLWECiphertext, - rhs: &GGLWEAutomorphismKeyPrepared, + lhs: &GLWE, + rhs: &AutomorphismKeyPrepared, scratch: &mut Scratch, ) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft @@ -56,7 +56,7 @@ impl GLWECiphertext { + VecZnxAutomorphismInplace + VecZnxNormalize + VecZnxNormalizeTmpBytes, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + Scratch: ScratchAvailable, { self.keyswitch(module, lhs, &rhs.key, scratch); (0..(self.rank() + 1).into()).for_each(|i| { @@ -67,10 +67,10 @@ impl GLWECiphertext { pub fn automorphism_inplace( &mut self, module: &Module, - rhs: &GGLWEAutomorphismKeyPrepared, + rhs: &AutomorphismKeyPrepared, scratch: &mut Scratch, ) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft @@ -82,7 +82,7 @@ impl GLWECiphertext { + VecZnxAutomorphismInplace + VecZnxNormalize + VecZnxNormalizeTmpBytes, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + Scratch: ScratchAvailable, { self.keyswitch_inplace(module, &rhs.key, scratch); (0..(self.rank() + 1).into()).for_each(|i| { @@ -93,11 +93,11 @@ impl GLWECiphertext { pub fn automorphism_add( &mut self, module: &Module, - lhs: &GLWECiphertext, - rhs: &GGLWEAutomorphismKeyPrepared, + lhs: &GLWE, + rhs: &AutomorphismKeyPrepared, scratch: &mut Scratch, ) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft @@ -109,7 +109,7 @@ impl GLWECiphertext { + VecZnxBigAutomorphismInplace + VecZnxNormalizeTmpBytes + VecZnxNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + Scratch: ScratchAvailable, { #[cfg(debug_assertions)] { @@ -135,10 +135,10 @@ impl GLWECiphertext { pub fn automorphism_add_inplace( &mut self, module: &Module, - rhs: &GGLWEAutomorphismKeyPrepared, + rhs: &AutomorphismKeyPrepared, scratch: &mut Scratch, ) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft @@ -150,7 +150,7 @@ impl GLWECiphertext { + VecZnxBigAutomorphismInplace + VecZnxNormalizeTmpBytes + VecZnxNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + Scratch: ScratchAvailable, { #[cfg(debug_assertions)] { @@ -176,11 +176,11 @@ impl GLWECiphertext { pub fn automorphism_sub_ab( &mut self, module: &Module, - lhs: &GLWECiphertext, - rhs: &GGLWEAutomorphismKeyPrepared, + lhs: &GLWE, + rhs: &AutomorphismKeyPrepared, scratch: &mut Scratch, ) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft @@ -193,7 +193,7 @@ impl GLWECiphertext { + VecZnxBigSubSmallInplace + VecZnxNormalizeTmpBytes + VecZnxNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + Scratch: ScratchAvailable, { #[cfg(debug_assertions)] { @@ -219,10 +219,10 @@ impl GLWECiphertext { pub fn automorphism_sub_inplace( &mut self, module: &Module, - rhs: &GGLWEAutomorphismKeyPrepared, + rhs: &AutomorphismKeyPrepared, scratch: &mut Scratch, ) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft @@ -235,7 +235,7 @@ impl GLWECiphertext { + VecZnxBigSubSmallInplace + VecZnxNormalizeTmpBytes + VecZnxNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + Scratch: ScratchAvailable, { #[cfg(debug_assertions)] { @@ -261,11 +261,11 @@ impl GLWECiphertext { pub fn automorphism_sub_negate( &mut self, module: &Module, - lhs: &GLWECiphertext, - rhs: &GGLWEAutomorphismKeyPrepared, + lhs: &GLWE, + rhs: &AutomorphismKeyPrepared, scratch: &mut Scratch, ) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft @@ -278,7 +278,7 @@ impl GLWECiphertext { + VecZnxBigSubSmallNegateInplace + VecZnxNormalizeTmpBytes + VecZnxNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + Scratch: ScratchAvailable, { #[cfg(debug_assertions)] { @@ -304,10 +304,10 @@ impl GLWECiphertext { pub fn automorphism_sub_negate_inplace( &mut self, module: &Module, - rhs: &GGLWEAutomorphismKeyPrepared, + rhs: &AutomorphismKeyPrepared, scratch: &mut Scratch, ) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft @@ -320,7 +320,7 @@ impl GLWECiphertext { + VecZnxBigSubSmallNegateInplace + VecZnxNormalizeTmpBytes + VecZnxNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + Scratch: ScratchAvailable, { #[cfg(debug_assertions)] { diff --git a/poulpy-core/src/conversion/gglwe_to_ggsw.rs b/poulpy-core/src/conversion/gglwe_to_ggsw.rs new file mode 100644 index 0000000..a7b86fa --- /dev/null +++ b/poulpy-core/src/conversion/gglwe_to_ggsw.rs @@ -0,0 +1,279 @@ +use poulpy_hal::{ + api::{ + ModuleN, ScratchAvailable, ScratchTakeBasic, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxCopy, VecZnxDftAddInplace, + VecZnxDftApply, VecZnxDftBytesOf, VecZnxDftCopy, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeTmpBytes, + VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + }, + layouts::{Backend, DataMut, Module, Scratch, VmpPMat, ZnxInfos}, +}; + +use crate::{ + ScratchTakeCore, + layouts::{ + GGLWE, GGLWEInfos, GGLWEToRef, GGSW, GGSWInfos, GGSWToMut, GLWEInfos, LWEInfos, + prepared::{TensorKeyPrepared, TensorKeyPreparedToRef}, + }, + operations::GLWEOperations, +}; + +impl GGLWE> { + pub fn from_gglw_tmp_bytes(module: &M, res_infos: &R, tsk_infos: &A) -> usize + where + M: GGSWFromGGLWE, + R: GGSWInfos, + A: GGLWEInfos, + { + module.ggsw_from_gglwe_tmp_bytes(res_infos, tsk_infos) + } +} + +impl GGSW { + pub fn from_gglwe(&mut self, module: &M, gglwe: &G, tsk: &T, scratch: &mut Scratch) + where + M: GGSWFromGGLWE, + G: GGLWEToRef, + T: TensorKeyPreparedToRef, + Scratch: ScratchTakeCore, + { + module.ggsw_from_gglwe(self, gglwe, tsk, scratch); + } +} + +impl GGSWFromGGLWE for Module where Self: GGSWExpandRows + VecZnxCopy {} + +pub trait GGSWFromGGLWE +where + Self: GGSWExpandRows + VecZnxCopy, +{ + fn ggsw_from_gglwe_tmp_bytes(&self, res_infos: &R, tsk_infos: &A) -> usize + where + R: GGSWInfos, + A: GGLWEInfos, + { + self.ggsw_expand_rows_tmp_bytes(res_infos, tsk_infos) + } + + fn ggsw_from_gglwe(&self, res: &mut R, a: &A, tsk: &T, scratch: &mut Scratch) + where + R: GGSWToMut, + A: GGLWEToRef, + T: TensorKeyPreparedToRef, + Scratch: ScratchTakeCore, + { + let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); + let a: &GGLWE<&[u8]> = &a.to_ref(); + let tsk: &TensorKeyPrepared<&[u8], BE> = &tsk.to_ref(); + + assert_eq!(res.rank(), a.rank_out()); + assert_eq!(res.dnum(), a.dnum()); + assert_eq!(res.n(), self.n() as u32); + assert_eq!(a.n(), self.n() as u32); + assert_eq!(tsk.n(), self.n() as u32); + + for row in 0..res.dnum().into() { + res.at_mut(row, 0).copy(self, &a.at(row, 0)); + } + + self.ggsw_expand_row(res, tsk, scratch); + } +} + +impl GGSWExpandRows for Module where + Self: Sized + + ModuleN + + VecZnxDftBytesOf + + VmpApplyDftToDftTmpBytes + + VecZnxBigBytesOf + + VecZnxNormalizeTmpBytes + + VecZnxDftBytesOf + + VmpApplyDftToDftTmpBytes + + VecZnxBigBytesOf + + VecZnxNormalizeTmpBytes + + VecZnxDftApply + + VecZnxDftCopy + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VecZnxDftAddInplace + + VecZnxBigNormalize + + VecZnxIdftApplyTmpA + + VecZnxNormalize +{ +} + +pub(crate) trait GGSWExpandRows +where + Self: Sized + + ModuleN + + VecZnxDftBytesOf + + VmpApplyDftToDftTmpBytes + + VecZnxBigBytesOf + + VecZnxNormalizeTmpBytes + + VecZnxDftBytesOf + + VmpApplyDftToDftTmpBytes + + VecZnxBigBytesOf + + VecZnxNormalizeTmpBytes + + VecZnxDftApply + + VecZnxDftCopy + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VecZnxDftAddInplace + + VecZnxBigNormalize + + VecZnxIdftApplyTmpA + + VecZnxNormalize, +{ + fn ggsw_expand_rows_tmp_bytes(&self, res_infos: &R, tsk_infos: &A) -> usize + where + R: GGSWInfos, + A: GGLWEInfos, + { + let tsk_size: usize = tsk_infos.k().div_ceil(tsk_infos.base2k()) as usize; + let size_in: usize = res_infos + .k() + .div_ceil(tsk_infos.base2k()) + .div_ceil(tsk_infos.dsize().into()) as usize; + + let tmp_dft_i: usize = self.bytes_of_vec_znx_dft((tsk_infos.rank_out() + 1).into(), tsk_size); + let tmp_a: usize = self.bytes_of_vec_znx_dft(1, size_in); + let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes( + tsk_size, + size_in, + size_in, + (tsk_infos.rank_in()).into(), // Verify if rank+1 + (tsk_infos.rank_out()).into(), // Verify if rank+1 + tsk_size, + ); + let tmp_idft: usize = self.bytes_of_vec_znx_big(1, tsk_size); + let norm: usize = self.vec_znx_normalize_tmp_bytes(); + + tmp_dft_i + ((tmp_a + vmp) | (tmp_idft + norm)) + } + + fn ggsw_expand_row(&self, res: &mut R, tsk: &T, scratch: &mut Scratch) + where + R: GGSWToMut, + T: TensorKeyPreparedToRef, + Scratch: ScratchTakeCore, + { + let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); + let tsk: &TensorKeyPrepared<&[u8], BE> = &tsk.to_ref(); + + let basek_in: usize = res.base2k().into(); + let basek_tsk: usize = tsk.base2k().into(); + + assert!(scratch.available() >= self.ggsw_expand_rows_tmp_bytes(res, tsk)); + + let rank: usize = res.rank().into(); + let cols: usize = rank + 1; + + let a_size: usize = (res.size() * basek_in).div_ceil(basek_tsk); + + // Keyswitch the j-th row of the col 0 + for row_i in 0..res.dnum().into() { + let a = &res.at(row_i, 0).data; + + // Pre-compute DFT of (a0, a1, a2) + let (mut ci_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, a_size); + + if basek_in == basek_tsk { + for i in 0..cols { + self.vec_znx_dft_apply(1, 0, &mut ci_dft, i, a, i); + } + } else { + let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(self, 1, a_size); + for i in 0..cols { + self.vec_znx_normalize(basek_tsk, &mut a_conv, 0, basek_in, a, i, scratch_2); + self.vec_znx_dft_apply(1, 0, &mut ci_dft, i, &a_conv, 0); + } + } + + for col_j in 1..cols { + // Example for rank 3: + // + // Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is + // actually composed of that many dnum and we focus on a specific row here + // implicitely given ci_dft. + // + // # Input + // + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) + // col 1: (0, 0, 0, 0) + // col 2: (0, 0, 0, 0) + // col 3: (0, 0, 0, 0) + // + // # Output + // + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) + // col 1: (-(b0s0 + b1s1 + b2s2) , b0 + M[i], b1 , b2 ) + // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M[i], c2 ) + // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M[i]) + + let dsize: usize = tsk.dsize().into(); + + let (mut tmp_dft_i, scratch_2) = scratch_1.take_vec_znx_dft(self, cols, tsk.size()); + let (mut tmp_a, scratch_3) = scratch_2.take_vec_znx_dft(self, 1, ci_dft.size().div_ceil(dsize)); + + { + // Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2 + // + // # Example for col=1 + // + // a0 * (-(f0s0 + f1s1 + f1s2) + s0^2, f0, f1, f2) = (-(a0f0s0 + a0f1s1 + a0f1s2) + a0s0^2, a0f0, a0f1, a0f2) + // + + // a1 * (-(g0s0 + g1s1 + g1s2) + s0s1, g0, g1, g2) = (-(a1g0s0 + a1g1s1 + a1g1s2) + a1s0s1, a1g0, a1g1, a1g2) + // + + // a2 * (-(h0s0 + h1s1 + h1s2) + s0s2, h0, h1, h2) = (-(a2h0s0 + a2h1s1 + a2h1s2) + a2s0s2, a2h0, a2h1, a2h2) + // = + // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2) + for col_i in 1..cols { + let pmat: &VmpPMat<&[u8], BE> = &tsk.at(col_i - 1, col_j - 1).key.data; // Selects Enc(s[i]s[j]) + + // Extracts a[i] and multipies with Enc(s[i]s[j]) + for di in 0..dsize { + tmp_a.set_size((ci_dft.size() + di) / dsize); + + // Small optimization for dsize > 2 + // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then + // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}. + // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. + // It is possible to further ignore the last dsize-1 limbs, but this introduce + // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same + // noise is kept with respect to the ideal functionality. + tmp_dft_i.set_size(tsk.size() - ((dsize - di) as isize - 2).max(0) as usize); + + self.vec_znx_dft_copy(dsize, dsize - 1 - di, &mut tmp_a, 0, &ci_dft, col_i); + if di == 0 && col_i == 1 { + self.vmp_apply_dft_to_dft(&mut tmp_dft_i, &tmp_a, pmat, scratch_3); + } else { + self.vmp_apply_dft_to_dft_add(&mut tmp_dft_i, &tmp_a, pmat, di, scratch_3); + } + } + } + } + + // Adds -(sum a[i] * s[i]) + m) on the i-th column of tmp_idft_i + // + // (-(x0s0 + x1s1 + x2s2) + a0s0s0 + a1s0s1 + a2s0s2, x0, x1, x2) + // + + // (0, -(a0s0 + a1s1 + a2s2) + M[i], 0, 0) + // = + // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0 -(a0s0 + a1s1 + a2s2) + M[i], x1, x2) + // = + // (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2) + self.vec_znx_dft_add_inplace(&mut tmp_dft_i, col_j, &ci_dft, 0); + let (mut tmp_idft, scratch_3) = scratch_2.take_vec_znx_big(self, 1, tsk.size()); + for i in 0..cols { + self.vec_znx_idft_apply_tmpa(&mut tmp_idft, 0, &mut tmp_dft_i, i); + self.vec_znx_big_normalize( + basek_in, + &mut res.at_mut(row_i, col_j).data, + i, + basek_tsk, + &tmp_idft, + 0, + scratch_3, + ); + } + } + } + } +} diff --git a/poulpy-core/src/conversion/glwe_to_lwe.rs b/poulpy-core/src/conversion/glwe_to_lwe.rs index 517da90..b6c6ed1 100644 --- a/poulpy-core/src/conversion/glwe_to_lwe.rs +++ b/poulpy-core/src/conversion/glwe_to_lwe.rs @@ -1,22 +1,16 @@ use poulpy_hal::{ api::{ - ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, - VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftApply, VecZnxDftBytesOf, + VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, + VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut, ZnxZero}, }; -use crate::{ - TakeGLWECt, - layouts::{ - GGLWEInfos, GLWECiphertext, GLWECiphertextLayout, GLWEInfos, LWECiphertext, LWEInfos, Rank, - prepared::GLWEToLWESwitchingKeyPrepared, - }, -}; +use crate::layouts::{GGLWEInfos, GLWE, GLWEInfos, GLWELayout, LWE, LWEInfos, Rank, prepared::GLWEToLWESwitchingKeyPrepared}; -impl LWECiphertext> { - pub fn from_glwe_scratch_space( +impl LWE> { + pub fn from_glwe_tmp_bytes( module: &Module, lwe_infos: &OUT, glwe_infos: &IN, @@ -26,26 +20,26 @@ impl LWECiphertext> { OUT: LWEInfos, IN: GLWEInfos, KEY: GGLWEInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, + Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - let glwe_layout: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_layout: GLWELayout = GLWELayout { n: module.n().into(), base2k: lwe_infos.base2k(), k: lwe_infos.k(), rank: Rank(1), }; - GLWECiphertext::alloc_bytes_with( + GLWE::bytes_of( module.n().into(), lwe_infos.base2k(), lwe_infos.k(), 1u32.into(), - ) + GLWECiphertext::keyswitch_scratch_space(module, &glwe_layout, glwe_infos, key_infos) + ) + GLWE::keyswitch_tmp_bytes(module, &glwe_layout, glwe_infos, key_infos) } } -impl LWECiphertext { - pub fn sample_extract(&mut self, a: &GLWECiphertext) { +impl LWE { + pub fn sample_extract(&mut self, a: &GLWE) { #[cfg(debug_assertions)] { assert!(self.n() <= a.n()); @@ -66,13 +60,13 @@ impl LWECiphertext { pub fn from_glwe( &mut self, module: &Module, - a: &GLWECiphertext, + a: &GLWE, ks: &GLWEToLWESwitchingKeyPrepared, scratch: &mut Scratch, ) where DGlwe: DataRef, DKs: DataRef, - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft @@ -83,7 +77,7 @@ impl LWECiphertext { + VecZnxBigNormalize + VecZnxNormalize + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable + TakeVecZnxDft + TakeGLWECt + TakeVecZnx, + Scratch:, { #[cfg(debug_assertions)] { @@ -92,7 +86,7 @@ impl LWECiphertext { assert!(self.n() <= module.n() as u32); } - let glwe_layout: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_layout: GLWELayout = GLWELayout { n: module.n().into(), base2k: self.base2k(), k: self.k(), diff --git a/poulpy-core/src/conversion/lwe_to_glwe.rs b/poulpy-core/src/conversion/lwe_to_glwe.rs index d3ae616..c4a3b88 100644 --- a/poulpy-core/src/conversion/lwe_to_glwe.rs +++ b/poulpy-core/src/conversion/lwe_to_glwe.rs @@ -1,22 +1,16 @@ use poulpy_hal::{ api::{ - ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, + ScratchAvailable, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftApply, + VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx, ZnxView, ZnxViewMut, ZnxZero}, }; -use crate::{ - TakeGLWECt, - layouts::{ - GGLWEInfos, GLWECiphertext, GLWECiphertextLayout, GLWEInfos, LWECiphertext, LWEInfos, - prepared::LWEToGLWESwitchingKeyPrepared, - }, -}; +use crate::layouts::{GGLWEInfos, GLWE, GLWEInfos, GLWELayout, LWE, LWEInfos, prepared::LWEToGLWESwitchingKeyPrepared}; -impl GLWECiphertext> { - pub fn from_lwe_scratch_space( +impl GLWE> { + pub fn from_lwe_tmp_bytes( module: &Module, glwe_infos: &OUT, lwe_infos: &IN, @@ -26,35 +20,35 @@ impl GLWECiphertext> { OUT: GLWEInfos, IN: LWEInfos, KEY: GGLWEInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, + Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - let ct: usize = GLWECiphertext::alloc_bytes_with( + let ct: usize = GLWE::bytes_of( module.n().into(), key_infos.base2k(), lwe_infos.k().max(glwe_infos.k()), 1u32.into(), ); - let ks: usize = GLWECiphertext::keyswitch_inplace_scratch_space(module, glwe_infos, key_infos); + let ks: usize = GLWE::keyswitch_inplace_tmp_bytes(module, glwe_infos, key_infos); if lwe_infos.base2k() == key_infos.base2k() { ct + ks } else { - let a_conv = VecZnx::alloc_bytes(module.n(), 1, lwe_infos.size()) + module.vec_znx_normalize_tmp_bytes(); + let a_conv = VecZnx::bytes_of(module.n(), 1, lwe_infos.size()) + module.vec_znx_normalize_tmp_bytes(); ct + a_conv + ks } } } -impl GLWECiphertext { +impl GLWE { pub fn from_lwe( &mut self, module: &Module, - lwe: &LWECiphertext, + lwe: &LWE, ksk: &LWEToGLWESwitchingKeyPrepared, scratch: &mut Scratch, ) where DLwe: DataRef, DKsk: DataRef, - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft @@ -65,7 +59,7 @@ impl GLWECiphertext { + VecZnxBigNormalize + VecZnxNormalize + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable + TakeVecZnxDft + TakeGLWECt + TakeVecZnx, + Scratch: ScratchAvailable, { #[cfg(debug_assertions)] { @@ -74,7 +68,7 @@ impl GLWECiphertext { assert!(lwe.n() <= module.n() as u32); } - let (mut glwe, scratch_1) = scratch.take_glwe_ct(&GLWECiphertextLayout { + let (mut glwe, scratch_1) = scratch.take_glwe_ct(&GLWELayout { n: ksk.n(), base2k: ksk.base2k(), k: lwe.k(), diff --git a/poulpy-core/src/conversion/mod.rs b/poulpy-core/src/conversion/mod.rs index 090208b..9771531 100644 --- a/poulpy-core/src/conversion/mod.rs +++ b/poulpy-core/src/conversion/mod.rs @@ -1,2 +1,5 @@ +mod gglwe_to_ggsw; mod glwe_to_lwe; mod lwe_to_glwe; + +pub use gglwe_to_ggsw::*; diff --git a/poulpy-core/src/decryption/glwe_ct.rs b/poulpy-core/src/decryption/glwe_ct.rs index 19b4a82..4306d33 100644 --- a/poulpy-core/src/decryption/glwe_ct.rs +++ b/poulpy-core/src/decryption/glwe_ct.rs @@ -1,25 +1,25 @@ use poulpy_hal::{ api::{ - SvpApplyDftToDftInplace, TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigNormalize, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalizeTmpBytes, + SvpApplyDftToDftInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxDftApply, + VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalizeTmpBytes, }, layouts::{Backend, DataMut, DataRef, DataViewMut, Module, Scratch}, }; -use crate::layouts::{GLWECiphertext, GLWEInfos, GLWEPlaintext, LWEInfos, prepared::GLWESecretPrepared}; +use crate::layouts::{GLWE, GLWEInfos, GLWEPlaintext, LWEInfos, prepared::GLWESecretPrepared}; -impl GLWECiphertext> { - pub fn decrypt_scratch_space(module: &Module, infos: &A) -> usize +impl GLWE> { + pub fn decrypt_tmp_bytes(module: &Module, infos: &A) -> usize where A: GLWEInfos, - Module: VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes, + Module: VecZnxDftBytesOf + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf, { let size: usize = infos.size(); - (module.vec_znx_normalize_tmp_bytes() | module.vec_znx_dft_alloc_bytes(1, size)) + module.vec_znx_dft_alloc_bytes(1, size) + (module.vec_znx_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_dft(1, size) } } -impl GLWECiphertext { +impl GLWE { pub fn decrypt( &self, module: &Module, @@ -33,7 +33,7 @@ impl GLWECiphertext { + VecZnxBigAddInplace + VecZnxBigAddSmallInplace + VecZnxBigNormalize, - Scratch: TakeVecZnxDft + TakeVecZnxBig, + Scratch:, { #[cfg(debug_assertions)] { diff --git a/poulpy-core/src/decryption/lwe_ct.rs b/poulpy-core/src/decryption/lwe_ct.rs index 57abdc6..ade21e3 100644 --- a/poulpy-core/src/decryption/lwe_ct.rs +++ b/poulpy-core/src/decryption/lwe_ct.rs @@ -4,9 +4,9 @@ use poulpy_hal::{ oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl}, }; -use crate::layouts::{LWECiphertext, LWEInfos, LWEPlaintext, LWESecret}; +use crate::layouts::{LWE, LWEInfos, LWEPlaintext, LWESecret}; -impl LWECiphertext +impl LWE where DataSelf: DataRef, { diff --git a/poulpy-core/src/encryption/compressed/gglwe_atk.rs b/poulpy-core/src/encryption/compressed/gglwe_atk.rs index 95dcf20..f0afcae 100644 --- a/poulpy-core/src/encryption/compressed/gglwe_atk.rs +++ b/poulpy-core/src/encryption/compressed/gglwe_atk.rs @@ -1,35 +1,96 @@ use poulpy_hal::{ - api::{ - ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxBigNormalize, VecZnxDftAllocBytes, - VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, - }, + api::{ScratchAvailable, SvpPPolBytesOf, VecZnxAutomorphism, VecZnxDftBytesOf, VecZnxNormalizeTmpBytes}, layouts::{Backend, DataMut, DataRef, Module, Scratch}, source::Source, }; use crate::{ - TakeGLWESecret, TakeGLWESecretPrepared, + encryption::compressed::gglwe_ksk::GGLWEKeyCompressedEncryptSk, layouts::{ - GGLWEInfos, GLWEInfos, GLWESecret, LWEInfos, - compressed::{GGLWEAutomorphismKeyCompressed, GGLWESwitchingKeyCompressed}, + GGLWEInfos, GLWEInfos, GLWESecret, GLWESecretToRef, LWEInfos, + compressed::{AutomorphismKeyCompressed, AutomorphismKeyCompressedToMut, GLWESwitchingKeyCompressed}, }, }; -impl GGLWEAutomorphismKeyCompressed> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize +impl AutomorphismKeyCompressed> { + pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GGLWEInfos, - Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + SvpPPolAllocBytes, + Module: VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes + SvpPPolBytesOf, { assert_eq!(module.n() as u32, infos.n()); - GGLWESwitchingKeyCompressed::encrypt_sk_scratch_space(module, infos) - + GLWESecret::alloc_bytes_with(infos.n(), infos.rank_out()) + GLWESwitchingKeyCompressed::encrypt_sk_tmp_bytes(module, infos) + GLWESecret::bytes_of(infos.n(), infos.rank_out()) } } -impl GGLWEAutomorphismKeyCompressed { +pub trait GGLWEAutomorphismKeyCompressedEncryptSk { + fn gglwe_automorphism_key_compressed_encrypt_sk( + &self, + res: &mut R, + p: i64, + sk: &S, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: AutomorphismKeyCompressedToMut, + S: GLWESecretToRef; +} + +impl GGLWEAutomorphismKeyCompressedEncryptSk for Module +where + Module: GGLWEKeyCompressedEncryptSk + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + SvpPPolBytesOf + VecZnxAutomorphism, + Scratch: ScratchAvailable, +{ + fn gglwe_automorphism_key_compressed_encrypt_sk( + &self, + res: &mut R, + p: i64, + sk: &S, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: AutomorphismKeyCompressedToMut, + S: GLWESecretToRef, + { + let res: &mut AutomorphismKeyCompressed<&mut [u8]> = &mut res.to_mut(); + let sk: &GLWESecret<&[u8]> = &sk.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.n(), sk.n()); + assert_eq!(res.rank_out(), res.rank_in()); + assert_eq!(sk.rank(), res.rank_out()); + assert!( + scratch.available() >= AutomorphismKeyCompressed::encrypt_sk_tmp_bytes(self, res), + "scratch.available(): {} < AutomorphismKey::encrypt_sk_tmp_bytes: {}", + scratch.available(), + AutomorphismKeyCompressed::encrypt_sk_tmp_bytes(self, res) + ) + } + + let (mut sk_out, scratch_1) = scratch.take_glwe_secret(sk.n(), sk.rank()); + + { + (0..res.rank_out().into()).for_each(|i| { + self.vec_znx_automorphism( + self.galois_element_inv(p), + &mut sk_out.data.as_vec_znx_mut(), + i, + &sk.data.as_vec_znx(), + i, + ); + }); + } + + self.gglwe_key_compressed_encrypt_sk(&mut res.key, sk, &sk_out, seed_xa, source_xe, scratch_1); + + res.p = p; + } +} + +impl AutomorphismKeyCompressed { #[allow(clippy::too_many_arguments)] pub fn encrypt_sk( &mut self, @@ -40,56 +101,8 @@ impl GGLWEAutomorphismKeyCompressed { source_xe: &mut Source, scratch: &mut Scratch, ) where - Module: VecZnxAutomorphism - + SvpPrepare - + SvpPPolAllocBytes - + VecZnxSwitchRing - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + VecZnxAddScalarInplace, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared, + Module: GGLWEAutomorphismKeyCompressedEncryptSk, { - #[cfg(debug_assertions)] - { - assert_eq!(self.n(), sk.n()); - assert_eq!(self.rank_out(), self.rank_in()); - assert_eq!(sk.rank(), self.rank_out()); - assert!( - scratch.available() >= GGLWEAutomorphismKeyCompressed::encrypt_sk_scratch_space(module, self), - "scratch.available(): {} < AutomorphismKey::encrypt_sk_scratch_space: {}", - scratch.available(), - GGLWEAutomorphismKeyCompressed::encrypt_sk_scratch_space(module, self) - ) - } - - let (mut sk_out, scratch_1) = scratch.take_glwe_secret(sk.n(), sk.rank()); - - { - (0..self.rank_out().into()).for_each(|i| { - module.vec_znx_automorphism( - module.galois_element_inv(p), - &mut sk_out.data.as_vec_znx_mut(), - i, - &sk.data.as_vec_znx(), - i, - ); - }); - } - - self.key - .encrypt_sk(module, sk, &sk_out, seed_xa, source_xe, scratch_1); - - self.p = p; + module.gglwe_automorphism_key_compressed_encrypt_sk(self, p, sk, seed_xa, source_xe, scratch); } } diff --git a/poulpy-core/src/encryption/compressed/gglwe_ct.rs b/poulpy-core/src/encryption/compressed/gglwe_ct.rs index 76871da..b67dc88 100644 --- a/poulpy-core/src/encryption/compressed/gglwe_ct.rs +++ b/poulpy-core/src/encryption/compressed/gglwe_ct.rs @@ -1,30 +1,22 @@ use poulpy_hal::{ api::{ - ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, - VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, - VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, + ScratchAvailable, VecZnxAddScalarInplace, VecZnxDftBytesOf, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, + ZnNormalizeInplace, }, - layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, ZnxZero}, + layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxZero}, source::Source, }; use crate::{ - TakeGLWEPt, - encryption::{SIGMA, glwe_encrypt_sk_internal}, - layouts::{GGLWECiphertext, GGLWEInfos, LWEInfos, compressed::GGLWECiphertextCompressed, prepared::GLWESecretPrepared}, + encryption::{SIGMA, glwe_ct::GLWEEncryptSkInternal}, + layouts::{ + GGLWE, GGLWEInfos, LWEInfos, + compressed::{GGLWECompressed, GGLWECompressedToMut}, + prepared::{GLWESecretPrepared, GLWESecretPreparedToRef}, + }, }; -impl GGLWECiphertextCompressed> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize - where - A: GGLWEInfos, - Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes, - { - GGLWECiphertext::encrypt_sk_scratch_space(module, infos) - } -} - -impl GGLWECiphertextCompressed { +impl GGLWECompressed { #[allow(clippy::too_many_arguments)] pub fn encrypt_sk( &mut self, @@ -35,83 +27,124 @@ impl GGLWECiphertextCompressed { source_xe: &mut Source, scratch: &mut Scratch, ) where - Module: VecZnxAddScalarInplace - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + Module: GGLWECompressedEncryptSk, { + module.gglwe_compressed_encrypt_sk(self, pt, sk, seed, source_xe, scratch); + } +} + +impl GGLWECompressed> { + pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize + where + A: GGLWEInfos, + Module: VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes, + { + GGLWE::encrypt_sk_tmp_bytes(module, infos) + } +} + +pub trait GGLWECompressedEncryptSk { + fn gglwe_compressed_encrypt_sk( + &self, + res: &mut R, + pt: &P, + sk: &S, + seed: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGLWECompressedToMut, + P: ScalarZnxToRef, + S: GLWESecretPreparedToRef; +} + +impl GGLWECompressedEncryptSk for Module +where + Module: GLWEEncryptSkInternal + + VecZnxNormalizeInplace + + VecZnxNormalizeTmpBytes + + VecZnxDftBytesOf + + VecZnxAddScalarInplace + + ZnNormalizeInplace, + Scratch: ScratchAvailable, +{ + fn gglwe_compressed_encrypt_sk( + &self, + res: &mut R, + pt: &P, + sk: &S, + seed: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGLWECompressedToMut, + P: ScalarZnxToRef, + S: GLWESecretPreparedToRef, + { + let res: &mut GGLWECompressed<&mut [u8]> = &mut res.to_mut(); + let pt: &ScalarZnx<&[u8]> = &pt.to_ref(); + #[cfg(debug_assertions)] { use poulpy_hal::layouts::ZnxInfos; + let sk = &sk.to_ref(); assert_eq!( - self.rank_in(), + res.rank_in(), pt.cols() as u32, - "self.rank_in(): {} != pt.cols(): {}", - self.rank_in(), + "res.rank_in(): {} != pt.cols(): {}", + res.rank_in(), pt.cols() ); assert_eq!( - self.rank_out(), + res.rank_out(), sk.rank(), - "self.rank_out(): {} != sk.rank(): {}", - self.rank_out(), + "res.rank_out(): {} != sk.rank(): {}", + res.rank_out(), sk.rank() ); - assert_eq!(self.n(), sk.n()); + assert_eq!(res.n(), sk.n()); assert_eq!(pt.n() as u32, sk.n()); assert!( - scratch.available() >= GGLWECiphertextCompressed::encrypt_sk_scratch_space(module, self), - "scratch.available: {} < GGLWECiphertext::encrypt_sk_scratch_space: {}", + scratch.available() >= GGLWECompressed::encrypt_sk_tmp_bytes(self, res), + "scratch.available: {} < GGLWECiphertext::encrypt_sk_tmp_bytes: {}", scratch.available(), - GGLWECiphertextCompressed::encrypt_sk_scratch_space(module, self) + GGLWECompressed::encrypt_sk_tmp_bytes(self, res) ); assert!( - self.dnum().0 * self.dsize().0 * self.base2k().0 <= self.k().0, - "self.dnum() : {} * self.dsize() : {} * self.base2k() : {} = {} >= self.k() = {}", - self.dnum(), - self.dsize(), - self.base2k(), - self.dnum().0 * self.dsize().0 * self.base2k().0, - self.k() + res.dnum().0 * res.dsize().0 * res.base2k().0 <= res.k().0, + "res.dnum() : {} * res.dsize() : {} * res.base2k() : {} = {} >= res.k() = {}", + res.dnum(), + res.dsize(), + res.base2k(), + res.dnum().0 * res.dsize().0 * res.base2k().0, + res.k() ); } - let dnum: usize = self.dnum().into(); - let dsize: usize = self.dsize().into(); - let base2k: usize = self.base2k().into(); - let rank_in: usize = self.rank_in().into(); - let cols: usize = (self.rank_out() + 1).into(); + let dnum: usize = res.dnum().into(); + let dsize: usize = res.dsize().into(); + let base2k: usize = res.base2k().into(); + let rank_in: usize = res.rank_in().into(); + let cols: usize = (res.rank_out() + 1).into(); let mut source_xa = Source::new(seed); - let (mut tmp_pt, scrach_1) = scratch.take_glwe_pt(self); + let (mut tmp_pt, scrach_1) = scratch.take_glwe_pt(res); (0..rank_in).for_each(|col_i| { (0..dnum).for_each(|d_i| { // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt tmp_pt.data.zero(); // zeroes for next iteration - module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (dsize - 1) + d_i * dsize, pt, col_i); - module.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scrach_1); + self.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (dsize - 1) + d_i * dsize, pt, col_i); + self.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scrach_1); let (seed, mut source_xa_tmp) = source_xa.branch(); - self.seed[col_i * dnum + d_i] = seed; + res.seed[col_i * dnum + d_i] = seed; - glwe_encrypt_sk_internal( - module, - self.base2k().into(), - self.k().into(), - &mut self.at_mut(d_i, col_i).data, + self.glwe_encrypt_sk_internal( + res.base2k().into(), + res.k().into(), + &mut res.at_mut(d_i, col_i).data, cols, true, Some((&tmp_pt, 0)), diff --git a/poulpy-core/src/encryption/compressed/gglwe_ksk.rs b/poulpy-core/src/encryption/compressed/gglwe_ksk.rs index 8dd177f..93519b9 100644 --- a/poulpy-core/src/encryption/compressed/gglwe_ksk.rs +++ b/poulpy-core/src/encryption/compressed/gglwe_ksk.rs @@ -1,35 +1,31 @@ use poulpy_hal::{ - api::{ - ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, - VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, - VecZnxSubInplace, VecZnxSwitchRing, - }, + api::{ScratchAvailable, SvpPPolBytesOf, SvpPrepare, VecZnxDftBytesOf, VecZnxNormalizeTmpBytes, VecZnxSwitchRing}, layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch}, source::Source, }; use crate::{ - TakeGLWESecretPrepared, + encryption::compressed::gglwe_ct::GGLWECompressedEncryptSk, layouts::{ - Degree, GGLWECiphertext, GGLWEInfos, GLWEInfos, GLWESecret, LWEInfos, compressed::GGLWESwitchingKeyCompressed, + GGLWE, GGLWEInfos, GLWEInfos, GLWESecret, GLWESecretToRef, LWEInfos, RingDegree, + compressed::{GLWESwitchingKeyCompressed, GLWESwitchingKeyCompressedToMut}, prepared::GLWESecretPrepared, }, }; -impl GGLWESwitchingKeyCompressed> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize +impl GLWESwitchingKeyCompressed> { + pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GGLWEInfos, - Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + SvpPPolAllocBytes, + Module: VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes + SvpPPolBytesOf, { - (GGLWECiphertext::encrypt_sk_scratch_space(module, infos) | ScalarZnx::alloc_bytes(module.n(), 1)) - + ScalarZnx::alloc_bytes(module.n(), infos.rank_in().into()) - + GLWESecretPrepared::alloc_bytes_with(module, infos.rank_out()) + (GGLWE::encrypt_sk_tmp_bytes(module, infos) | ScalarZnx::bytes_of(module.n(), 1)) + + ScalarZnx::bytes_of(module.n(), infos.rank_in().into()) + + GLWESecretPrepared::bytes_of(module, infos.rank_out()) } } -impl GGLWESwitchingKeyCompressed { +impl GLWESwitchingKeyCompressed { #[allow(clippy::too_many_arguments)] pub fn encrypt_sk( &mut self, @@ -40,36 +36,65 @@ impl GGLWESwitchingKeyCompressed { source_xe: &mut Source, scratch: &mut Scratch, ) where - Module: SvpPrepare - + SvpPPolAllocBytes - + VecZnxSwitchRing - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + VecZnxAddScalarInplace, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared, + Module: GGLWEKeyCompressedEncryptSk, { + module.gglwe_key_compressed_encrypt_sk(self, sk_in, sk_out, seed_xa, source_xe, scratch); + } +} + +pub trait GGLWEKeyCompressedEncryptSk { + fn gglwe_key_compressed_encrypt_sk( + &self, + res: &mut R, + sk_in: &SI, + sk_out: &SO, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GLWESwitchingKeyCompressedToMut, + SI: GLWESecretToRef, + SO: GLWESecretToRef; +} + +impl GGLWEKeyCompressedEncryptSk for Module +where + Module: GGLWECompressedEncryptSk + + SvpPPolBytesOf + + VecZnxNormalizeTmpBytes + + VecZnxDftBytesOf + + VecZnxSwitchRing + + SvpPrepare, + Scratch: ScratchAvailable, +{ + fn gglwe_key_compressed_encrypt_sk( + &self, + res: &mut R, + sk_in: &SI, + sk_out: &SO, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GLWESwitchingKeyCompressedToMut, + SI: GLWESecretToRef, + SO: GLWESecretToRef, + { + let res: &mut GLWESwitchingKeyCompressed<&mut [u8]> = &mut res.to_mut(); + let sk_in: &GLWESecret<&[u8]> = &sk_in.to_ref(); + let sk_out: &GLWESecret<&[u8]> = &sk_out.to_ref(); + #[cfg(debug_assertions)] { - use crate::layouts::GGLWESwitchingKey; + use crate::layouts::GLWESwitchingKey; - assert!(sk_in.n().0 <= module.n() as u32); - assert!(sk_out.n().0 <= module.n() as u32); + assert!(sk_in.n().0 <= self.n() as u32); + assert!(sk_out.n().0 <= self.n() as u32); assert!( - scratch.available() >= GGLWESwitchingKey::encrypt_sk_scratch_space(module, self), - "scratch.available()={} < GLWESwitchingKey::encrypt_sk_scratch_space={}", + scratch.available() >= GLWESwitchingKey::encrypt_sk_tmp_bytes(self, res), + "scratch.available()={} < GLWESwitchingKey::encrypt_sk_tmp_bytes={}", scratch.available(), - GGLWESwitchingKey::encrypt_sk_scratch_space(module, self) + GLWESwitchingKey::encrypt_sk_tmp_bytes(self, res) ) } @@ -77,7 +102,7 @@ impl GGLWESwitchingKeyCompressed { let (mut sk_in_tmp, scratch_1) = scratch.take_scalar_znx(n, sk_in.rank().into()); (0..sk_in.rank().into()).for_each(|i| { - module.vec_znx_switch_ring( + self.vec_znx_switch_ring( &mut sk_in_tmp.as_vec_znx_mut(), i, &sk_in.data.as_vec_znx(), @@ -85,24 +110,24 @@ impl GGLWESwitchingKeyCompressed { ); }); - let (mut sk_out_tmp, scratch_2) = scratch_1.take_glwe_secret_prepared(Degree(n as u32), sk_out.rank()); + let (mut sk_out_tmp, scratch_2) = scratch_1.take_glwe_secret_prepared(RingDegree(n as u32), sk_out.rank()); { let (mut tmp, _) = scratch_2.take_scalar_znx(n, 1); (0..sk_out.rank().into()).for_each(|i| { - module.vec_znx_switch_ring(&mut tmp.as_vec_znx_mut(), 0, &sk_out.data.as_vec_znx(), i); - module.svp_prepare(&mut sk_out_tmp.data, i, &tmp, 0); + self.vec_znx_switch_ring(&mut tmp.as_vec_znx_mut(), 0, &sk_out.data.as_vec_znx(), i); + self.svp_prepare(&mut sk_out_tmp.data, i, &tmp, 0); }); } - self.key.encrypt_sk( - module, + self.gglwe_compressed_encrypt_sk( + &mut res.key, &sk_in_tmp, &sk_out_tmp, seed_xa, source_xe, scratch_2, ); - self.sk_in_n = sk_in.n().into(); - self.sk_out_n = sk_out.n().into(); + res.sk_in_n = sk_in.n().into(); + res.sk_out_n = sk_out.n().into(); } } diff --git a/poulpy-core/src/encryption/compressed/gglwe_tsk.rs b/poulpy-core/src/encryption/compressed/gglwe_tsk.rs index 6a75a57..2beaa4b 100644 --- a/poulpy-core/src/encryption/compressed/gglwe_tsk.rs +++ b/poulpy-core/src/encryption/compressed/gglwe_tsk.rs @@ -1,86 +1,83 @@ use poulpy_hal::{ api::{ - ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, - TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAllocBytes, - VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, - VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, + SvpApplyDftToDft, SvpPPolBytesOf, SvpPrepare, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, + VecZnxIdftApplyTmpA, VecZnxNormalizeTmpBytes, }, layouts::{Backend, DataMut, DataRef, Module, Scratch}, source::Source, }; use crate::{ - TakeGLWESecret, TakeGLWESecretPrepared, + encryption::compressed::gglwe_ksk::GGLWEKeyCompressedEncryptSk, layouts::{ - GGLWEInfos, GGLWETensorKey, GLWEInfos, GLWESecret, LWEInfos, Rank, compressed::GGLWETensorKeyCompressed, - prepared::Prepare, + GGLWEInfos, GLWEInfos, GLWESecret, GLWESecretToRef, LWEInfos, Rank, TensorKey, + compressed::{TensorKeyCompressed, TensorKeyCompressedToMut}, }, }; -impl GGLWETensorKeyCompressed> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize +impl TensorKeyCompressed> { + pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GGLWEInfos, - Module: - SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + VecZnxBigAllocBytes, + Module: SvpPPolBytesOf + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes + VecZnxBigBytesOf, { - GGLWETensorKey::encrypt_sk_scratch_space(module, infos) + TensorKey::encrypt_sk_tmp_bytes(module, infos) } } -impl GGLWETensorKeyCompressed { - pub fn encrypt_sk( - &mut self, - module: &Module, - sk: &GLWESecret, +pub trait GGLWETensorKeyCompressedEncryptSk { + fn gglwe_tensor_key_encrypt_sk( + &self, + res: &mut R, + sk: &S, seed_xa: [u8; 32], source_xe: &mut Source, scratch: &mut Scratch, ) where - Module: SvpApplyDftToDft - + VecZnxIdftApplyTmpA - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + VecZnxSwitchRing - + VecZnxAddScalarInplace - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc, - Scratch: ScratchAvailable - + TakeScalarZnx - + TakeVecZnxDft - + TakeGLWESecretPrepared - + ScratchAvailable - + TakeVecZnx - + TakeVecZnxBig, + R: TensorKeyCompressedToMut, + S: GLWESecretToRef; +} + +impl GGLWETensorKeyCompressedEncryptSk for Module +where + Module: GGLWEKeyCompressedEncryptSk + + VecZnxDftApply + + SvpApplyDftToDft + + VecZnxIdftApplyTmpA + + VecZnxBigNormalize + + SvpPrepare, + Scratch:, +{ + fn gglwe_tensor_key_encrypt_sk( + &self, + res: &mut R, + sk: &S, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: TensorKeyCompressedToMut, + S: GLWESecretToRef, { + let res: &mut TensorKeyCompressed<&mut [u8]> = &mut res.to_mut(); + let sk: &GLWESecret<&[u8]> = &sk.to_ref(); + #[cfg(debug_assertions)] { - assert_eq!(self.rank_out(), sk.rank()); - assert_eq!(self.n(), sk.n()); + assert_eq!(res.rank_out(), sk.rank()); + assert_eq!(res.n(), sk.n()); } let n: usize = sk.n().into(); - let rank: usize = self.rank_out().into(); + let rank: usize = res.rank_out().into(); - let (mut sk_dft_prep, scratch_1) = scratch.take_glwe_secret_prepared(sk.n(), self.rank_out()); - sk_dft_prep.prepare(module, sk, scratch_1); + let (mut sk_dft_prep, scratch_1) = scratch.take_glwe_secret_prepared(sk.n(), res.rank_out()); + sk_dft_prep.prepare(self, sk, scratch_1); let (mut sk_dft, scratch_2) = scratch_1.take_vec_znx_dft(n, rank, 1); for i in 0..rank { - module.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); + self.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); } let (mut sk_ij_big, scratch_3) = scratch_2.take_vec_znx_big(n, 1, 1); @@ -91,14 +88,14 @@ impl GGLWETensorKeyCompressed { for i in 0..rank { for j in i..rank { - module.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_dft_prep.data, j, &sk_dft, i); + self.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_dft_prep.data, j, &sk_dft, i); - module.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0); - module.vec_znx_big_normalize( - self.base2k().into(), + self.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0); + self.vec_znx_big_normalize( + res.base2k().into(), &mut sk_ij.data.as_vec_znx_mut(), 0, - self.base2k().into(), + res.base2k().into(), &sk_ij_big, 0, scratch_5, @@ -106,9 +103,30 @@ impl GGLWETensorKeyCompressed { let (seed_xa_tmp, _) = source_xa.branch(); - self.at_mut(i, j) - .encrypt_sk(module, &sk_ij, sk, seed_xa_tmp, source_xe, scratch_5); + self.gglwe_key_compressed_encrypt_sk( + res.at_mut(i, j), + &sk_ij, + sk, + seed_xa_tmp, + source_xe, + scratch_5, + ); } } } } + +impl TensorKeyCompressed { + pub fn encrypt_sk( + &mut self, + module: &Module, + sk: &GLWESecret, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + Module: GGLWETensorKeyCompressedEncryptSk, + { + module.gglwe_tensor_key_encrypt_sk(self, sk, seed_xa, source_xe, scratch); + } +} diff --git a/poulpy-core/src/encryption/compressed/ggsw_ct.rs b/poulpy-core/src/encryption/compressed/ggsw_ct.rs index e49f246..567f04f 100644 --- a/poulpy-core/src/encryption/compressed/ggsw_ct.rs +++ b/poulpy-core/src/encryption/compressed/ggsw_ct.rs @@ -1,32 +1,118 @@ use poulpy_hal::{ - api::{ - ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, - VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, - VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, - }, - layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, ZnxZero}, + api::{VecZnxAddScalarInplace, VecZnxDftBytesOf, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes}, + layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxZero}, source::Source, }; use crate::{ - TakeGLWEPt, - encryption::{SIGMA, glwe_encrypt_sk_internal}, + encryption::{SIGMA, glwe_ct::GLWEEncryptSkInternal}, layouts::{ - GGSWCiphertext, GGSWInfos, GLWEInfos, LWEInfos, compressed::GGSWCiphertextCompressed, prepared::GLWESecretPrepared, + GGSW, GGSWInfos, GLWEInfos, LWEInfos, + compressed::{GGSWCompressed, GGSWCompressedToMut}, + prepared::{GLWESecretPrepared, GLWESecretPreparedToRef}, }, }; -impl GGSWCiphertextCompressed> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize +impl GGSWCompressed> { + pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GGSWInfos, - Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes, + Module: VecZnxNormalizeTmpBytes + VecZnxDftBytesOf, { - GGSWCiphertext::encrypt_sk_scratch_space(module, infos) + GGSW::encrypt_sk_tmp_bytes(module, infos) } } -impl GGSWCiphertextCompressed { +pub trait GGSWCompressedEncryptSk { + fn ggsw_compressed_encrypt_sk( + &self, + res: &mut R, + pt: &P, + sk: &S, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGSWCompressedToMut, + P: ScalarZnxToRef, + S: GLWESecretPreparedToRef; +} + +impl GGSWCompressedEncryptSk for Module +where + Module: GLWEEncryptSkInternal + VecZnxAddScalarInplace + VecZnxNormalizeInplace, + Scratch:, +{ + fn ggsw_compressed_encrypt_sk( + &self, + res: &mut R, + pt: &P, + sk: &S, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGSWCompressedToMut, + P: ScalarZnxToRef, + S: GLWESecretPreparedToRef, + { + let res: &mut GGSWCompressed<&mut [u8]> = &mut res.to_mut(); + let sk: &GLWESecretPrepared<&[u8], B> = &sk.to_ref(); + let pt: &ScalarZnx<&[u8]> = &pt.to_ref(); + + #[cfg(debug_assertions)] + { + use poulpy_hal::layouts::ZnxInfos; + + assert_eq!(res.rank(), sk.rank()); + assert_eq!(res.n(), sk.n()); + assert_eq!(pt.n() as u32, sk.n()); + } + + let base2k: usize = res.base2k().into(); + let rank: usize = res.rank().into(); + let cols: usize = rank + 1; + let dsize: usize = res.dsize().into(); + + let (mut tmp_pt, scratch_1) = scratch.take_glwe_pt(&res.glwe_layout()); + + let mut source = Source::new(seed_xa); + + res.seed = vec![[0u8; 32]; res.dnum().0 as usize * cols]; + + for row_i in 0..res.dnum().into() { + tmp_pt.data.zero(); + + // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt + self.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (dsize - 1) + row_i * dsize, pt, 0); + self.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scratch_1); + + for col_j in 0..rank + 1 { + // rlwe encrypt of vec_znx_pt into vec_znx_ct + + let (seed, mut source_xa_tmp) = source.branch(); + + res.seed[row_i * cols + col_j] = seed; + + self.glwe_encrypt_sk_internal( + res.base2k().into(), + res.k().into(), + &mut res.at_mut(row_i, col_j).data, + cols, + true, + Some((&tmp_pt, col_j)), + sk, + &mut source_xa_tmp, + source_xe, + SIGMA, + scratch_1, + ); + } + } + } +} + +impl GGSWCompressed { #[allow(clippy::too_many_arguments)] pub fn encrypt_sk( &mut self, @@ -37,71 +123,8 @@ impl GGSWCiphertextCompressed { source_xe: &mut Source, scratch: &mut Scratch, ) where - Module: VecZnxAddScalarInplace - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + Module: GGSWCompressedEncryptSk, { - #[cfg(debug_assertions)] - { - use poulpy_hal::layouts::ZnxInfos; - - assert_eq!(self.rank(), sk.rank()); - assert_eq!(self.n(), sk.n()); - assert_eq!(pt.n() as u32, sk.n()); - } - - let base2k: usize = self.base2k().into(); - let rank: usize = self.rank().into(); - let cols: usize = rank + 1; - let dsize: usize = self.dsize().into(); - - let (mut tmp_pt, scratch_1) = scratch.take_glwe_pt(&self.glwe_layout()); - - let mut source = Source::new(seed_xa); - - self.seed = vec![[0u8; 32]; self.dnum().0 as usize * cols]; - - (0..self.dnum().into()).for_each(|row_i| { - tmp_pt.data.zero(); - - // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt - module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (dsize - 1) + row_i * dsize, pt, 0); - module.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scratch_1); - - (0..rank + 1).for_each(|col_j| { - // rlwe encrypt of vec_znx_pt into vec_znx_ct - - let (seed, mut source_xa_tmp) = source.branch(); - - self.seed[row_i * cols + col_j] = seed; - - glwe_encrypt_sk_internal( - module, - self.base2k().into(), - self.k().into(), - &mut self.at_mut(row_i, col_j).data, - cols, - true, - Some((&tmp_pt, col_j)), - sk, - &mut source_xa_tmp, - source_xe, - SIGMA, - scratch_1, - ); - }); - }); + module.ggsw_compressed_encrypt_sk(self, pt, sk, seed_xa, source_xe, scratch); } } diff --git a/poulpy-core/src/encryption/compressed/glwe_ct.rs b/poulpy-core/src/encryption/compressed/glwe_ct.rs index 834f968..f04a07a 100644 --- a/poulpy-core/src/encryption/compressed/glwe_ct.rs +++ b/poulpy-core/src/encryption/compressed/glwe_ct.rs @@ -1,31 +1,83 @@ use poulpy_hal::{ - api::{ - ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, - VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, - VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, - }, + api::{VecZnxDftBytesOf, VecZnxNormalizeTmpBytes}, layouts::{Backend, DataMut, DataRef, Module, Scratch}, source::Source, }; use crate::{ - encryption::{SIGMA, glwe_ct::glwe_encrypt_sk_internal}, + encryption::{SIGMA, glwe_ct::GLWEEncryptSkInternal}, layouts::{ - GLWECiphertext, GLWEInfos, GLWEPlaintext, LWEInfos, compressed::GLWECiphertextCompressed, prepared::GLWESecretPrepared, + GLWE, GLWEInfos, GLWEPlaintext, GLWEPlaintextToRef, LWEInfos, + compressed::{GLWECompressed, GLWECompressedToMut}, + prepared::{GLWESecretPrepared, GLWESecretPreparedToRef}, }, }; -impl GLWECiphertextCompressed> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize +impl GLWECompressed> { + pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GLWEInfos, - Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes, + Module: VecZnxNormalizeTmpBytes + VecZnxDftBytesOf, { - GLWECiphertext::encrypt_sk_scratch_space(module, infos) + GLWE::encrypt_sk_tmp_bytes(module, infos) } } -impl GLWECiphertextCompressed { +pub trait GLWECompressedEncryptSk { + fn glwe_compressed_encrypt_sk( + &self, + res: &mut R, + pt: &P, + sk: &S, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GLWECompressedToMut, + P: GLWEPlaintextToRef, + S: GLWESecretPreparedToRef; +} + +impl GLWECompressedEncryptSk for Module +where + Module: GLWEEncryptSkInternal, +{ + fn glwe_compressed_encrypt_sk( + &self, + res: &mut R, + pt: &P, + sk: &S, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GLWECompressedToMut, + P: GLWEPlaintextToRef, + S: GLWESecretPreparedToRef, + { + let res: &mut GLWECompressed<&mut [u8]> = &mut res.to_mut(); + let mut source_xa: Source = Source::new(seed_xa); + let cols: usize = (res.rank() + 1).into(); + + self.glwe_encrypt_sk_internal( + res.base2k().into(), + res.k().into(), + &mut res.data, + cols, + true, + Some((pt, 0)), + sk, + &mut source_xa, + source_xe, + SIGMA, + scratch, + ); + + res.seed = seed_xa; + } +} + +impl GLWECompressed { #[allow(clippy::too_many_arguments)] pub fn encrypt_sk( &mut self, @@ -36,65 +88,8 @@ impl GLWECiphertextCompressed { source_xe: &mut Source, scratch: &mut Scratch, ) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + Module: GLWECompressedEncryptSk, { - self.encrypt_sk_internal(module, Some((pt, 0)), sk, seed_xa, source_xe, scratch); - } - - #[allow(clippy::too_many_arguments)] - pub(crate) fn encrypt_sk_internal( - &mut self, - module: &Module, - pt: Option<(&GLWEPlaintext, usize)>, - sk: &GLWESecretPrepared, - seed_xa: [u8; 32], - source_xe: &mut Source, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, - { - let mut source_xa = Source::new(seed_xa); - let cols: usize = (self.rank() + 1).into(); - glwe_encrypt_sk_internal( - module, - self.base2k().into(), - self.k().into(), - &mut self.data, - cols, - true, - pt, - sk, - &mut source_xa, - source_xe, - SIGMA, - scratch, - ); - self.seed = seed_xa; + module.glwe_compressed_encrypt_sk(self, pt, sk, seed_xa, source_xe, scratch); } } diff --git a/poulpy-core/src/encryption/gglwe_atk.rs b/poulpy-core/src/encryption/gglwe_atk.rs index 6d45b37..6536c7e 100644 --- a/poulpy-core/src/encryption/gglwe_atk.rs +++ b/poulpy-core/src/encryption/gglwe_atk.rs @@ -1,34 +1,33 @@ use poulpy_hal::{ api::{ - ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxBigNormalize, VecZnxDftAllocBytes, - VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, + ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, + VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, + VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, + VecZnxSwitchRing, }, - layouts::{Backend, DataMut, DataRef, Module, Scratch}, + layouts::{Backend, DataMut, Module, Scratch}, source::Source, }; -use crate::{ - TakeGLWESecret, TakeGLWESecretPrepared, - layouts::{GGLWEAutomorphismKey, GGLWEInfos, GGLWESwitchingKey, GLWEInfos, GLWESecret, LWEInfos}, +use crate::layouts::{ + AutomorphismKey, AutomorphismKeyToMut, GGLWEInfos, GLWEInfos, GLWESecret, GLWESecretToRef, GLWESwitchingKey, LWEInfos, }; -impl GGLWEAutomorphismKey> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize +impl AutomorphismKey> { + pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GGLWEInfos, - Module: SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes, + Module: SvpPPolBytesOf + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes, { assert_eq!( infos.rank_in(), infos.rank_out(), "rank_in != rank_out is not supported for GGLWEAutomorphismKey" ); - GGLWESwitchingKey::encrypt_sk_scratch_space(module, infos) + GLWESecret::alloc_bytes(&infos.glwe_layout()) + GLWESwitchingKey::encrypt_sk_tmp_bytes(module, infos) + GLWESecret::bytes_of_from_infos(module, &infos.glwe_layout()) } - pub fn encrypt_pk_scratch_space(module: &Module, _infos: &A) -> usize + pub fn encrypt_pk_tmp_bytes(module: &Module, _infos: &A) -> usize where A: GGLWEInfos, { @@ -37,62 +36,102 @@ impl GGLWEAutomorphismKey> { _infos.rank_out(), "rank_in != rank_out is not supported for GGLWEAutomorphismKey" ); - GGLWESwitchingKey::encrypt_pk_scratch_space(module, _infos) + GLWESwitchingKey::encrypt_pk_tmp_bytes(module, _infos) } } -impl GGLWEAutomorphismKey { - #[allow(clippy::too_many_arguments)] - pub fn encrypt_sk( - &mut self, - module: &Module, +pub trait GGLWEAutomorphismKeyEncryptSk { + fn gglwe_automorphism_key_encrypt_sk( + &self, + res: &mut A, p: i64, - sk: &GLWESecret, + sk: &B, source_xa: &mut Source, source_xe: &mut Source, - scratch: &mut Scratch, + scratch: &mut Scratch, ) where - Module: VecZnxAddScalarInplace - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + VecZnxSwitchRing - + SvpPPolAllocBytes - + VecZnxAutomorphism, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared, + A: AutomorphismKeyToMut, + B: GLWESecretToRef; +} + +impl AutomorphismKey +where + Self: AutomorphismKeyToMut, +{ + pub fn encrypt_sk( + &mut self, + module: &Module, + p: i64, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + S: GLWESecretToRef, + Module: GGLWEAutomorphismKeyEncryptSk, { + module.gglwe_automorphism_key_encrypt_sk(self, p, sk, source_xa, source_xe, scratch); + } +} + +impl GGLWEAutomorphismKeyEncryptSk for Module +where + Module: VecZnxAddScalarInplace + + VecZnxDftBytesOf + + VecZnxBigNormalize + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxNormalizeTmpBytes + + VecZnxFillUniform + + VecZnxSubInplace + + VecZnxAddInplace + + VecZnxNormalizeInplace + + VecZnxAddNormal + + VecZnxNormalize + + VecZnxSub + + SvpPrepare + + VecZnxSwitchRing + + SvpPPolBytesOf + + VecZnxAutomorphism, + Scratch: ScratchAvailable, +{ + fn gglwe_automorphism_key_encrypt_sk( + &self, + res: &mut A, + p: i64, + sk: &B, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + A: AutomorphismKeyToMut, + B: GLWESecretToRef, + { + let res: &mut AutomorphismKey<&mut [u8]> = &mut res.to_mut(); + let sk: &GLWESecret<&[u8]> = &sk.to_ref(); + #[cfg(debug_assertions)] { use crate::layouts::{GLWEInfos, LWEInfos}; - assert_eq!(self.n(), sk.n()); - assert_eq!(self.rank_out(), self.rank_in()); - assert_eq!(sk.rank(), self.rank_out()); + assert_eq!(res.n(), sk.n()); + assert_eq!(res.rank_out(), res.rank_in()); + assert_eq!(sk.rank(), res.rank_out()); assert!( - scratch.available() >= GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, self), - "scratch.available(): {} < AutomorphismKey::encrypt_sk_scratch_space: {:?}", + scratch.available() >= AutomorphismKey::encrypt_sk_tmp_bytes(self, res), + "scratch.available(): {} < AutomorphismKey::encrypt_sk_tmp_bytes: {:?}", scratch.available(), - GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, self) + AutomorphismKey::encrypt_sk_tmp_bytes(self, res) ) } let (mut sk_out, scratch_1) = scratch.take_glwe_secret(sk.n(), sk.rank()); { - (0..self.rank_out().into()).for_each(|i| { - module.vec_znx_automorphism( - module.galois_element_inv(p), + (0..res.rank_out().into()).for_each(|i| { + self.vec_znx_automorphism( + self.galois_element_inv(p), &mut sk_out.data.as_vec_znx_mut(), i, &sk.data.as_vec_znx(), @@ -101,9 +140,9 @@ impl GGLWEAutomorphismKey { }); } - self.key - .encrypt_sk(module, sk, &sk_out, source_xa, source_xe, scratch_1); + res.key + .encrypt_sk(self, sk, &sk_out, source_xa, source_xe, scratch_1); - self.p = p; + res.p = p; } } diff --git a/poulpy-core/src/encryption/gglwe_ct.rs b/poulpy-core/src/encryption/gglwe_ct.rs index 51054cb..d333892 100644 --- a/poulpy-core/src/encryption/gglwe_ct.rs +++ b/poulpy-core/src/encryption/gglwe_ct.rs @@ -1,29 +1,28 @@ use poulpy_hal::{ - api::{ - ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, - VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, - VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, - }, - layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, ZnxZero}, + api::{ScratchAvailable, VecZnxAddScalarInplace, VecZnxDftBytesOf, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes}, + layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxZero}, source::Source, }; use crate::{ - TakeGLWEPt, - layouts::{GGLWECiphertext, GGLWEInfos, GLWECiphertext, GLWEPlaintext, LWEInfos, prepared::GLWESecretPrepared}, + encryption::glwe_ct::GLWEEncryptSk, + layouts::{ + GGLWE, GGLWEInfos, GGLWEToMut, GLWE, GLWEPlaintext, LWEInfos, + prepared::{GLWESecretPrepared, GLWESecretPreparedToRef}, + }, }; -impl GGLWECiphertext> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize +impl GGLWE> { + pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GGLWEInfos, - Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes, + Module: VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes, { - GLWECiphertext::encrypt_sk_scratch_space(module, &infos.glwe_layout()) - + (GLWEPlaintext::alloc_bytes(&infos.glwe_layout()) | module.vec_znx_normalize_tmp_bytes()) + GLWE::encrypt_sk_tmp_bytes(module, &infos.glwe_layout()) + + (GLWEPlaintext::bytes_of_from_infos(module, &infos.glwe_layout()) | module.vec_znx_normalize_tmp_bytes()) } - pub fn encrypt_pk_scratch_space(_module: &Module, _infos: &A) -> usize + pub fn encrypt_pk_tmp_bytes(_module: &Module, _infos: &A) -> usize where A: GGLWEInfos, { @@ -31,78 +30,88 @@ impl GGLWECiphertext> { } } -impl GGLWECiphertext { - #[allow(clippy::too_many_arguments)] - pub fn encrypt_sk( - &mut self, - module: &Module, - pt: &ScalarZnx, - sk: &GLWESecretPrepared, +pub trait GGLWEEncryptSk { + fn gglwe_encrypt_sk( + &self, + res: &mut R, + pt: &P, + sk: &S, source_xa: &mut Source, source_xe: &mut Source, scratch: &mut Scratch, ) where - Module: VecZnxAddScalarInplace - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + R: GGLWEToMut, + P: ScalarZnxToRef, + S: GLWESecretPreparedToRef; +} + +impl GGLWEEncryptSk for Module +where + Module: GLWEEncryptSk + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxAddScalarInplace + VecZnxNormalizeInplace, + Scratch: ScratchAvailable, +{ + fn gglwe_encrypt_sk( + &self, + res: &mut R, + pt: &P, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGLWEToMut, + P: ScalarZnxToRef, + S: GLWESecretPreparedToRef, { + let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut(); + let pt: &ScalarZnx<&[u8]> = &pt.to_ref(); + #[cfg(debug_assertions)] { use poulpy_hal::layouts::ZnxInfos; + let sk: GLWESecretPrepared<&[u8], B> = sk.to_ref(); assert_eq!( - self.rank_in(), + res.rank_in(), pt.cols() as u32, - "self.rank_in(): {} != pt.cols(): {}", - self.rank_in(), + "res.rank_in(): {} != pt.cols(): {}", + res.rank_in(), pt.cols() ); assert_eq!( - self.rank_out(), + res.rank_out(), sk.rank(), - "self.rank_out(): {} != sk.rank(): {}", - self.rank_out(), + "res.rank_out(): {} != sk.rank(): {}", + res.rank_out(), sk.rank() ); - assert_eq!(self.n(), sk.n()); + assert_eq!(res.n(), sk.n()); assert_eq!(pt.n() as u32, sk.n()); assert!( - scratch.available() >= GGLWECiphertext::encrypt_sk_scratch_space(module, self), - "scratch.available: {} < GGLWECiphertext::encrypt_sk_scratch_space(module, self.rank()={}, self.size()={}): {}", + scratch.available() >= GGLWE::encrypt_sk_tmp_bytes(self, res), + "scratch.available: {} < GGLWECiphertext::encrypt_sk_tmp_bytes(self, res.rank()={}, res.size()={}): {}", scratch.available(), - self.rank_out(), - self.size(), - GGLWECiphertext::encrypt_sk_scratch_space(module, self) + res.rank_out(), + res.size(), + GGLWE::encrypt_sk_tmp_bytes(self, res) ); assert!( - self.dnum().0 * self.dsize().0 * self.base2k().0 <= self.k().0, - "self.dnum() : {} * self.dsize() : {} * self.base2k() : {} = {} >= self.k() = {}", - self.dnum(), - self.dsize(), - self.base2k(), - self.dnum().0 * self.dsize().0 * self.base2k().0, - self.k() + res.dnum().0 * res.dsize().0 * res.base2k().0 <= res.k().0, + "res.dnum() : {} * res.dsize() : {} * res.base2k() : {} = {} >= res.k() = {}", + res.dnum(), + res.dsize(), + res.base2k(), + res.dnum().0 * res.dsize().0 * res.base2k().0, + res.k() ); } - let dnum: usize = self.dnum().into(); - let dsize: usize = self.dsize().into(); - let base2k: usize = self.base2k().into(); - let rank_in: usize = self.rank_in().into(); + let dnum: usize = res.dnum().into(); + let dsize: usize = res.dsize().into(); + let base2k: usize = res.base2k().into(); + let rank_in: usize = res.rank_in().into(); - let (mut tmp_pt, scrach_1) = scratch.take_glwe_pt(self); + let (mut tmp_pt, scrach_1) = scratch.take_glwe_pt(res); // For each input column (i.e. rank) produces a GGLWE ciphertext of rank_out+1 columns // // Example for ksk rank 2 to rank 3: @@ -114,17 +123,39 @@ impl GGLWECiphertext { // // (-(a*s) + s0, a) // (-(b*s) + s1, b) - (0..rank_in).for_each(|col_i| { - (0..dnum).for_each(|row_i| { + + for col_i in 0..rank_in { + for row_i in 0..dnum { // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt tmp_pt.data.zero(); // zeroes for next iteration - module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (dsize - 1) + row_i * dsize, pt, col_i); - module.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scrach_1); - - // rlwe encrypt of vec_znx_pt into vec_znx_ct - self.at_mut(row_i, col_i) - .encrypt_sk(module, &tmp_pt, sk, source_xa, source_xe, scrach_1); - }); - }); + self.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (dsize - 1) + row_i * dsize, pt, col_i); + self.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scrach_1); + self.glwe_encrypt_sk( + &mut res.at_mut(row_i, col_i), + &tmp_pt, + sk, + source_xa, + source_xe, + scrach_1, + ); + } + } + } +} + +impl GGLWE { + #[allow(clippy::too_many_arguments)] + pub fn encrypt_sk( + &mut self, + module: &Module, + pt: &ScalarZnx, + sk: &GLWESecretPrepared, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + Module: GGLWEEncryptSk, + { + module.gglwe_encrypt_sk(self, pt, sk, source_xa, source_xe, scratch); } } diff --git a/poulpy-core/src/encryption/gglwe_ksk.rs b/poulpy-core/src/encryption/gglwe_ksk.rs index 0629bec..ef9b5bf 100644 --- a/poulpy-core/src/encryption/gglwe_ksk.rs +++ b/poulpy-core/src/encryption/gglwe_ksk.rs @@ -1,41 +1,37 @@ use poulpy_hal::{ api::{ - ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, - VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, - VecZnxSubInplace, VecZnxSwitchRing, + ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, + VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, + VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, }, layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch}, source::Source, }; -use crate::{ - TakeGLWESecretPrepared, - layouts::{ - Degree, GGLWECiphertext, GGLWEInfos, GGLWESwitchingKey, GLWEInfos, GLWESecret, LWEInfos, prepared::GLWESecretPrepared, - }, +use crate::layouts::{ + GGLWE, GGLWEInfos, GLWEInfos, GLWESecret, GLWESwitchingKey, LWEInfos, RingDegree, prepared::GLWESecretPrepared, }; -impl GGLWESwitchingKey> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize +impl GLWESwitchingKey> { + pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GGLWEInfos, - Module: SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes, + Module: SvpPPolBytesOf + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes, { - (GGLWECiphertext::encrypt_sk_scratch_space(module, infos) | ScalarZnx::alloc_bytes(module.n(), 1)) - + ScalarZnx::alloc_bytes(module.n(), infos.rank_in().into()) - + GLWESecretPrepared::alloc_bytes(module, &infos.glwe_layout()) + (GGLWE::encrypt_sk_tmp_bytes(module, infos) | ScalarZnx::bytes_of(module.n(), 1)) + + ScalarZnx::bytes_of(module.n(), infos.rank_in().into()) + + GLWESecretPrepared::bytes_of_from_infos(module, &infos.glwe_layout()) } - pub fn encrypt_pk_scratch_space(module: &Module, _infos: &A) -> usize + pub fn encrypt_pk_tmp_bytes(module: &Module, _infos: &A) -> usize where A: GGLWEInfos, { - GGLWECiphertext::encrypt_pk_scratch_space(module, _infos) + GGLWE::encrypt_pk_tmp_bytes(module, _infos) } } -impl GGLWESwitchingKey { +impl GLWESwitchingKey { #[allow(clippy::too_many_arguments)] pub fn encrypt_sk( &mut self, @@ -47,7 +43,7 @@ impl GGLWESwitchingKey { scratch: &mut Scratch, ) where Module: VecZnxAddScalarInplace - + VecZnxDftAllocBytes + + VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -62,18 +58,18 @@ impl GGLWESwitchingKey { + VecZnxSub + SvpPrepare + VecZnxSwitchRing - + SvpPPolAllocBytes, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared, + + SvpPPolBytesOf, + Scratch: ScratchAvailable, { #[cfg(debug_assertions)] { assert!(sk_in.n().0 <= module.n() as u32); assert!(sk_out.n().0 <= module.n() as u32); assert!( - scratch.available() >= GGLWESwitchingKey::encrypt_sk_scratch_space(module, self), - "scratch.available()={} < GLWESwitchingKey::encrypt_sk_scratch_space={}", + scratch.available() >= GLWESwitchingKey::encrypt_sk_tmp_bytes(module, self), + "scratch.available()={} < GLWESwitchingKey::encrypt_sk_tmp_bytes={}", scratch.available(), - GGLWESwitchingKey::encrypt_sk_scratch_space(module, self) + GLWESwitchingKey::encrypt_sk_tmp_bytes(module, self) ) } @@ -89,7 +85,7 @@ impl GGLWESwitchingKey { ); }); - let (mut sk_out_tmp, scratch_2) = scratch_1.take_glwe_secret_prepared(Degree(n as u32), sk_out.rank()); + let (mut sk_out_tmp, scratch_2) = scratch_1.take_glwe_secret_prepared(RingDegree(n as u32), sk_out.rank()); { let (mut tmp, _) = scratch_2.take_scalar_znx(n, 1); (0..sk_out.rank().into()).for_each(|i| { diff --git a/poulpy-core/src/encryption/gglwe_tsk.rs b/poulpy-core/src/encryption/gglwe_tsk.rs index 1946929..62d1a15 100644 --- a/poulpy-core/src/encryption/gglwe_tsk.rs +++ b/poulpy-core/src/encryption/gglwe_tsk.rs @@ -1,39 +1,34 @@ use poulpy_hal::{ api::{ - ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, - TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAllocBytes, - VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, - VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, + SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, + VecZnxAddScalarInplace, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, + VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, + VecZnxSubInplace, VecZnxSwitchRing, }, layouts::{Backend, DataMut, DataRef, Module, Scratch}, source::Source, }; -use crate::{ - TakeGLWESecret, TakeGLWESecretPrepared, - layouts::{ - Degree, GGLWEInfos, GGLWESwitchingKey, GGLWETensorKey, GLWEInfos, GLWESecret, LWEInfos, Rank, - prepared::{GLWESecretPrepared, Prepare}, - }, +use crate::layouts::{ + GGLWEInfos, GLWEInfos, GLWESecret, GLWESwitchingKey, LWEInfos, Rank, RingDegree, TensorKey, prepared::GLWESecretPrepared, }; -impl GGLWETensorKey> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize +impl TensorKey> { + pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GGLWEInfos, - Module: - SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + VecZnxBigAllocBytes, + Module: SvpPPolBytesOf + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes + VecZnxBigBytesOf, { - GLWESecretPrepared::alloc_bytes_with(module, infos.rank_out()) - + module.vec_znx_dft_alloc_bytes(infos.rank_out().into(), 1) - + module.vec_znx_big_alloc_bytes(1, 1) - + module.vec_znx_dft_alloc_bytes(1, 1) - + GLWESecret::alloc_bytes_with(Degree(module.n() as u32), Rank(1)) - + GGLWESwitchingKey::encrypt_sk_scratch_space(module, infos) + GLWESecretPrepared::bytes_of(module, infos.rank_out()) + + module.bytes_of_vec_znx_dft(infos.rank_out().into(), 1) + + module.bytes_of_vec_znx_big(1, 1) + + module.bytes_of_vec_znx_dft(1, 1) + + GLWESecret::bytes_of(RingDegree(module.n() as u32), Rank(1)) + + GLWESwitchingKey::encrypt_sk_tmp_bytes(module, infos) } } -impl GGLWETensorKey { +impl TensorKey { pub fn encrypt_sk( &mut self, module: &Module, @@ -45,7 +40,7 @@ impl GGLWETensorKey { Module: SvpApplyDftToDft + VecZnxIdftApplyTmpA + VecZnxAddScalarInplace - + VecZnxDftAllocBytes + + VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -60,9 +55,8 @@ impl GGLWETensorKey { + VecZnxSub + SvpPrepare + VecZnxSwitchRing - + SvpPPolAllocBytes, - Scratch: - TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared + TakeVecZnxBig, + + SvpPPolBytesOf, + Scratch:, { #[cfg(debug_assertions)] { @@ -70,7 +64,7 @@ impl GGLWETensorKey { assert_eq!(self.n(), sk.n()); } - let n: Degree = sk.n(); + let n: RingDegree = sk.n(); let rank: Rank = self.rank_out(); let (mut sk_dft_prep, scratch_1) = scratch.take_glwe_secret_prepared(n, rank); diff --git a/poulpy-core/src/encryption/ggsw_ct.rs b/poulpy-core/src/encryption/ggsw_ct.rs index 6195458..b044ae3 100644 --- a/poulpy-core/src/encryption/ggsw_ct.rs +++ b/poulpy-core/src/encryption/ggsw_ct.rs @@ -1,33 +1,112 @@ use poulpy_hal::{ - api::{ - ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, - VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, - VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, - }, - layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, VecZnx, ZnxZero}, + api::{VecZnxAddScalarInplace, VecZnxDftBytesOf, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes}, + layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, ScalarZnxToRef, Scratch, VecZnx, ZnxZero}, source::Source, }; use crate::{ - TakeGLWEPt, - layouts::{GGSWCiphertext, GGSWInfos, GLWECiphertext, GLWEInfos, LWEInfos, prepared::GLWESecretPrepared}, + SIGMA, + encryption::glwe_ct::GLWEEncryptSkInternal, + layouts::{ + GGSW, GGSWInfos, GGSWToMut, GLWE, GLWEInfos, LWEInfos, + prepared::{GLWESecretPrepared, GLWESecretPreparedToRef}, + }, }; -impl GGSWCiphertext> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize +impl GGSW> { + pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GGSWInfos, - Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes, + Module: VecZnxNormalizeTmpBytes + VecZnxDftBytesOf, { let size = infos.size(); - GLWECiphertext::encrypt_sk_scratch_space(module, &infos.glwe_layout()) - + VecZnx::alloc_bytes(module.n(), (infos.rank() + 1).into(), size) - + VecZnx::alloc_bytes(module.n(), 1, size) - + module.vec_znx_dft_alloc_bytes((infos.rank() + 1).into(), size) + GLWE::encrypt_sk_tmp_bytes(module, &infos.glwe_layout()) + + VecZnx::bytes_of(module.n(), (infos.rank() + 1).into(), size) + + VecZnx::bytes_of(module.n(), 1, size) + + module.bytes_of_vec_znx_dft((infos.rank() + 1).into(), size) } } -impl GGSWCiphertext { +pub trait GGSWEncryptSk { + fn ggsw_encrypt_sk( + &self, + res: &mut R, + pt: &P, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGSWToMut, + P: ScalarZnxToRef, + S: GLWESecretPreparedToRef; +} + +impl GGSWEncryptSk for Module +where + Module: GLWEEncryptSkInternal + VecZnxAddScalarInplace + VecZnxNormalizeInplace, + Scratch:, +{ + fn ggsw_encrypt_sk( + &self, + res: &mut R, + pt: &P, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGSWToMut, + P: ScalarZnxToRef, + S: GLWESecretPreparedToRef, + { + let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); + let pt: &ScalarZnx<&[u8]> = &pt.to_ref(); + let sk: &GLWESecretPrepared<&[u8], B> = &sk.to_ref(); + + #[cfg(debug_assertions)] + { + use poulpy_hal::layouts::ZnxInfos; + + assert_eq!(res.rank(), sk.rank()); + assert_eq!(res.n(), self.n() as u32); + assert_eq!(pt.n(), self.n()); + assert_eq!(sk.n(), self.n() as u32); + } + + let k: usize = res.k().into(); + let base2k: usize = res.base2k().into(); + let rank: usize = res.rank().into(); + let dsize: usize = res.dsize().into(); + let cols: usize = (rank + 1).into(); + + let (mut tmp_pt, scratch_1) = scratch.take_glwe_pt(&res.glwe_layout()); + + for row_i in 0..res.dnum().into() { + tmp_pt.data.zero(); + // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt + self.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (dsize - 1) + row_i * dsize, pt, 0); + self.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scratch_1); + for col_j in 0..rank + 1 { + self.glwe_encrypt_sk_internal( + base2k, + k, + res.at_mut(row_i, col_j).data_mut(), + cols, + false, + Some((&tmp_pt, col_j)), + sk, + source_xa, + source_xe, + SIGMA, + scratch_1, + ); + } + } + } +} + +impl GGSW { #[allow(clippy::too_many_arguments)] pub fn encrypt_sk( &mut self, @@ -38,56 +117,8 @@ impl GGSWCiphertext { source_xe: &mut Source, scratch: &mut Scratch, ) where - Module: VecZnxAddScalarInplace - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + Module: GGSWEncryptSk, { - #[cfg(debug_assertions)] - { - use poulpy_hal::layouts::ZnxInfos; - - assert_eq!(self.rank(), sk.rank()); - assert_eq!(self.n(), sk.n()); - assert_eq!(pt.n() as u32, sk.n()); - } - - let base2k: usize = self.base2k().into(); - let rank: usize = self.rank().into(); - let dsize: usize = self.dsize().into(); - - let (mut tmp_pt, scratch_1) = scratch.take_glwe_pt(&self.glwe_layout()); - - (0..self.dnum().into()).for_each(|row_i| { - tmp_pt.data.zero(); - - // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt - module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (dsize - 1) + row_i * dsize, pt, 0); - module.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scratch_1); - - (0..rank + 1).for_each(|col_j| { - // rlwe encrypt of vec_znx_pt into vec_znx_ct - - self.at_mut(row_i, col_j).encrypt_sk_internal( - module, - Some((&tmp_pt, col_j)), - sk, - source_xa, - source_xe, - scratch_1, - ); - }); - }); + module.ggsw_encrypt_sk(self, pt, sk, source_xa, source_xe, scratch); } } diff --git a/poulpy-core/src/encryption/glwe_ct.rs b/poulpy-core/src/encryption/glwe_ct.rs index 8ecacc6..16bbadf 100644 --- a/poulpy-core/src/encryption/glwe_ct.rs +++ b/poulpy-core/src/encryption/glwe_ct.rs @@ -1,11 +1,11 @@ use poulpy_hal::{ api::{ - ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeSvpPPol, - TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxBigAddNormal, VecZnxBigAddSmallInplace, - VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, - VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, + ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, + VecZnxAddNormal, VecZnxBigAddNormal, VecZnxBigAddSmallInplace, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxDftApply, + VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, + VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, }, - layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, VecZnx, VecZnxBig, ZnxInfos, ZnxZero}, + layouts::{Backend, DataMut, Module, ScalarZnx, Scratch, VecZnx, VecZnxBig, VecZnxToMut, ZnxInfos, ZnxZero}, source::Source, }; @@ -13,157 +13,155 @@ use crate::{ dist::Distribution, encryption::{SIGMA, SIGMA_BOUND}, layouts::{ - GLWECiphertext, GLWEInfos, GLWEPlaintext, LWEInfos, - prepared::{GLWEPublicKeyPrepared, GLWESecretPrepared}, + GLWE, GLWEInfos, GLWEPlaintext, GLWEPlaintextToRef, GLWEToMut, LWEInfos, + prepared::{GLWEPublicKeyPrepared, GLWEPublicKeyPreparedToRef, GLWESecretPrepared, GLWESecretPreparedToRef}, }, }; -impl GLWECiphertext> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize +impl GLWE> { + pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GLWEInfos, - Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes, + Module: VecZnxNormalizeTmpBytes + VecZnxDftBytesOf, { let size: usize = infos.size(); assert_eq!(module.n() as u32, infos.n()); - module.vec_znx_normalize_tmp_bytes() - + 2 * VecZnx::alloc_bytes(module.n(), 1, size) - + module.vec_znx_dft_alloc_bytes(1, size) + module.vec_znx_normalize_tmp_bytes() + 2 * VecZnx::bytes_of(module.n(), 1, size) + module.bytes_of_vec_znx_dft(1, size) } - pub fn encrypt_pk_scratch_space(module: &Module, infos: &A) -> usize + pub fn encrypt_pk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GLWEInfos, - Module: VecZnxDftAllocBytes + SvpPPolAllocBytes + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes, + Module: VecZnxDftBytesOf + SvpPPolBytesOf + VecZnxBigBytesOf + VecZnxNormalizeTmpBytes, { let size: usize = infos.size(); assert_eq!(module.n() as u32, infos.n()); - ((module.vec_znx_dft_alloc_bytes(1, size) + module.vec_znx_big_alloc_bytes(1, size)) - | ScalarZnx::alloc_bytes(module.n(), 1)) - + module.svp_ppol_alloc_bytes(1) + ((module.bytes_of_vec_znx_dft(1, size) + module.bytes_of_vec_znx_big(1, size)) | ScalarZnx::bytes_of(module.n(), 1)) + + module.bytes_of_svp_ppol(1) + module.vec_znx_normalize_tmp_bytes() } } -impl GLWECiphertext { - #[allow(clippy::too_many_arguments)] - pub fn encrypt_sk( +impl GLWE { + pub fn encrypt_sk( &mut self, module: &Module, - pt: &GLWEPlaintext, - sk: &GLWESecretPrepared, + pt: &P, + sk: &S, source_xa: &mut Source, source_xe: &mut Source, scratch: &mut Scratch, ) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + P: GLWEPlaintextToRef, + S: GLWESecretPreparedToRef, + Module: GLWEEncryptSk, { + module.glwe_encrypt_sk(self, pt, sk, source_xa, source_xe, scratch); + } + + pub fn encrypt_zero_sk( + &mut self, + module: &Module, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + S: GLWESecretPreparedToRef, + Module: GLWEEncryptZeroSk, + { + module.glwe_encrypt_zero_sk(self, sk, source_xa, source_xe, scratch); + } + + pub fn encrypt_pk( + &mut self, + module: &Module, + pt: &P, + pk: &K, + source_xu: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + P: GLWEPlaintextToRef, + K: GLWEPublicKeyPreparedToRef, + Module: GLWEEncryptPk, + { + module.glwe_encrypt_pk(self, pt, pk, source_xu, source_xe, scratch); + } + + pub fn encrypt_zero_pk( + &mut self, + module: &Module, + pk: &K, + source_xu: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + K: GLWEPublicKeyPreparedToRef, + Module: GLWEEncryptZeroPk, + { + module.glwe_encrypt_zero_pk(self, pk, source_xu, source_xe, scratch); + } +} + +pub trait GLWEEncryptSk { + fn glwe_encrypt_sk( + &self, + res: &mut R, + pt: &P, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GLWEToMut, + P: GLWEPlaintextToRef, + S: GLWESecretPreparedToRef; +} + +impl GLWEEncryptSk for Module +where + Module: GLWEEncryptSkInternal + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf, + Scratch: ScratchAvailable, +{ + fn glwe_encrypt_sk( + &self, + res: &mut R, + pt: &P, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GLWEToMut, + P: GLWEPlaintextToRef, + S: GLWESecretPreparedToRef, + { + let mut res: GLWE<&mut [u8]> = res.to_mut(); + let pt: GLWEPlaintext<&[u8]> = pt.to_ref(); + #[cfg(debug_assertions)] { - assert_eq!(self.rank(), sk.rank()); - assert_eq!(sk.n(), self.n()); - assert_eq!(pt.n(), self.n()); + let sk: GLWESecretPrepared<&[u8], B> = sk.to_ref(); + assert_eq!(res.rank(), sk.rank()); + assert_eq!(res.n(), self.n() as u32); + assert_eq!(sk.n(), self.n() as u32); + assert_eq!(pt.n(), self.n() as u32); assert!( - scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self), - "scratch.available(): {} < GLWECiphertext::encrypt_sk_scratch_space: {}", + scratch.available() >= GLWE::encrypt_sk_tmp_bytes(self, &res), + "scratch.available(): {} < GLWECiphertext::encrypt_sk_tmp_bytes: {}", scratch.available(), - GLWECiphertext::encrypt_sk_scratch_space(module, self) + GLWE::encrypt_sk_tmp_bytes(self, &res) ) } - self.encrypt_sk_internal(module, Some((pt, 0)), sk, source_xa, source_xe, scratch); - } - - pub fn encrypt_zero_sk( - &mut self, - module: &Module, - sk: &GLWESecretPrepared, - source_xa: &mut Source, - source_xe: &mut Source, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, - { - #[cfg(debug_assertions)] - { - assert_eq!(self.rank(), sk.rank()); - assert_eq!(sk.n(), self.n()); - assert!( - scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self), - "scratch.available(): {} < GLWECiphertext::encrypt_sk_scratch_space: {}", - scratch.available(), - GLWECiphertext::encrypt_sk_scratch_space(module, self) - ) - } - self.encrypt_sk_internal( - module, - None::<(&GLWEPlaintext>, usize)>, - sk, - source_xa, - source_xe, - scratch, - ); - } - - #[allow(clippy::too_many_arguments)] - pub(crate) fn encrypt_sk_internal( - &mut self, - module: &Module, - pt: Option<(&GLWEPlaintext, usize)>, - sk: &GLWESecretPrepared, - source_xa: &mut Source, - source_xe: &mut Source, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, - { - let cols: usize = (self.rank() + 1).into(); - glwe_encrypt_sk_internal( - module, - self.base2k().into(), - self.k().into(), - &mut self.data, + let cols: usize = (res.rank() + 1).into(); + self.glwe_encrypt_sk_internal( + res.base2k().into(), + res.k().into(), + res.data_mut(), cols, false, - pt, + Some((&pt, 0)), sk, source_xa, source_xe, @@ -171,46 +169,136 @@ impl GLWECiphertext { scratch, ); } +} - #[allow(clippy::too_many_arguments)] - pub fn encrypt_pk( - &mut self, - module: &Module, - pt: &GLWEPlaintext, - pk: &GLWEPublicKeyPrepared, - source_xu: &mut Source, +pub trait GLWEEncryptZeroSk { + fn glwe_encrypt_zero_sk( + &self, + res: &mut R, + sk: &S, + source_xa: &mut Source, source_xe: &mut Source, scratch: &mut Scratch, ) where - Module: SvpPrepare - + SvpApplyDftToDft - + VecZnxIdftApplyConsume - + VecZnxBigAddNormal - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize, - Scratch: TakeSvpPPol + TakeScalarZnx + TakeVecZnxDft, + R: GLWEToMut, + S: GLWESecretPreparedToRef; +} + +impl GLWEEncryptZeroSk for Module +where + Module: GLWEEncryptSkInternal + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf, + Scratch: ScratchAvailable, +{ + fn glwe_encrypt_zero_sk( + &self, + res: &mut R, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GLWEToMut, + S: GLWESecretPreparedToRef, { - self.encrypt_pk_internal::(module, Some((pt, 0)), pk, source_xu, source_xe, scratch); + let mut res: GLWE<&mut [u8]> = res.to_mut(); + + #[cfg(debug_assertions)] + { + let sk: GLWESecretPrepared<&[u8], B> = sk.to_ref(); + assert_eq!(res.rank(), sk.rank()); + assert_eq!(res.n(), self.n() as u32); + assert_eq!(sk.n(), self.n() as u32); + assert!( + scratch.available() >= GLWE::encrypt_sk_tmp_bytes(self, &res), + "scratch.available(): {} < GLWECiphertext::encrypt_sk_tmp_bytes: {}", + scratch.available(), + GLWE::encrypt_sk_tmp_bytes(self, &res) + ) + } + + let cols: usize = (res.rank() + 1).into(); + self.glwe_encrypt_sk_internal( + res.base2k().into(), + res.k().into(), + res.data_mut(), + cols, + false, + None::<(&GLWEPlaintext>, usize)>, + sk, + source_xa, + source_xe, + SIGMA, + scratch, + ); } +} - pub fn encrypt_zero_pk( - &mut self, - module: &Module, - pk: &GLWEPublicKeyPrepared, +pub trait GLWEEncryptPk { + fn glwe_encrypt_pk( + &self, + res: &mut R, + pt: &P, + pk: &K, source_xu: &mut Source, source_xe: &mut Source, scratch: &mut Scratch, ) where - Module: SvpPrepare - + SvpApplyDftToDft - + VecZnxIdftApplyConsume - + VecZnxBigAddNormal - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize, - Scratch: TakeSvpPPol + TakeScalarZnx + TakeVecZnxDft, + R: GLWEToMut, + P: GLWEPlaintextToRef, + K: GLWEPublicKeyPreparedToRef; +} + +impl GLWEEncryptPk for Module +where + Module: GLWEEncryptPkInternal, +{ + fn glwe_encrypt_pk( + &self, + res: &mut R, + pt: &P, + pk: &K, + source_xu: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GLWEToMut, + P: GLWEPlaintextToRef, + K: GLWEPublicKeyPreparedToRef, { - self.encrypt_pk_internal::, DataPk, B>( - module, + self.glwe_encrypt_pk_internal(res, Some((pt, 0)), pk, source_xu, source_xe, scratch); + } +} + +pub trait GLWEEncryptZeroPk { + fn glwe_encrypt_zero_pk( + &self, + res: &mut R, + pk: &K, + source_xu: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GLWEToMut, + K: GLWEPublicKeyPreparedToRef; +} + +impl GLWEEncryptZeroPk for Module +where + Module: GLWEEncryptPkInternal, +{ + fn glwe_encrypt_zero_pk( + &self, + res: &mut R, + pk: &K, + source_xu: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GLWEToMut, + K: GLWEPublicKeyPreparedToRef, + { + self.glwe_encrypt_pk_internal( + res, None::<(&GLWEPlaintext>, usize)>, pk, source_xu, @@ -218,45 +306,69 @@ impl GLWECiphertext { scratch, ); } +} - #[allow(clippy::too_many_arguments)] - pub(crate) fn encrypt_pk_internal( - &mut self, - module: &Module, - pt: Option<(&GLWEPlaintext, usize)>, - pk: &GLWEPublicKeyPrepared, +pub(crate) trait GLWEEncryptPkInternal { + fn glwe_encrypt_pk_internal( + &self, + res: &mut R, + pt: Option<(&P, usize)>, + pk: &K, source_xu: &mut Source, source_xe: &mut Source, scratch: &mut Scratch, ) where - Module: SvpPrepare - + SvpApplyDftToDft - + VecZnxIdftApplyConsume - + VecZnxBigAddNormal - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize, - Scratch: TakeSvpPPol + TakeScalarZnx + TakeVecZnxDft, + R: GLWEToMut, + P: GLWEPlaintextToRef, + K: GLWEPublicKeyPreparedToRef; +} + +impl GLWEEncryptPkInternal for Module +where + Module: SvpPrepare + + SvpApplyDftToDft + + VecZnxIdftApplyConsume + + VecZnxBigAddNormal + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize, + Scratch:, +{ + fn glwe_encrypt_pk_internal( + &self, + res: &mut R, + pt: Option<(&P, usize)>, + pk: &K, + source_xu: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GLWEToMut, + P: GLWEPlaintextToRef, + K: GLWEPublicKeyPreparedToRef, { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let pk: &GLWEPublicKeyPrepared<&[u8], B> = &pk.to_ref(); + #[cfg(debug_assertions)] { - assert_eq!(self.base2k(), pk.base2k()); - assert_eq!(self.n(), pk.n()); - assert_eq!(self.rank(), pk.rank()); + assert_eq!(res.base2k(), pk.base2k()); + assert_eq!(res.n(), pk.n()); + assert_eq!(res.rank(), pk.rank()); if let Some((pt, _)) = pt { - assert_eq!(pt.base2k(), pk.base2k()); - assert_eq!(pt.n(), pk.n()); + assert_eq!(pt.to_ref().base2k(), pk.base2k()); + assert_eq!(pt.to_ref().n(), pk.n()); } } let base2k: usize = pk.base2k().into(); let size_pk: usize = pk.size(); - let cols: usize = (self.rank() + 1).into(); + let cols: usize = (res.rank() + 1).into(); // Generates u according to the underlying secret distribution. - let (mut u_dft, scratch_1) = scratch.take_svp_ppol(self.n().into(), 1); + let (mut u_dft, scratch_1) = scratch.take_svp_ppol(res.n().into(), 1); { - let (mut u, _) = scratch_1.take_scalar_znx(self.n().into(), 1); + let (mut u, _) = scratch_1.take_scalar_znx(res.n().into(), 1); match pk.dist { Distribution::NONE => panic!( "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through \ @@ -270,20 +382,20 @@ impl GLWECiphertext { Distribution::ZERO => {} } - module.svp_prepare(&mut u_dft, 0, &u, 0); + self.svp_prepare(&mut u_dft, 0, &u, 0); } // ct[i] = pk[i] * u + ei (+ m if col = i) (0..cols).for_each(|i| { - let (mut ci_dft, scratch_2) = scratch_1.take_vec_znx_dft(self.n().into(), 1, size_pk); + let (mut ci_dft, scratch_2) = scratch_1.take_vec_znx_dft(res.n().into(), 1, size_pk); // ci_dft = DFT(u) * DFT(pk[i]) - module.svp_apply_dft_to_dft(&mut ci_dft, 0, &u_dft, 0, &pk.data, i); + self.svp_apply_dft_to_dft(&mut ci_dft, 0, &u_dft, 0, &pk.data, i); // ci_big = u * p[i] - let mut ci_big = module.vec_znx_idft_apply_consume(ci_dft); + let mut ci_big = self.vec_znx_idft_apply_consume(ci_dft); // ci_big = u * pk[i] + e - module.vec_znx_big_add_normal( + self.vec_znx_big_add_normal( base2k, &mut ci_big, 0, @@ -297,31 +409,38 @@ impl GLWECiphertext { if let Some((pt, col)) = pt && col == i { - module.vec_znx_big_add_small_inplace(&mut ci_big, 0, &pt.data, 0); + self.vec_znx_big_add_small_inplace(&mut ci_big, 0, &pt.to_ref().data, 0); } // ct[i] = norm(ci_big) - module.vec_znx_big_normalize(base2k, &mut self.data, i, base2k, &ci_big, 0, scratch_2); + self.vec_znx_big_normalize(base2k, &mut res.data, i, base2k, &ci_big, 0, scratch_2); }); } } -#[allow(clippy::too_many_arguments)] -pub(crate) fn glwe_encrypt_sk_internal( - module: &Module, - base2k: usize, - k: usize, - ct: &mut VecZnx, - cols: usize, - compressed: bool, - pt: Option<(&GLWEPlaintext, usize)>, - sk: &GLWESecretPrepared, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - scratch: &mut Scratch, -) where - Module: VecZnxDftAllocBytes +pub(crate) trait GLWEEncryptSkInternal { + fn glwe_encrypt_sk_internal( + &self, + base2k: usize, + k: usize, + res: &mut R, + cols: usize, + compressed: bool, + pt: Option<(&P, usize)>, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + P: GLWEPlaintextToRef, + S: GLWESecretPreparedToRef; +} + +impl GLWEEncryptSkInternal for Module +where + Module: VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -334,74 +453,96 @@ pub(crate) fn glwe_encrypt_sk_internal + VecZnxSub, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + Scratch: ScratchAvailable, { - #[cfg(debug_assertions)] + fn glwe_encrypt_sk_internal( + &self, + base2k: usize, + k: usize, + res: &mut R, + cols: usize, + compressed: bool, + pt: Option<(&P, usize)>, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + P: GLWEPlaintextToRef, + S: GLWESecretPreparedToRef, { - if compressed { - assert_eq!( - ct.cols(), - 1, - "invalid ciphertext: compressed tag=true but #cols={} != 1", - ct.cols() - ) - } - } + let ct: &mut VecZnx<&mut [u8]> = &mut res.to_mut(); + let sk: GLWESecretPrepared<&[u8], B> = sk.to_ref(); - let size: usize = ct.size(); - - let (mut c0, scratch_1) = scratch.take_vec_znx(ct.n(), 1, size); - c0.zero(); - - { - let (mut ci, scratch_2) = scratch_1.take_vec_znx(ct.n(), 1, size); - - // ct[i] = uniform - // ct[0] -= c[i] * s[i], - (1..cols).for_each(|i| { - let col_ct: usize = if compressed { 0 } else { i }; - - // ct[i] = uniform (+ pt) - module.vec_znx_fill_uniform(base2k, ct, col_ct, source_xa); - - let (mut ci_dft, scratch_3) = scratch_2.take_vec_znx_dft(ct.n(), 1, size); - - // ci = ct[i] - pt - // i.e. we act as we sample ct[i] already as uniform + pt - // and if there is a pt, then we subtract it before applying DFT - if let Some((pt, col)) = pt { - if i == col { - module.vec_znx_sub(&mut ci, 0, ct, col_ct, &pt.data, 0); - module.vec_znx_normalize_inplace(base2k, &mut ci, 0, scratch_3); - module.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, &ci, 0); - } else { - module.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, ct, col_ct); - } - } else { - module.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, ct, col_ct); + #[cfg(debug_assertions)] + { + if compressed { + assert_eq!( + ct.cols(), + 1, + "invalid ciphertext: compressed tag=true but #cols={} != 1", + ct.cols() + ) } + } - module.svp_apply_dft_to_dft_inplace(&mut ci_dft, 0, &sk.data, i - 1); - let ci_big: VecZnxBig<&mut [u8], B> = module.vec_znx_idft_apply_consume(ci_dft); + let size: usize = ct.size(); - // use c[0] as buffer, which is overwritten later by the normalization step - module.vec_znx_big_normalize(base2k, &mut ci, 0, base2k, &ci_big, 0, scratch_3); + let (mut c0, scratch_1) = scratch.take_vec_znx(ct.n(), 1, size); + c0.zero(); - // c0_tmp = -c[i] * s[i] (use c[0] as buffer) - module.vec_znx_sub_inplace(&mut c0, 0, &ci, 0); - }); + { + let (mut ci, scratch_2) = scratch_1.take_vec_znx(ct.n(), 1, size); + + // ct[i] = uniform + // ct[0] -= c[i] * s[i], + (1..cols).for_each(|i| { + let col_ct: usize = if compressed { 0 } else { i }; + + // ct[i] = uniform (+ pt) + self.vec_znx_fill_uniform(base2k, ct, col_ct, source_xa); + + let (mut ci_dft, scratch_3) = scratch_2.take_vec_znx_dft(ct.n(), 1, size); + + // ci = ct[i] - pt + // i.e. we act as we sample ct[i] already as uniform + pt + // and if there is a pt, then we subtract it before applying DFT + if let Some((pt, col)) = pt { + if i == col { + self.vec_znx_sub(&mut ci, 0, ct, col_ct, &pt.to_ref().data, 0); + self.vec_znx_normalize_inplace(base2k, &mut ci, 0, scratch_3); + self.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, &ci, 0); + } else { + self.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, ct, col_ct); + } + } else { + self.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, ct, col_ct); + } + + self.svp_apply_dft_to_dft_inplace(&mut ci_dft, 0, &sk.data, i - 1); + let ci_big: VecZnxBig<&mut [u8], B> = self.vec_znx_idft_apply_consume(ci_dft); + + // use c[0] as buffer, which is overwritten later by the normalization step + self.vec_znx_big_normalize(base2k, &mut ci, 0, base2k, &ci_big, 0, scratch_3); + + // c0_tmp = -c[i] * s[i] (use c[0] as buffer) + self.vec_znx_sub_inplace(&mut c0, 0, &ci, 0); + }); + } + + // c[0] += e + self.vec_znx_add_normal(base2k, &mut c0, 0, k, source_xe, sigma, SIGMA_BOUND); + + // c[0] += m if col = 0 + if let Some((pt, col)) = pt + && col == 0 + { + self.vec_znx_add_inplace(&mut c0, 0, &pt.to_ref().data, 0); + } + + // c[0] = norm(c[0]) + self.vec_znx_normalize(base2k, ct, 0, base2k, &c0, 0, scratch_1); } - - // c[0] += e - module.vec_znx_add_normal(base2k, &mut c0, 0, k, source_xe, sigma, SIGMA_BOUND); - - // c[0] += m if col = 0 - if let Some((pt, col)) = pt - && col == 0 - { - module.vec_znx_add_inplace(&mut c0, 0, &pt.data, 0); - } - - // c[0] = norm(c[0]) - module.vec_znx_normalize(base2k, ct, 0, base2k, &c0, 0, scratch_1); } diff --git a/poulpy-core/src/encryption/glwe_pk.rs b/poulpy-core/src/encryption/glwe_pk.rs index c7cdaeb..d89f515 100644 --- a/poulpy-core/src/encryption/glwe_pk.rs +++ b/poulpy-core/src/encryption/glwe_pk.rs @@ -1,50 +1,43 @@ use poulpy_hal::{ - api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, VecZnxAddInplace, VecZnxAddNormal, VecZnxBigNormalize, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, - }, + api::{ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxDftBytesOf, VecZnxNormalizeTmpBytes}, layouts::{Backend, DataMut, DataRef, Module, ScratchOwned}, - oep::{ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxDftImpl, TakeVecZnxImpl}, source::Source, }; -use crate::layouts::{GLWECiphertext, GLWEPublicKey, prepared::GLWESecretPrepared}; +use crate::{ + encryption::glwe_ct::GLWEEncryptZeroSk, + layouts::{ + GLWE, GLWEPublicKey, GLWEPublicKeyToMut, + prepared::{GLWESecretPrepared, GLWESecretPreparedToRef}, + }, +}; -impl GLWEPublicKey { - pub fn generate_from_sk( - &mut self, - module: &Module, - sk: &GLWESecretPrepared, - source_xa: &mut Source, - source_xe: &mut Source, - ) where - Module:, - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - B: Backend - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + TakeVecZnxDftImpl - + ScratchAvailableImpl - + TakeVecZnxImpl, +pub trait GLWEPublicKeyGenerate { + fn glwe_public_key_generate(&self, res: &mut R, sk: &S, source_xa: &mut Source, source_xe: &mut Source) + where + R: GLWEPublicKeyToMut, + S: GLWESecretPreparedToRef; +} + +impl GLWEPublicKeyGenerate for Module +where + Module: GLWEEncryptZeroSk + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + fn glwe_public_key_generate(&self, res: &mut R, sk: &S, source_xa: &mut Source, source_xe: &mut Source) + where + R: GLWEPublicKeyToMut, + S: GLWESecretPreparedToRef, { + let res: &mut GLWEPublicKey<&mut [u8]> = &mut res.to_mut(); + let sk: &GLWESecretPrepared<&[u8], B> = &sk.to_ref(); + #[cfg(debug_assertions)] { use crate::{Distribution, layouts::LWEInfos}; - assert_eq!(self.n(), sk.n()); + assert_eq!(res.n(), self.n() as u32); + assert_eq!(sk.n(), self.n() as u32); if sk.dist == Distribution::NONE { panic!("invalid sk: SecretDistribution::NONE") @@ -52,10 +45,25 @@ impl GLWEPublicKey { } // Its ok to allocate scratch space here since pk is usually generated only once. - let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWECiphertext::encrypt_sk_scratch_space(module, self)); + let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWE::encrypt_sk_tmp_bytes(self, res)); - let mut tmp: GLWECiphertext> = GLWECiphertext::alloc(self); - tmp.encrypt_zero_sk(module, sk, source_xa, source_xe, scratch.borrow()); - self.dist = sk.dist; + let mut tmp: GLWE> = GLWE::alloc_from_infos(res); + + tmp.encrypt_zero_sk(self, sk, source_xa, source_xe, scratch.borrow()); + res.dist = sk.dist; + } +} + +impl GLWEPublicKey { + pub fn generate( + &mut self, + module: &Module, + sk: &GLWESecretPrepared, + source_xa: &mut Source, + source_xe: &mut Source, + ) where + Module: GLWEPublicKeyGenerate, + { + module.glwe_public_key_generate(self, sk, source_xa, source_xe); } } diff --git a/poulpy-core/src/encryption/glwe_to_lwe_ksk.rs b/poulpy-core/src/encryption/glwe_to_lwe_ksk.rs index b65ce4e..971766e 100644 --- a/poulpy-core/src/encryption/glwe_to_lwe_ksk.rs +++ b/poulpy-core/src/encryption/glwe_to_lwe_ksk.rs @@ -1,32 +1,30 @@ use poulpy_hal::{ api::{ - ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigNormalize, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, + ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, + VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, + VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, + VecZnxSubInplace, VecZnxSwitchRing, }, layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut, ZnxZero}, source::Source, }; -use crate::{ - TakeGLWESecret, TakeGLWESecretPrepared, - layouts::{GGLWEInfos, GGLWESwitchingKey, GLWESecret, GLWEToLWEKey, LWEInfos, LWESecret, Rank, prepared::GLWESecretPrepared}, +use crate::layouts::{ + GGLWEInfos, GLWESecret, GLWESwitchingKey, GLWEToLWESwitchingKey, LWEInfos, LWESecret, Rank, prepared::GLWESecretPrepared, }; -impl GLWEToLWEKey> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize +impl GLWEToLWESwitchingKey> { + pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GGLWEInfos, - Module: SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes, + Module: SvpPPolBytesOf + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes, { - GLWESecretPrepared::alloc_bytes_with(module, infos.rank_in()) - + (GGLWESwitchingKey::encrypt_sk_scratch_space(module, infos) - | GLWESecret::alloc_bytes_with(infos.n(), infos.rank_in())) + GLWESecretPrepared::bytes_of(module, infos.rank_in()) + + (GLWESwitchingKey::encrypt_sk_tmp_bytes(module, infos) | GLWESecret::bytes_of(infos.n(), infos.rank_in())) } } -impl GLWEToLWEKey { +impl GLWEToLWESwitchingKey { #[allow(clippy::too_many_arguments)] pub fn encrypt_sk( &mut self, @@ -41,7 +39,7 @@ impl GLWEToLWEKey { DGlwe: DataRef, Module: VecZnxAutomorphismInplace + VecZnxAddScalarInplace - + VecZnxDftAllocBytes + + VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -56,8 +54,8 @@ impl GLWEToLWEKey { + VecZnxSub + SvpPrepare + VecZnxSwitchRing - + SvpPPolAllocBytes, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared, + + SvpPPolBytesOf, + Scratch: ScratchAvailable, { #[cfg(debug_assertions)] { diff --git a/poulpy-core/src/encryption/lwe_ct.rs b/poulpy-core/src/encryption/lwe_ct.rs index 4dd09ac..a01d95f 100644 --- a/poulpy-core/src/encryption/lwe_ct.rs +++ b/poulpy-core/src/encryption/lwe_ct.rs @@ -7,10 +7,10 @@ use poulpy_hal::{ use crate::{ encryption::{SIGMA, SIGMA_BOUND}, - layouts::{LWECiphertext, LWEInfos, LWEPlaintext, LWESecret}, + layouts::{LWE, LWEInfos, LWEPlaintext, LWESecret}, }; -impl LWECiphertext { +impl LWE { pub fn encrypt_sk( &mut self, module: &Module, diff --git a/poulpy-core/src/encryption/lwe_ksk.rs b/poulpy-core/src/encryption/lwe_ksk.rs index 66ae685..2fd60ff 100644 --- a/poulpy-core/src/encryption/lwe_ksk.rs +++ b/poulpy-core/src/encryption/lwe_ksk.rs @@ -1,27 +1,24 @@ use poulpy_hal::{ api::{ - ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigNormalize, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, + SvpApplyDftToDftInplace, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, + VecZnxAutomorphismInplace, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, + VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, + VecZnxSwitchRing, }, layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut}, source::Source, }; -use crate::{ - TakeGLWESecret, TakeGLWESecretPrepared, - layouts::{ - Degree, GGLWEInfos, GGLWESwitchingKey, GLWESecret, LWEInfos, LWESecret, LWESwitchingKey, Rank, - prepared::GLWESecretPrepared, - }, +use crate::layouts::{ + GGLWEInfos, GLWESecret, GLWESwitchingKey, LWEInfos, LWESecret, LWESwitchingKey, Rank, RingDegree, + prepared::GLWESecretPrepared, }; impl LWESwitchingKey> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize + pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GGLWEInfos, - Module: SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes, + Module: SvpPPolBytesOf + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes, { debug_assert_eq!( infos.dsize().0, @@ -38,9 +35,9 @@ impl LWESwitchingKey> { 1, "rank_out > 1 is not supported for LWESwitchingKey" ); - GLWESecret::alloc_bytes_with(Degree(module.n() as u32), Rank(1)) - + GLWESecretPrepared::alloc_bytes_with(module, Rank(1)) - + GGLWESwitchingKey::encrypt_sk_scratch_space(module, infos) + GLWESecret::bytes_of(RingDegree(module.n() as u32), Rank(1)) + + GLWESecretPrepared::bytes_of(module, Rank(1)) + + GLWESwitchingKey::encrypt_sk_tmp_bytes(module, infos) } } @@ -59,7 +56,7 @@ impl LWESwitchingKey { DOut: DataRef, Module: VecZnxAutomorphismInplace + VecZnxAddScalarInplace - + VecZnxDftAllocBytes + + VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -74,8 +71,8 @@ impl LWESwitchingKey { + VecZnxSub + SvpPrepare + VecZnxSwitchRing - + SvpPPolAllocBytes, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared, + + SvpPPolBytesOf, + Scratch:, { #[cfg(debug_assertions)] { diff --git a/poulpy-core/src/encryption/lwe_to_glwe_ksk.rs b/poulpy-core/src/encryption/lwe_to_glwe_ksk.rs index 204e84b..041c7c4 100644 --- a/poulpy-core/src/encryption/lwe_to_glwe_ksk.rs +++ b/poulpy-core/src/encryption/lwe_to_glwe_ksk.rs @@ -1,32 +1,29 @@ use poulpy_hal::{ api::{ - ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigNormalize, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, + ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, + VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, + VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, + VecZnxSubInplace, VecZnxSwitchRing, }, layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut}, source::Source, }; -use crate::{ - TakeGLWESecret, TakeGLWESecretPrepared, - layouts::{Degree, GGLWEInfos, GGLWESwitchingKey, GLWESecret, LWEInfos, LWESecret, LWEToGLWESwitchingKey, Rank}, -}; +use crate::layouts::{GGLWEInfos, GLWESecret, GLWESwitchingKey, LWEInfos, LWESecret, LWEToGLWESwitchingKey, Rank, RingDegree}; impl LWEToGLWESwitchingKey> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize + pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GGLWEInfos, - Module: SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes, + Module: SvpPPolBytesOf + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes, { debug_assert_eq!( infos.rank_in(), Rank(1), "rank_in != 1 is not supported for LWEToGLWESwitchingKey" ); - GGLWESwitchingKey::encrypt_sk_scratch_space(module, infos) - + GLWESecret::alloc_bytes_with(Degree(module.n() as u32), infos.rank_in()) + GLWESwitchingKey::encrypt_sk_tmp_bytes(module, infos) + + GLWESecret::bytes_of(RingDegree(module.n() as u32), infos.rank_in()) } } @@ -45,7 +42,7 @@ impl LWEToGLWESwitchingKey { DGlwe: DataRef, Module: VecZnxAutomorphismInplace + VecZnxAddScalarInplace - + VecZnxDftAllocBytes + + VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -60,8 +57,8 @@ impl LWEToGLWESwitchingKey { + VecZnxSub + SvpPrepare + VecZnxSwitchRing - + SvpPPolAllocBytes, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared, + + SvpPPolBytesOf, + Scratch: ScratchAvailable, { #[cfg(debug_assertions)] { diff --git a/poulpy-core/src/encryption/mod.rs b/poulpy-core/src/encryption/mod.rs index 9380933..fb9a459 100644 --- a/poulpy-core/src/encryption/mod.rs +++ b/poulpy-core/src/encryption/mod.rs @@ -11,7 +11,5 @@ mod lwe_ct; mod lwe_ksk; mod lwe_to_glwe_ksk; -pub(crate) use glwe_ct::glwe_encrypt_sk_internal; - pub const SIGMA: f64 = 3.2; pub(crate) const SIGMA_BOUND: f64 = 6.0 * SIGMA; diff --git a/poulpy-core/src/external_product/gglwe_atk.rs b/poulpy-core/src/external_product/gglwe_atk.rs index cb35a4c..871eab8 100644 --- a/poulpy-core/src/external_product/gglwe_atk.rs +++ b/poulpy-core/src/external_product/gglwe_atk.rs @@ -1,83 +1,46 @@ -use poulpy_hal::{ - api::{ - ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, - VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, - VmpApplyDftToDftTmpBytes, - }, - layouts::{Backend, DataMut, DataRef, Module, Scratch}, +use poulpy_hal::layouts::{Backend, DataMut, Scratch}; + +use crate::{ + ScratchTakeCore, + external_product::gglwe_ksk::GGLWEExternalProduct, + layouts::{AutomorphismKey, AutomorphismKeyToRef, GGLWEInfos, GGSWInfos, prepared::GGSWPreparedToRef}, }; -use crate::layouts::{GGLWEAutomorphismKey, GGLWEInfos, GGLWESwitchingKey, GGSWInfos, prepared::GGSWCiphertextPrepared}; - -impl GGLWEAutomorphismKey> { - pub fn external_product_scratch_space( - module: &Module, - out_infos: &OUT, - in_infos: &IN, - ggsw_infos: &GGSW, +impl AutomorphismKey> { + pub fn external_product_tmp_bytes( + &self, + module: &M, + res_infos: &R, + a_infos: &A, + b_infos: &B, ) -> usize where - OUT: GGLWEInfos, - IN: GGLWEInfos, - GGSW: GGSWInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes, + R: GGLWEInfos, + A: GGLWEInfos, + B: GGSWInfos, + M: GGLWEExternalProduct, { - GGLWESwitchingKey::external_product_scratch_space(module, out_infos, in_infos, ggsw_infos) - } - - pub fn external_product_inplace_scratch_space( - module: &Module, - out_infos: &OUT, - ggsw_infos: &GGSW, - ) -> usize - where - OUT: GGLWEInfos, - GGSW: GGSWInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes, - { - GGLWESwitchingKey::external_product_inplace_scratch_space(module, out_infos, ggsw_infos) + module.gglwe_external_product_tmp_bytes(res_infos, a_infos, b_infos) } } -impl GGLWEAutomorphismKey { - pub fn external_product( - &mut self, - module: &Module, - lhs: &GGLWEAutomorphismKey, - rhs: &GGSWCiphertextPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxNormalizeTmpBytes - + VecZnxDftApply - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxIdftApplyConsume - + VecZnxBigNormalize - + VecZnxNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, +impl AutomorphismKey { + pub fn external_product(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch) + where + M: GGLWEExternalProduct, + A: AutomorphismKeyToRef, + B: GGSWPreparedToRef, + Scratch: ScratchTakeCore, { - self.key.external_product(module, &lhs.key, rhs, scratch); + module.gglwe_external_product(&mut self.key.key, &a.to_ref().key.key, b, scratch); } - pub fn external_product_inplace( - &mut self, - module: &Module, - rhs: &GGSWCiphertextPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxNormalizeTmpBytes - + VecZnxDftApply - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxIdftApplyConsume - + VecZnxBigNormalize - + VecZnxNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + pub fn external_product_inplace(&mut self, module: &M, a: &A, scratch: &mut Scratch) + where + M: GGLWEExternalProduct, + A: GGSWPreparedToRef, + Scratch: ScratchTakeCore, { - self.key.external_product_inplace(module, rhs, scratch); + module.gglwe_external_product_inplace(&mut self.key.key, a, scratch); } } diff --git a/poulpy-core/src/external_product/gglwe_ksk.rs b/poulpy-core/src/external_product/gglwe_ksk.rs index 2eff45c..5bb4557 100644 --- a/poulpy-core/src/external_product/gglwe_ksk.rs +++ b/poulpy-core/src/external_product/gglwe_ksk.rs @@ -1,144 +1,134 @@ -use poulpy_hal::{ - api::{ - ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, - VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, - VmpApplyDftToDftTmpBytes, +use poulpy_hal::layouts::{Backend, DataMut, Module, Scratch, ZnxZero}; + +use crate::{ + GLWEExternalProduct, ScratchTakeCore, + layouts::{ + GGLWE, GGLWEInfos, GGLWEToMut, GGLWEToRef, GGSWInfos, GLWEInfos, GLWESwitchingKey, GLWESwitchingKeyToRef, + prepared::{GGSWPrepared, GGSWPreparedToRef}, }, - layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero}, }; -use crate::layouts::{GGLWEInfos, GGLWESwitchingKey, GGSWInfos, GLWECiphertext, prepared::GGSWCiphertextPrepared}; - -impl GGLWESwitchingKey> { - pub fn external_product_scratch_space( - module: &Module, - out_infos: &OUT, - in_infos: &IN, - ggsw_infos: &GGSW, - ) -> usize +pub trait GGLWEExternalProduct +where + Self: GLWEExternalProduct, +{ + fn gglwe_external_product_tmp_bytes(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize where - OUT: GGLWEInfos, - IN: GGLWEInfos, - GGSW: GGSWInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes, + R: GGLWEInfos, + A: GGLWEInfos, + B: GGSWInfos, { - GLWECiphertext::external_product_scratch_space( - module, - &out_infos.glwe_layout(), - &in_infos.glwe_layout(), - ggsw_infos, - ) + self.glwe_external_product_tmp_bytes(res_infos, a_infos, b_infos) } - pub fn external_product_inplace_scratch_space( - module: &Module, - out_infos: &OUT, - ggsw_infos: &GGSW, - ) -> usize + fn gglwe_external_product(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) where - OUT: GGLWEInfos, - GGSW: GGSWInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes, + R: GGLWEToMut, + A: GGLWEToRef, + B: GGSWPreparedToRef, + Scratch: ScratchTakeCore, { - GLWECiphertext::external_product_inplace_scratch_space(module, &out_infos.glwe_layout(), ggsw_infos) + let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GGLWE<&[u8]> = &a.to_ref(); + let b: &GGSWPrepared<&[u8], BE> = &b.to_ref(); + + assert_eq!( + res.rank_in(), + a.rank_in(), + "res input rank_in: {} != a input rank_in: {}", + res.rank_in(), + a.rank_in() + ); + assert_eq!( + a.rank_out(), + b.rank(), + "a output rank_out: {} != b rank: {}", + a.rank_out(), + b.rank() + ); + assert_eq!( + res.rank_out(), + b.rank(), + "res output rank_out: {} != b rank: {}", + res.rank_out(), + b.rank() + ); + + for row in 0..res.dnum().into() { + for col in 0..res.rank_in().into() { + self.glwe_external_product(&mut res.at_mut(row, col), &a.at(row, col), b, scratch); + } + } + + for row in res.dnum().min(a.dnum()).into()..res.dnum().into() { + for col in 0..res.rank_in().into() { + res.at_mut(row, col).data_mut().zero(); + } + } + } + + fn gglwe_external_product_inplace(&self, res: &mut R, a: &A, scratch: &mut Scratch) + where + R: GGLWEToMut, + A: GGSWPreparedToRef, + Scratch: ScratchTakeCore, + { + let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GGSWPrepared<&[u8], BE> = &a.to_ref(); + + assert_eq!( + res.rank_out(), + a.rank(), + "res output rank: {} != a rank: {}", + res.rank_out(), + a.rank() + ); + + for row in 0..res.dnum().into() { + for col in 0..res.rank_in().into() { + self.glwe_external_product_inplace(&mut res.at_mut(row, col), a, scratch); + } + } } } -impl GGLWESwitchingKey { - pub fn external_product( - &mut self, - module: &Module, - lhs: &GGLWESwitchingKey, - rhs: &GGSWCiphertextPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxNormalizeTmpBytes - + VecZnxDftApply - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxIdftApplyConsume - + VecZnxBigNormalize - + VecZnxNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, +impl GGLWEExternalProduct for Module where Self: GLWEExternalProduct {} + +impl GLWESwitchingKey> { + pub fn external_product_tmp_bytes( + &self, + module: &M, + res_infos: &R, + a_infos: &A, + b_infos: &B, + ) -> usize + where + R: GGLWEInfos, + A: GGLWEInfos, + B: GGSWInfos, + M: GGLWEExternalProduct, { - #[cfg(debug_assertions)] - { - use crate::layouts::GLWEInfos; - - assert_eq!( - self.rank_in(), - lhs.rank_in(), - "ksk_out input rank: {} != ksk_in input rank: {}", - self.rank_in(), - lhs.rank_in() - ); - assert_eq!( - lhs.rank_out(), - rhs.rank(), - "ksk_in output rank: {} != ggsw rank: {}", - self.rank_out(), - rhs.rank() - ); - assert_eq!( - self.rank_out(), - rhs.rank(), - "ksk_out output rank: {} != ggsw rank: {}", - self.rank_out(), - rhs.rank() - ); - } - - (0..self.rank_in().into()).for_each(|col_i| { - (0..self.dnum().into()).for_each(|row_j| { - self.at_mut(row_j, col_i) - .external_product(module, &lhs.at(row_j, col_i), rhs, scratch); - }); - }); - - (self.dnum().min(lhs.dnum()).into()..self.dnum().into()).for_each(|row_i| { - (0..self.rank_in().into()).for_each(|col_j| { - self.at_mut(row_i, col_j).data.zero(); - }); - }); - } - - pub fn external_product_inplace( - &mut self, - module: &Module, - rhs: &GGSWCiphertextPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxNormalizeTmpBytes - + VecZnxDftApply - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxIdftApplyConsume - + VecZnxBigNormalize - + VecZnxNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, - { - #[cfg(debug_assertions)] - { - use crate::layouts::GLWEInfos; - - assert_eq!( - self.rank_out(), - rhs.rank(), - "ksk_out output rank: {} != ggsw rank: {}", - self.rank_out(), - rhs.rank() - ); - } - - (0..self.rank_in().into()).for_each(|col_i| { - (0..self.dnum().into()).for_each(|row_j| { - self.at_mut(row_j, col_i) - .external_product_inplace(module, rhs, scratch); - }); - }); + module.gglwe_external_product_tmp_bytes(res_infos, a_infos, b_infos) + } +} + +impl GLWESwitchingKey { + pub fn external_product(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch) + where + M: GGLWEExternalProduct, + A: GLWESwitchingKeyToRef, + B: GGSWPreparedToRef, + Scratch: ScratchTakeCore, + { + module.gglwe_external_product(&mut self.key, &a.to_ref().key, b, scratch); + } + + pub fn external_product_inplace(&mut self, module: &M, a: &A, scratch: &mut Scratch) + where + M: GGLWEExternalProduct, + A: GGSWPreparedToRef, + Scratch: ScratchTakeCore, + { + module.gglwe_external_product_inplace(&mut self.key, a, scratch); } } diff --git a/poulpy-core/src/external_product/ggsw_ct.rs b/poulpy-core/src/external_product/ggsw_ct.rs index a458de1..d3a59a6 100644 --- a/poulpy-core/src/external_product/ggsw_ct.rs +++ b/poulpy-core/src/external_product/ggsw_ct.rs @@ -1,143 +1,136 @@ use poulpy_hal::{ - api::{ - ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, - VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, - VmpApplyDftToDftTmpBytes, - }, - layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero}, + api::ScratchAvailable, + layouts::{Backend, DataMut, Module, Scratch, ZnxZero}, }; -use crate::layouts::{GGSWCiphertext, GGSWInfos, GLWECiphertext, GLWEInfos, prepared::GGSWCiphertextPrepared}; +use crate::{ + GLWEExternalProduct, ScratchTakeCore, + layouts::{ + GGSW, GGSWInfos, GGSWToMut, GGSWToRef, GLWEInfos, LWEInfos, + prepared::{GGSWPrepared, GGSWPreparedToRef}, + }, +}; -impl GGSWCiphertext> { - #[allow(clippy::too_many_arguments)] - pub fn external_product_scratch_space( - module: &Module, - out_infos: &OUT, - in_infos: &IN, - apply_infos: &GGSW, - ) -> usize +pub trait GGSWExternalProduct +where + Self: GLWEExternalProduct, +{ + fn ggsw_external_product_tmp_bytes(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize where - OUT: GGSWInfos, - IN: GGSWInfos, - GGSW: GGSWInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes, + R: GGSWInfos, + A: GGSWInfos, + B: GGSWInfos, { - GLWECiphertext::external_product_scratch_space( - module, - &out_infos.glwe_layout(), - &in_infos.glwe_layout(), - apply_infos, - ) + self.glwe_external_product_tmp_bytes(res_infos, a_infos, b_infos) } - pub fn external_product_inplace_scratch_space( - module: &Module, - out_infos: &OUT, - apply_infos: &GGSW, - ) -> usize + fn ggsw_external_product(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) where - OUT: GGSWInfos, - GGSW: GGSWInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes, + R: GGSWToMut, + A: GGSWToRef, + B: GGSWPreparedToRef, + Scratch: ScratchTakeCore, { - GLWECiphertext::external_product_inplace_scratch_space(module, &out_infos.glwe_layout(), apply_infos) + let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); + let a: &GGSW<&[u8]> = &a.to_ref(); + let b: &GGSWPrepared<&[u8], BE> = &b.to_ref(); + + assert_eq!( + res.rank(), + a.rank(), + "res rank: {} != a rank: {}", + res.rank(), + a.rank() + ); + assert_eq!( + res.rank(), + b.rank(), + "res rank: {} != b rank: {}", + res.rank(), + b.rank() + ); + + assert!(scratch.available() >= self.ggsw_external_product_tmp_bytes(res, a, b)); + + let min_dnum: usize = res.dnum().min(a.dnum()).into(); + + for row in 0..min_dnum { + for col in 0..(res.rank() + 1).into() { + self.glwe_external_product(&mut res.at_mut(row, col), &a.at(row, col), b, scratch); + } + } + + for row in min_dnum..res.dnum().into() { + for col in 0..(res.rank() + 1).into() { + res.at_mut(row, col).data.zero(); + } + } + } + + fn ggsw_external_product_inplace(&self, res: &mut R, a: &A, scratch: &mut Scratch) + where + R: GGSWToMut, + A: GGSWPreparedToRef, + Scratch: ScratchTakeCore, + { + let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); + let a: &GGSWPrepared<&[u8], BE> = &a.to_ref(); + + assert_eq!(res.n(), self.n() as u32); + assert_eq!(a.n(), self.n() as u32); + assert_eq!( + res.rank(), + a.rank(), + "res rank: {} != a rank: {}", + res.rank(), + a.rank() + ); + + for row in 0..res.dnum().into() { + for col in 0..(res.rank() + 1).into() { + self.glwe_external_product_inplace(&mut res.at_mut(row, col), a, scratch); + } + } } } -impl GGSWCiphertext { - pub fn external_product( - &mut self, - module: &Module, - lhs: &GGSWCiphertext, - rhs: &GGSWCiphertextPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxNormalizeTmpBytes - + VecZnxDftApply - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxIdftApplyConsume - + VecZnxBigNormalize - + VecZnxNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, +impl GGSWExternalProduct for Module where Self: GLWEExternalProduct {} + +impl GGSW> { + pub fn external_product_tmp_bytes( + &self, + module: &M, + res_infos: &R, + a_infos: &A, + b_infos: &B, + ) -> usize + where + R: GGSWInfos, + A: GGSWInfos, + B: GGSWInfos, + M: GGSWExternalProduct, { - #[cfg(debug_assertions)] - { - use crate::layouts::LWEInfos; - - assert_eq!(lhs.n(), self.n()); - assert_eq!(rhs.n(), self.n()); - - assert_eq!( - self.rank(), - lhs.rank(), - "ggsw_out rank: {} != ggsw_in rank: {}", - self.rank(), - lhs.rank() - ); - assert_eq!( - self.rank(), - rhs.rank(), - "ggsw_in rank: {} != ggsw_apply rank: {}", - self.rank(), - rhs.rank() - ); - - assert!(scratch.available() >= GGSWCiphertext::external_product_scratch_space(module, self, lhs, rhs)) - } - - let min_dnum: usize = self.dnum().min(lhs.dnum()).into(); - - (0..(self.rank() + 1).into()).for_each(|col_i| { - (0..min_dnum).for_each(|row_j| { - self.at_mut(row_j, col_i) - .external_product(module, &lhs.at(row_j, col_i), rhs, scratch); - }); - (min_dnum..self.dnum().into()).for_each(|row_i| { - self.at_mut(row_i, col_i).data.zero(); - }); - }); - } - - pub fn external_product_inplace( - &mut self, - module: &Module, - rhs: &GGSWCiphertextPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxNormalizeTmpBytes - + VecZnxDftApply - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxIdftApplyConsume - + VecZnxBigNormalize - + VecZnxNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, - { - #[cfg(debug_assertions)] - { - use crate::layouts::LWEInfos; - - assert_eq!(rhs.n(), self.n()); - assert_eq!( - self.rank(), - rhs.rank(), - "ggsw_out rank: {} != ggsw_apply: {}", - self.rank(), - rhs.rank() - ); - } - - (0..(self.rank() + 1).into()).for_each(|col_i| { - (0..self.dnum().into()).for_each(|row_j| { - self.at_mut(row_j, col_i) - .external_product_inplace(module, rhs, scratch); - }); - }); + module.ggsw_external_product_tmp_bytes(res_infos, a_infos, b_infos) + } +} + +impl GGSW { + pub fn external_product(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch) + where + M: GGSWExternalProduct, + A: GGSWToRef, + B: GGSWPreparedToRef, + Scratch: ScratchTakeCore, + { + module.ggsw_external_product(self, a, b, scratch); + } + + pub fn external_product_inplace(&mut self, module: &M, a: &A, scratch: &mut Scratch) + where + M: GGSWExternalProduct, + A: GGSWPreparedToRef, + Scratch: ScratchTakeCore, + { + module.ggsw_external_product_inplace(self, a, scratch); } } diff --git a/poulpy-core/src/external_product/glwe_ct.rs b/poulpy-core/src/external_product/glwe_ct.rs index d764507..ab9968c 100644 --- a/poulpy-core/src/external_product/glwe_ct.rs +++ b/poulpy-core/src/external_product/glwe_ct.rs @@ -1,102 +1,57 @@ use poulpy_hal::{ api::{ - ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, - VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, - VmpApplyDftToDftTmpBytes, + ModuleN, ScratchTakeBasic, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, + VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, - layouts::{Backend, DataMut, DataRef, DataViewMut, Module, Scratch, VecZnx, VecZnxBig}, + layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnx, VecZnxBig}, }; use crate::{ - GLWEExternalProduct, GLWEExternalProductInplace, + ScratchTakeCore, layouts::{ - GGSWInfos, GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEInfos, LWEInfos, - prepared::{GGSWCiphertextPrepared, GGSWCiphertextPreparedToRef}, + GGSWInfos, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, + prepared::{GGSWPrepared, GGSWPreparedToRef}, }, }; -impl GLWECiphertext> { - #[allow(clippy::too_many_arguments)] - pub fn external_product_scratch_space( - module: &Module, - out_infos: &OUT, - in_infos: &IN, - apply_infos: &GGSW, - ) -> usize +impl GLWE> { + pub fn external_product_tmp_bytes(module: &M, res_infos: &R, a_infos: &A, b_infos: &B) -> usize where - OUT: GLWEInfos, - IN: GLWEInfos, - GGSW: GGSWInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes, + R: GLWEInfos, + A: GLWEInfos, + B: GGSWInfos, + M: GLWEExternalProduct, { - let in_size: usize = in_infos - .k() - .div_ceil(apply_infos.base2k()) - .div_ceil(apply_infos.dsize().into()) as usize; - let out_size: usize = out_infos.size(); - let ggsw_size: usize = apply_infos.size(); - let res_dft: usize = module.vec_znx_dft_alloc_bytes((apply_infos.rank() + 1).into(), ggsw_size); - let a_dft: usize = module.vec_znx_dft_alloc_bytes((apply_infos.rank() + 1).into(), in_size); - let vmp: usize = module.vmp_apply_dft_to_dft_tmp_bytes( - out_size, - in_size, - in_size, // rows - (apply_infos.rank() + 1).into(), // cols in - (apply_infos.rank() + 1).into(), // cols out - ggsw_size, - ); - let normalize_big: usize = module.vec_znx_normalize_tmp_bytes(); - - if in_infos.base2k() == apply_infos.base2k() { - res_dft + a_dft + (vmp | normalize_big) - } else { - let normalize_conv: usize = VecZnx::alloc_bytes(module.n(), (apply_infos.rank() + 1).into(), in_size); - res_dft + ((a_dft + normalize_conv + (module.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big) - } - } - - pub fn external_product_inplace_scratch_space( - module: &Module, - out_infos: &OUT, - apply_infos: &GGSW, - ) -> usize - where - OUT: GLWEInfos, - GGSW: GGSWInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes, - { - Self::external_product_scratch_space(module, out_infos, out_infos, apply_infos) + module.glwe_external_product_tmp_bytes(res_infos, a_infos, b_infos) } } -impl GLWECiphertext { - pub fn external_product( - &mut self, - module: &Module, - lhs: &GLWECiphertext, - rhs: &GGSWCiphertextPrepared, - scratch: &mut Scratch, - ) where - Module: GLWEExternalProduct, +impl GLWE { + pub fn external_product(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch) + where + A: GLWEToRef, + B: GGSWPreparedToRef, + M: GLWEExternalProduct, + Scratch: ScratchTakeCore, { - module.external_product(self, lhs, rhs, scratch); + module.glwe_external_product(self, a, b, scratch); } - pub fn external_product_inplace( - &mut self, - module: &Module, - rhs: &GGSWCiphertextPrepared, - scratch: &mut Scratch, - ) where - Module: GLWEExternalProductInplace, + pub fn external_product_inplace(&mut self, module: &M, a: &A, scratch: &mut Scratch) + where + A: GGSWPreparedToRef, + M: GLWEExternalProduct, + Scratch: ScratchTakeCore, { - module.external_product_inplace(self, rhs, scratch); + module.glwe_external_product_inplace(self, a, scratch); } } -impl GLWEExternalProductInplace for Module +pub trait GLWEExternalProduct where - Module: VecZnxDftAllocBytes + Self: Sized + + ModuleN + + VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes + VecZnxDftApply @@ -105,15 +60,47 @@ where + VecZnxIdftApplyConsume + VecZnxBigNormalize + VecZnxNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { - fn external_product_inplace(&self, res: &mut R, ggsw: &D, scratch: &mut Scratch) + fn glwe_external_product_tmp_bytes(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize where - R: GLWECiphertextToMut, - D: GGSWCiphertextPreparedToRef, + R: GLWEInfos, + A: GLWEInfos, + B: GGSWInfos, { - let res: &mut GLWECiphertext<&mut [u8]> = &mut res.to_mut(); - let rhs: &GGSWCiphertextPrepared<&[u8], BE> = &ggsw.to_ref(); + let in_size: usize = a_infos + .k() + .div_ceil(b_infos.base2k()) + .div_ceil(b_infos.dsize().into()) as usize; + let out_size: usize = res_infos.size(); + let ggsw_size: usize = b_infos.size(); + let res_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), ggsw_size); + let a_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), in_size); + let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes( + out_size, + in_size, + in_size, // rows + (b_infos.rank() + 1).into(), // cols in + (b_infos.rank() + 1).into(), // cols out + ggsw_size, + ); + let normalize_big: usize = self.vec_znx_normalize_tmp_bytes(); + + if a_infos.base2k() == b_infos.base2k() { + res_dft + a_dft + (vmp | normalize_big) + } else { + let normalize_conv: usize = VecZnx::bytes_of(self.n(), (b_infos.rank() + 1).into(), in_size); + res_dft + ((a_dft + normalize_conv + (self.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big) + } + } + + fn glwe_external_product_inplace(&self, res: &mut R, a: &D, scratch: &mut Scratch) + where + R: GLWEToMut, + D: GGSWPreparedToRef, + Scratch: ScratchTakeCore, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let rhs: &GGSWPrepared<&[u8], BE> = &a.to_ref(); let basek_in: usize = res.base2k().into(); let basek_ggsw: usize = rhs.base2k().into(); @@ -124,15 +111,15 @@ where assert_eq!(rhs.rank(), res.rank()); assert_eq!(rhs.n(), res.n()); - assert!(scratch.available() >= GLWECiphertext::external_product_inplace_scratch_space(self, res, rhs)); + assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, res, rhs)); } let cols: usize = (rhs.rank() + 1).into(); let dsize: usize = rhs.dsize().into(); let a_size: usize = (res.size() * basek_in).div_ceil(basek_ggsw); - let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(res.n().into(), cols, rhs.size()); // Todo optimise - let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(res.n().into(), cols, a_size.div_ceil(dsize)); + let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, rhs.size()); // Todo optimise + let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, cols, a_size.div_ceil(dsize)); a_dft.data_mut().fill(0); if basek_in == basek_ggsw { @@ -160,7 +147,7 @@ where } } } else { - let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self.n(), cols, a_size); + let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self, cols, a_size); for j in 0..cols { self.vec_znx_normalize( @@ -213,31 +200,18 @@ where ); } } -} -impl GLWEExternalProduct for Module -where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxNormalizeTmpBytes - + VecZnxDftApply - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxIdftApplyConsume - + VecZnxBigNormalize - + VecZnxNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, -{ - fn external_product(&self, res: &mut R, lhs: &A, rhs: &D, scratch: &mut Scratch) + fn glwe_external_product(&self, res: &mut R, lhs: &A, rhs: &D, scratch: &mut Scratch) where - R: GLWECiphertextToMut, - A: GLWECiphertextToRef, - D: GGSWCiphertextPreparedToRef, + R: GLWEToMut, + A: GLWEToRef, + D: GGSWPreparedToRef, + Scratch: ScratchTakeCore, { - let res: &mut GLWECiphertext<&mut [u8]> = &mut res.to_mut(); - let lhs: &GLWECiphertext<&[u8]> = &lhs.to_ref(); + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let lhs: &GLWE<&[u8]> = &lhs.to_ref(); - let rhs: &GGSWCiphertextPrepared<&[u8], BE> = &rhs.to_ref(); + let rhs: &GGSWPrepared<&[u8], BE> = &rhs.to_ref(); let basek_in: usize = lhs.base2k().into(); let basek_ggsw: usize = rhs.base2k().into(); @@ -251,7 +225,7 @@ where assert_eq!(rhs.rank(), res.rank()); assert_eq!(rhs.n(), res.n()); assert_eq!(lhs.n(), res.n()); - assert!(scratch.available() >= GLWECiphertext::external_product_scratch_space(self, res, lhs, rhs)); + assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, lhs, rhs)); } let cols: usize = (rhs.rank() + 1).into(); @@ -259,8 +233,8 @@ where let a_size: usize = (lhs.size() * basek_in).div_ceil(basek_ggsw); - let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), cols, rhs.size()); // Todo optimise - let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self.n(), cols, a_size.div_ceil(dsize)); + let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, rhs.size()); // Todo optimise + let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, cols, a_size.div_ceil(dsize)); a_dft.data_mut().fill(0); if basek_in == basek_ggsw { @@ -288,7 +262,7 @@ where } } } else { - let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self.n(), cols, a_size); + let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self, cols, a_size); for j in 0..cols { self.vec_znx_normalize( @@ -342,3 +316,20 @@ where }); } } + +impl GLWEExternalProduct for Module where + Self: ModuleN + + VecZnxDftBytesOf + + VmpApplyDftToDftTmpBytes + + VecZnxNormalizeTmpBytes + + VecZnxDftApply + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VecZnxIdftApplyConsume + + VecZnxBigNormalize + + VecZnxNormalize + + VecZnxDftBytesOf + + VmpApplyDftToDftTmpBytes + + VecZnxNormalizeTmpBytes +{ +} diff --git a/poulpy-core/src/external_product/mod.rs b/poulpy-core/src/external_product/mod.rs index f46b82c..532406b 100644 --- a/poulpy-core/src/external_product/mod.rs +++ b/poulpy-core/src/external_product/mod.rs @@ -1,23 +1,8 @@ -use poulpy_hal::layouts::{Backend, Scratch}; - -use crate::layouts::{GLWECiphertextToMut, GLWECiphertextToRef, prepared::GGSWCiphertextPreparedToRef}; - mod gglwe_atk; mod gglwe_ksk; mod ggsw_ct; mod glwe_ct; -pub trait GLWEExternalProduct { - fn external_product(&self, res: &mut R, a: &A, ggsw: &D, scratch: &mut Scratch) - where - R: GLWECiphertextToMut, - A: GLWECiphertextToRef, - D: GGSWCiphertextPreparedToRef; -} - -pub trait GLWEExternalProductInplace { - fn external_product_inplace(&self, res: &mut R, ggsw: &D, scratch: &mut Scratch) - where - R: GLWECiphertextToMut, - D: GGSWCiphertextPreparedToRef; -} +pub use gglwe_ksk::*; +pub use ggsw_ct::*; +pub use glwe_ct::*; diff --git a/poulpy-core/src/glwe_packing.rs b/poulpy-core/src/glwe_packing.rs index 3da962a..7dacb97 100644 --- a/poulpy-core/src/glwe_packing.rs +++ b/poulpy-core/src/glwe_packing.rs @@ -2,18 +2,18 @@ use std::collections::HashMap; use poulpy_hal::{ api::{ - ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, - VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxCopy, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNegateInplace, - VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, - VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + ScratchAvailable, VecZnxAddInplace, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace, + VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxCopy, VecZnxDftApply, + VecZnxDftBytesOf, VecZnxDftCopy, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNegateInplace, VecZnxNormalize, + VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, + VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, Module, Scratch}, }; use crate::{ - GLWEOperations, TakeGLWECt, - layouts::{GGLWEInfos, GLWECiphertext, GLWEInfos, LWEInfos, prepared::GGLWEAutomorphismKeyPrepared}, + GLWEOperations, + layouts::{GGLWEInfos, GLWE, GLWEInfos, LWEInfos, prepared::AutomorphismKeyPrepared}, }; /// [GLWEPacker] enables only the fly GLWE packing @@ -29,7 +29,7 @@ pub struct GLWEPacker { /// [Accumulator] stores intermediate packing result. /// There are Log(N) such accumulators in a [GLWEPacker]. struct Accumulator { - data: GLWECiphertext>, + data: GLWE>, value: bool, // Implicit flag for zero ciphertext control: bool, // Can be combined with incoming value } @@ -43,12 +43,12 @@ impl Accumulator { /// * `base2k`: base 2 logarithm of the GLWE ciphertext in memory digit representation. /// * `k`: base 2 precision of the GLWE ciphertext precision over the Torus. /// * `rank`: rank of the GLWE ciphertext. - pub fn alloc(infos: &A) -> Self + pub fn alloc(module: &Module, infos: &A) -> Self where A: GLWEInfos, { Self { - data: GLWECiphertext::alloc(infos), + data: GLWE::alloc_from_infos(module, infos), value: false, control: false, } @@ -66,13 +66,13 @@ impl GLWEPacker { /// and N GLWE ciphertext can be packed. With `log_batch=2` all coefficients /// which are multiples of X^{N/4} are packed. Meaning that N/4 ciphertexts /// can be packed. - pub fn new(infos: &A, log_batch: usize) -> Self + pub fn new(module: Module, infos: &A, log_batch: usize) -> Self where A: GLWEInfos, { let mut accumulators: Vec = Vec::::new(); let log_n: usize = infos.n().log2(); - (0..log_n - log_batch).for_each(|_| accumulators.push(Accumulator::alloc(infos))); + (0..log_n - log_batch).for_each(|_| accumulators.push(Accumulator::alloc(module, infos))); Self { accumulators, log_batch, @@ -90,17 +90,17 @@ impl GLWEPacker { } /// Number of scratch space bytes required to call [Self::add]. - pub fn scratch_space(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize + pub fn tmp_bytes(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize where OUT: GLWEInfos, KEY: GGLWEInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, + Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - pack_core_scratch_space(module, out_infos, key_infos) + pack_core_tmp_bytes(module, out_infos, key_infos) } pub fn galois_elements(module: &Module) -> Vec { - GLWECiphertext::trace_galois_elements(module) + GLWE::trace_galois_elements(module) } /// Adds a GLWE ciphertext to the [GLWEPacker]. @@ -111,15 +111,15 @@ impl GLWEPacker { /// of packed ciphertexts reaches N/2^log_batch is a result written. /// * `a`: ciphertext to pack. Can optionally give None to pack a 0 ciphertext. /// * `auto_keys`: a [HashMap] containing the [AutomorphismKeyExec]s. - /// * `scratch`: scratch space of size at least [Self::scratch_space]. + /// * `scratch`: scratch space of size at least [Self::tmp_bytes]. pub fn add( &mut self, module: &Module, - a: Option<&GLWECiphertext>, - auto_keys: &HashMap>, + a: Option<&GLWE>, + auto_keys: &HashMap>, scratch: &mut Scratch, ) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft @@ -142,7 +142,7 @@ impl GLWEPacker { + VecZnxBigAutomorphismInplace + VecZnxNormalize + VecZnxNormalizeTmpBytes, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + Scratch: ScratchAvailable, { assert!( (self.counter as u32) < self.accumulators[0].data.n(), @@ -162,7 +162,7 @@ impl GLWEPacker { } /// Flush result to`res`. - pub fn flush(&mut self, module: &Module, res: &mut GLWECiphertext) + pub fn flush(&mut self, module: &Module, res: &mut GLWE) where Module: VecZnxCopy, { @@ -177,24 +177,24 @@ impl GLWEPacker { } } -fn pack_core_scratch_space(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize +fn pack_core_tmp_bytes(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize where OUT: GLWEInfos, KEY: GGLWEInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, + Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - combine_scratch_space(module, out_infos, key_infos) + combine_tmp_bytes(module, out_infos, key_infos) } fn pack_core( module: &Module, - a: Option<&GLWECiphertext>, + a: Option<&GLWE>, accumulators: &mut [Accumulator], i: usize, - auto_keys: &HashMap>, + auto_keys: &HashMap>, scratch: &mut Scratch, ) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft @@ -217,7 +217,7 @@ fn pack_core( + VecZnxBigAutomorphismInplace + VecZnxNormalize + VecZnxNormalizeTmpBytes, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + Scratch: ScratchAvailable, { let log_n: usize = module.log_n(); @@ -258,7 +258,7 @@ fn pack_core( } else { pack_core( module, - None::<&GLWECiphertext>>, + None::<&GLWE>>, acc_next, i + 1, auto_keys, @@ -268,27 +268,26 @@ fn pack_core( } } -fn combine_scratch_space(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize +fn combine_tmp_bytes(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize where OUT: GLWEInfos, KEY: GGLWEInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, + Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - GLWECiphertext::alloc_bytes(out_infos) - + (GLWECiphertext::rsh_scratch_space(module.n()) - | GLWECiphertext::automorphism_inplace_scratch_space(module, out_infos, key_infos)) + GLWE::bytes_of_from_infos(module, out_infos) + + (GLWE::rsh_tmp_bytes(module.n()) | GLWE::automorphism_inplace_tmp_bytes(module, out_infos, key_infos)) } /// [combine] merges two ciphertexts together. fn combine( module: &Module, acc: &mut Accumulator, - b: Option<&GLWECiphertext>, + b: Option<&GLWE>, i: usize, - auto_keys: &HashMap>, + auto_keys: &HashMap>, scratch: &mut Scratch, ) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft @@ -311,10 +310,10 @@ fn combine( + VecZnxBigAutomorphismInplace + VecZnxNormalize + VecZnxNormalizeTmpBytes, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeGLWECt, + Scratch: ScratchAvailable, { let log_n: usize = acc.data.n().log2(); - let a: &mut GLWECiphertext> = &mut acc.data; + let a: &mut GLWE> = &mut acc.data; let gal_el: i64 = if i == 0 { -1 @@ -395,9 +394,9 @@ fn combine( /// to [0: GLWE(m_0 * X^x_0 + m_1 * X^x_1 + ... + m_i * X^x_i)] pub fn glwe_packing( module: &Module, - cts: &mut HashMap>, + cts: &mut HashMap>, log_gap_out: usize, - auto_keys: &HashMap>, + auto_keys: &HashMap>, scratch: &mut Scratch, ) where ATK: DataRef, @@ -414,7 +413,7 @@ pub fn glwe_packing( + VecZnxNegateInplace + VecZnxCopy + VecZnxSubInplace - + VecZnxDftAllocBytes + + VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft @@ -427,7 +426,7 @@ pub fn glwe_packing( + VecZnxBigSubSmallNegateInplace + VecZnxRotate + VecZnxNormalize, - Scratch: TakeVecZnx + TakeVecZnxDft + ScratchAvailable, + Scratch: ScratchAvailable, { #[cfg(debug_assertions)] { @@ -439,15 +438,15 @@ pub fn glwe_packing( (0..log_n - log_gap_out).for_each(|i| { let t: usize = (1 << log_n).min(1 << (log_n - 1 - i)); - let auto_key: &GGLWEAutomorphismKeyPrepared = if i == 0 { + let auto_key: &AutomorphismKeyPrepared = if i == 0 { auto_keys.get(&-1).unwrap() } else { auto_keys.get(&module.galois_element(1 << (i - 1))).unwrap() }; (0..t).for_each(|j| { - let mut a: Option<&mut GLWECiphertext> = cts.remove(&j); - let mut b: Option<&mut GLWECiphertext> = cts.remove(&(j + t)); + let mut a: Option<&mut GLWE> = cts.remove(&j); + let mut b: Option<&mut GLWE> = cts.remove(&(j + t)); pack_internal(module, &mut a, &mut b, i, auto_key, scratch); @@ -463,10 +462,10 @@ pub fn glwe_packing( #[allow(clippy::too_many_arguments)] fn pack_internal( module: &Module, - a: &mut Option<&mut GLWECiphertext>, - b: &mut Option<&mut GLWECiphertext>, + a: &mut Option<&mut GLWE>, + b: &mut Option<&mut GLWE>, i: usize, - auto_key: &GGLWEAutomorphismKeyPrepared, + auto_key: &AutomorphismKeyPrepared, scratch: &mut Scratch, ) where Module: VecZnxRotateInplace @@ -481,7 +480,7 @@ fn pack_internal( + VecZnxNegateInplace + VecZnxCopy + VecZnxSubInplace - + VecZnxDftAllocBytes + + VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft @@ -494,7 +493,7 @@ fn pack_internal( + VecZnxBigSubSmallNegateInplace + VecZnxRotate + VecZnxNormalize, - Scratch: TakeVecZnx + TakeVecZnxDft + ScratchAvailable, + Scratch: ScratchAvailable, { // Goal is to evaluate: a = a + b*X^t + phi(a - b*X^t)) // We also use the identity: AUTO(a * X^t, g) = -X^t * AUTO(a, g) diff --git a/poulpy-core/src/glwe_trace.rs b/poulpy-core/src/glwe_trace.rs index 4e1769e..36aabb9 100644 --- a/poulpy-core/src/glwe_trace.rs +++ b/poulpy-core/src/glwe_trace.rs @@ -2,22 +2,19 @@ use std::collections::HashMap; use poulpy_hal::{ api::{ - ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace, VecZnxBigNormalize, - VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalize, - VecZnxNormalizeTmpBytes, VecZnxRshInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + ScratchAvailable, VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, + VecZnxCopy, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, + VecZnxRshInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx}, }; use crate::{ - TakeGLWECt, - layouts::{ - Base2K, GGLWEInfos, GLWECiphertext, GLWECiphertextLayout, GLWEInfos, LWEInfos, prepared::GGLWEAutomorphismKeyPrepared, - }, + layouts::{Base2K, GGLWEInfos, GLWE, GLWEInfos, GLWELayout, LWEInfos, prepared::AutomorphismKeyPrepared}, operations::GLWEOperations, }; -impl GLWECiphertext> { +impl GLWE> { pub fn trace_galois_elements(module: &Module) -> Vec { let mut gal_els: Vec = Vec::new(); (0..module.log_n()).for_each(|i| { @@ -30,21 +27,16 @@ impl GLWECiphertext> { gal_els } - pub fn trace_scratch_space( - module: &Module, - out_infos: &OUT, - in_infos: &IN, - key_infos: &KEY, - ) -> usize + pub fn trace_tmp_bytes(module: &Module, out_infos: &OUT, in_infos: &IN, key_infos: &KEY) -> usize where OUT: GLWEInfos, IN: GLWEInfos, KEY: GGLWEInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, + Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - let trace: usize = Self::automorphism_inplace_scratch_space(module, out_infos, key_infos); + let trace: usize = Self::automorphism_inplace_tmp_bytes(module, out_infos, key_infos); if in_infos.base2k() != key_infos.base2k() { - let glwe_conv: usize = VecZnx::alloc_bytes( + let glwe_conv: usize = VecZnx::bytes_of( module.n(), (key_infos.rank_out() + 1).into(), out_infos.k().min(in_infos.k()).div_ceil(key_infos.base2k()) as usize, @@ -55,27 +47,27 @@ impl GLWECiphertext> { trace } - pub fn trace_inplace_scratch_space(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize + pub fn trace_inplace_tmp_bytes(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize where OUT: GLWEInfos, KEY: GGLWEInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, + Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - Self::trace_scratch_space(module, out_infos, out_infos, key_infos) + Self::trace_tmp_bytes(module, out_infos, out_infos, key_infos) } } -impl GLWECiphertext { +impl GLWE { pub fn trace( &mut self, module: &Module, start: usize, end: usize, - lhs: &GLWECiphertext, - auto_keys: &HashMap>, + lhs: &GLWE, + auto_keys: &HashMap>, scratch: &mut Scratch, ) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft @@ -89,7 +81,7 @@ impl GLWECiphertext { + VecZnxCopy + VecZnxNormalizeTmpBytes + VecZnxNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + Scratch: ScratchAvailable, { self.copy(module, lhs); self.trace_inplace(module, start, end, auto_keys, scratch); @@ -100,10 +92,10 @@ impl GLWECiphertext { module: &Module, start: usize, end: usize, - auto_keys: &HashMap>, + auto_keys: &HashMap>, scratch: &mut Scratch, ) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft @@ -116,7 +108,7 @@ impl GLWECiphertext { + VecZnxRshInplace + VecZnxNormalizeTmpBytes + VecZnxNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + Scratch: ScratchAvailable, { let basek_ksk: Base2K = auto_keys .get(auto_keys.keys().next().unwrap()) @@ -137,7 +129,7 @@ impl GLWECiphertext { } if self.base2k() != basek_ksk { - let (mut self_conv, scratch_1) = scratch.take_glwe_ct(&GLWECiphertextLayout { + let (mut self_conv, scratch_1) = scratch.take_glwe_ct(&GLWELayout { n: module.n().into(), base2k: basek_ksk, k: self.k(), diff --git a/poulpy-core/src/keyswitching/gglwe_ct.rs b/poulpy-core/src/keyswitching/gglwe_ct.rs index 9f6caa5..edda267 100644 --- a/poulpy-core/src/keyswitching/gglwe_ct.rs +++ b/poulpy-core/src/keyswitching/gglwe_ct.rs @@ -1,224 +1,205 @@ -use poulpy_hal::{ - api::{ - ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, - VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, +use poulpy_hal::layouts::{Backend, DataMut, Module, Scratch}; + +use crate::{ + ScratchTakeCore, + keyswitching::glwe_ct::GLWEKeySwitch, + layouts::{ + AutomorphismKey, AutomorphismKeyToRef, GGLWE, GGLWEInfos, GGLWEToMut, GGLWEToRef, GLWESwitchingKey, + GLWESwitchingKeyToRef, + prepared::{GLWESwitchingKeyPrepared, GLWESwitchingKeyPreparedToRef}, }, - layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero}, }; -use crate::layouts::{ - GGLWEAutomorphismKey, GGLWEInfos, GGLWESwitchingKey, GLWECiphertext, GLWEInfos, - prepared::{GGLWEAutomorphismKeyPrepared, GGLWESwitchingKeyPrepared}, -}; - -impl GGLWEAutomorphismKey> { - pub fn keyswitch_scratch_space( - module: &Module, - out_infos: &OUT, - in_infos: &IN, - key_infos: &KEY, - ) -> usize +impl AutomorphismKey> { + pub fn keyswitch_inplace_tmp_bytes(module: &M, res_infos: &R, a_infos: &A, key_infos: &K) -> usize where - OUT: GGLWEInfos, - IN: GGLWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, + R: GGLWEInfos, + A: GGLWEInfos, + K: GGLWEInfos, + M: GGLWEKeySwitch, { - GGLWESwitchingKey::keyswitch_scratch_space(module, out_infos, in_infos, key_infos) - } - - pub fn keyswitch_inplace_scratch_space(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize - where - OUT: GGLWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, - { - GGLWESwitchingKey::keyswitch_inplace_scratch_space(module, out_infos, key_infos) + module.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos) } } -impl GGLWEAutomorphismKey { - pub fn keyswitch( - &mut self, - module: &Module, - lhs: &GGLWEAutomorphismKey, - rhs: &GGLWESwitchingKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, +impl AutomorphismKey { + pub fn keyswitch(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch) + where + A: AutomorphismKeyToRef, + B: GLWESwitchingKeyPreparedToRef, + Scratch: ScratchTakeCore, + M: GGLWEKeySwitch, { - self.key.keyswitch(module, &lhs.key, rhs, scratch); + module.gglwe_keyswitch(&mut self.key.key, &a.to_ref().key.key, b, scratch); } - pub fn keyswitch_inplace( - &mut self, - module: &Module, - rhs: &GGLWEAutomorphismKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + pub fn keyswitch_inplace(&mut self, module: &M, a: &A, scratch: &mut Scratch) + where + A: GLWESwitchingKeyPreparedToRef, + Scratch: ScratchTakeCore, + M: GGLWEKeySwitch, { - self.key.keyswitch_inplace(module, &rhs.key, scratch); + module.gglwe_keyswitch_inplace(&mut self.key.key, a, scratch); } } -impl GGLWESwitchingKey> { - pub fn keyswitch_scratch_space( - module: &Module, - out_infos: &OUT, - in_infos: &IN, - key_apply: &KEY, - ) -> usize +impl GLWESwitchingKey> { + pub fn keyswitch_inplace_tmp_bytes(module: &M, res_infos: &R, a_infos: &A, key_infos: &K) -> usize where - OUT: GGLWEInfos, - IN: GGLWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, + R: GGLWEInfos, + A: GGLWEInfos, + K: GGLWEInfos, + M: GGLWEKeySwitch, { - GLWECiphertext::keyswitch_scratch_space(module, out_infos, in_infos, key_apply) - } - - pub fn keyswitch_inplace_scratch_space(module: &Module, out_infos: &OUT, key_apply: &KEY) -> usize - where - OUT: GGLWEInfos + GLWEInfos, - KEY: GGLWEInfos + GLWEInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, - { - GLWECiphertext::keyswitch_inplace_scratch_space(module, out_infos, key_apply) + module.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos) } } -impl GGLWESwitchingKey { - pub fn keyswitch( - &mut self, - module: &Module, - lhs: &GGLWESwitchingKey, - rhs: &GGLWESwitchingKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnx, +impl GLWESwitchingKey { + pub fn keyswitch(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch) + where + A: GLWESwitchingKeyToRef, + B: GLWESwitchingKeyPreparedToRef, + Scratch: ScratchTakeCore, + M: GGLWEKeySwitch, { - #[cfg(debug_assertions)] - { - assert_eq!( - self.rank_in(), - lhs.rank_in(), - "ksk_out input rank: {} != ksk_in input rank: {}", - self.rank_in(), - lhs.rank_in() - ); - assert_eq!( - lhs.rank_out(), - rhs.rank_in(), - "ksk_in output rank: {} != ksk_apply input rank: {}", - self.rank_out(), - rhs.rank_in() - ); - assert_eq!( - self.rank_out(), - rhs.rank_out(), - "ksk_out output rank: {} != ksk_apply output rank: {}", - self.rank_out(), - rhs.rank_out() - ); - assert!( - self.dnum() <= lhs.dnum(), - "self.dnum()={} > lhs.dnum()={}", - self.dnum(), - lhs.dnum() - ); - assert_eq!( - self.dsize(), - lhs.dsize(), - "ksk_out dsize: {} != ksk_in dsize: {}", - self.dsize(), - lhs.dsize() - ) + module.gglwe_keyswitch(&mut self.key, &a.to_ref().key, b, scratch); + } + + pub fn keyswitch_inplace(&mut self, module: &M, a: &A, scratch: &mut Scratch) + where + A: GLWESwitchingKeyPreparedToRef, + Scratch: ScratchTakeCore, + M: GGLWEKeySwitch, + { + module.gglwe_keyswitch_inplace(&mut self.key, a, scratch); + } +} + +impl GGLWE> { + pub fn keyswitch_inplace_tmp_bytes(module: &M, res_infos: &R, a_infos: &A, key_infos: &K) -> usize + where + R: GGLWEInfos, + A: GGLWEInfos, + K: GGLWEInfos, + M: GGLWEKeySwitch, + { + module.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos) + } +} + +impl GGLWE { + pub fn keyswitch(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch) + where + A: GGLWEToRef, + B: GLWESwitchingKeyPreparedToRef, + Scratch: ScratchTakeCore, + M: GGLWEKeySwitch, + { + module.gglwe_keyswitch(self, a, b, scratch); + } + + pub fn keyswitch_inplace(&mut self, module: &M, a: &A, scratch: &mut Scratch) + where + A: GLWESwitchingKeyPreparedToRef, + Scratch: ScratchTakeCore, + M: GGLWEKeySwitch, + { + module.gglwe_keyswitch_inplace(self, a, scratch); + } +} + +impl GGLWEKeySwitch for Module where Self: GLWEKeySwitch {} + +pub trait GGLWEKeySwitch +where + Self: GLWEKeySwitch, +{ + fn gglwe_keyswitch_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize + where + R: GGLWEInfos, + A: GGLWEInfos, + K: GGLWEInfos, + { + self.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos) + } + + fn gglwe_keyswitch(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) + where + R: GGLWEToMut, + A: GGLWEToRef, + B: GLWESwitchingKeyPreparedToRef, + Scratch: ScratchTakeCore, + { + let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GGLWE<&[u8]> = &a.to_ref(); + let b: &GLWESwitchingKeyPrepared<&[u8], BE> = &b.to_ref(); + + assert_eq!( + res.rank_in(), + a.rank_in(), + "res input rank: {} != a input rank: {}", + res.rank_in(), + a.rank_in() + ); + assert_eq!( + a.rank_out(), + b.rank_in(), + "res output rank: {} != b input rank: {}", + a.rank_out(), + b.rank_in() + ); + assert_eq!( + res.rank_out(), + b.rank_out(), + "res output rank: {} != b output rank: {}", + res.rank_out(), + b.rank_out() + ); + assert!( + res.dnum() <= a.dnum(), + "res.dnum()={} > a.dnum()={}", + res.dnum(), + a.dnum() + ); + assert_eq!( + res.dsize(), + a.dsize(), + "res dsize: {} != a dsize: {}", + res.dsize(), + a.dsize() + ); + + for row in 0..res.dnum().into() { + for col in 0..res.rank_in().into() { + self.glwe_keyswitch(&mut res.at_mut(row, col), &a.at(row, col), b, scratch); + } } - - (0..self.rank_in().into()).for_each(|col_i| { - (0..self.dnum().into()).for_each(|row_j| { - self.at_mut(row_j, col_i) - .keyswitch(module, &lhs.at(row_j, col_i), rhs, scratch); - }); - }); - - (self.dnum().min(lhs.dnum()).into()..self.dnum().into()).for_each(|row_i| { - (0..self.rank_in().into()).for_each(|col_j| { - self.at_mut(row_i, col_j).data.zero(); - }); - }); } - pub fn keyswitch_inplace( - &mut self, - module: &Module, - rhs: &GGLWESwitchingKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnx, + fn gglwe_keyswitch_inplace(&self, res: &mut R, a: &A, scratch: &mut Scratch) + where + R: GGLWEToMut, + A: GLWESwitchingKeyPreparedToRef, + Scratch: ScratchTakeCore, { - #[cfg(debug_assertions)] - { - assert_eq!( - self.rank_out(), - rhs.rank_out(), - "ksk_out output rank: {} != ksk_apply output rank: {}", - self.rank_out(), - rhs.rank_out() - ); - } + let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWESwitchingKeyPrepared<&[u8], BE> = &a.to_ref(); - (0..self.rank_in().into()).for_each(|col_i| { - (0..self.dnum().into()).for_each(|row_j| { - self.at_mut(row_j, col_i) - .keyswitch_inplace(module, rhs, scratch) - }); - }); + assert_eq!( + res.rank_out(), + a.rank_out(), + "res output rank: {} != a output rank: {}", + res.rank_out(), + a.rank_out() + ); + + for row in 0..res.dnum().into() { + for col in 0..res.rank_in().into() { + self.glwe_keyswitch_inplace(&mut res.at_mut(row, col), a, scratch); + } + } } } + +impl GLWESwitchingKey {} diff --git a/poulpy-core/src/keyswitching/ggsw_ct.rs b/poulpy-core/src/keyswitching/ggsw_ct.rs index d261f03..cfb4d8e 100644 --- a/poulpy-core/src/keyswitching/ggsw_ct.rs +++ b/poulpy-core/src/keyswitching/ggsw_ct.rs @@ -1,366 +1,131 @@ -use poulpy_hal::{ - api::{ - ScratchAvailable, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, - VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftApply, - VecZnxDftCopy, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, - VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, - }, - layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx, VmpPMat, ZnxInfos}, -}; +use poulpy_hal::layouts::{Backend, DataMut, Scratch, VecZnx}; use crate::{ + GGSWExpandRows, ScratchTakeCore, + keyswitching::glwe_ct::GLWEKeySwitch, layouts::{ - GGLWECiphertext, GGLWEInfos, GGSWCiphertext, GGSWInfos, GLWECiphertext, GLWEInfos, LWEInfos, - prepared::{GGLWESwitchingKeyPrepared, GGLWETensorKeyPrepared}, + GGLWEInfos, GGSW, GGSWInfos, GGSWToMut, GGSWToRef, + prepared::{GLWESwitchingKeyPreparedToRef, TensorKeyPreparedToRef}, }, - operations::GLWEOperations, }; -impl GGSWCiphertext> { - pub(crate) fn expand_row_scratch_space(module: &Module, out_infos: &OUT, tsk_infos: &TSK) -> usize - where - OUT: GGSWInfos, - TSK: GGLWEInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes, - { - let tsk_size: usize = tsk_infos.k().div_ceil(tsk_infos.base2k()) as usize; - let size_in: usize = out_infos - .k() - .div_ceil(tsk_infos.base2k()) - .div_ceil(tsk_infos.dsize().into()) as usize; - - let tmp_dft_i: usize = module.vec_znx_dft_alloc_bytes((tsk_infos.rank_out() + 1).into(), tsk_size); - let tmp_a: usize = module.vec_znx_dft_alloc_bytes(1, size_in); - let vmp: usize = module.vmp_apply_dft_to_dft_tmp_bytes( - tsk_size, - size_in, - size_in, - (tsk_infos.rank_in()).into(), // Verify if rank+1 - (tsk_infos.rank_out()).into(), // Verify if rank+1 - tsk_size, - ); - let tmp_idft: usize = module.vec_znx_big_alloc_bytes(1, tsk_size); - let norm: usize = module.vec_znx_normalize_tmp_bytes(); - - tmp_dft_i + ((tmp_a + vmp) | (tmp_idft + norm)) - } - - #[allow(clippy::too_many_arguments)] - pub fn keyswitch_scratch_space( - module: &Module, - out_infos: &OUT, - in_infos: &IN, - apply_infos: &KEY, - tsk_infos: &TSK, +impl GGSW> { + pub fn keyswitch_tmp_bytes( + module: &M, + res_infos: &R, + a_infos: &A, + key_infos: &K, + tsk_infos: &T, ) -> usize where - OUT: GGSWInfos, - IN: GGSWInfos, - KEY: GGLWEInfos, - TSK: GGLWEInfos, - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigAllocBytes - + VecZnxNormalizeTmpBytes - + VecZnxBigNormalizeTmpBytes, + R: GGSWInfos, + A: GGSWInfos, + K: GGLWEInfos, + T: GGLWEInfos, + M: GGSWKeySwitch, { - #[cfg(debug_assertions)] - { - assert_eq!(apply_infos.rank_in(), apply_infos.rank_out()); - assert_eq!(tsk_infos.rank_in(), tsk_infos.rank_out()); - assert_eq!(apply_infos.rank_in(), tsk_infos.rank_in()); - } + module.ggsw_keyswitch_tmp_bytes(res_infos, a_infos, key_infos, tsk_infos) + } +} - let rank: usize = apply_infos.rank_out().into(); +impl GGSW { + pub fn keyswitch(&mut self, module: &M, a: &A, key: &K, tsk: &T, scratch: &mut Scratch) + where + A: GGSWToRef, + K: GLWESwitchingKeyPreparedToRef, + T: TensorKeyPreparedToRef, + Scratch: ScratchTakeCore, + M: GGSWKeySwitch, + { + module.ggsw_keyswitch(self, a, key, tsk, scratch); + } - let size_out: usize = out_infos.k().div_ceil(out_infos.base2k()) as usize; - let res_znx: usize = VecZnx::alloc_bytes(module.n(), rank + 1, size_out); - let ci_dft: usize = module.vec_znx_dft_alloc_bytes(rank + 1, size_out); - let ks: usize = GLWECiphertext::keyswitch_scratch_space(module, out_infos, in_infos, apply_infos); - let expand_rows: usize = GGSWCiphertext::expand_row_scratch_space(module, out_infos, tsk_infos); - let res_dft: usize = module.vec_znx_dft_alloc_bytes(rank + 1, size_out); + pub fn keyswitch_inplace(&mut self, module: &M, key: &K, tsk: &T, scratch: &mut Scratch) + where + K: GLWESwitchingKeyPreparedToRef, + T: TensorKeyPreparedToRef, + Scratch: ScratchTakeCore, + M: GGSWKeySwitch, + { + module.ggsw_keyswitch_inplace(self, key, tsk, scratch); + } +} - if in_infos.base2k() == tsk_infos.base2k() { +pub trait GGSWKeySwitch +where + Self: GLWEKeySwitch + GGSWExpandRows, +{ + fn ggsw_keyswitch_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &K, tsk_infos: &T) -> usize + where + R: GGSWInfos, + A: GGSWInfos, + K: GGLWEInfos, + T: GGLWEInfos, + { + assert_eq!(key_infos.rank_in(), key_infos.rank_out()); + assert_eq!(tsk_infos.rank_in(), tsk_infos.rank_out()); + assert_eq!(key_infos.rank_in(), tsk_infos.rank_in()); + + let rank: usize = key_infos.rank_out().into(); + + let size_out: usize = res_infos.k().div_ceil(res_infos.base2k()) as usize; + let res_znx: usize = VecZnx::bytes_of(self.n(), rank + 1, size_out); + let ci_dft: usize = self.bytes_of_vec_znx_dft(rank + 1, size_out); + let ks: usize = self.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos); + let expand_rows: usize = self.ggsw_expand_rows_tmp_bytes(res_infos, tsk_infos); + let res_dft: usize = self.bytes_of_vec_znx_dft(rank + 1, size_out); + + if a_infos.base2k() == tsk_infos.base2k() { res_znx + ci_dft + (ks | expand_rows | res_dft) } else { - let a_conv: usize = VecZnx::alloc_bytes( - module.n(), + let a_conv: usize = VecZnx::bytes_of( + self.n(), 1, - out_infos.k().div_ceil(tsk_infos.base2k()) as usize, - ) + module.vec_znx_normalize_tmp_bytes(); + res_infos.k().div_ceil(tsk_infos.base2k()) as usize, + ) + self.vec_znx_normalize_tmp_bytes(); res_znx + ci_dft + (a_conv | ks | expand_rows | res_dft) } } - #[allow(clippy::too_many_arguments)] - pub fn keyswitch_inplace_scratch_space( - module: &Module, - out_infos: &OUT, - apply_infos: &KEY, - tsk_infos: &TSK, - ) -> usize + fn ggsw_keyswitch(&self, res: &mut R, a: &A, key: &K, tsk: &T, scratch: &mut Scratch) where - OUT: GGSWInfos, - KEY: GGLWEInfos, - TSK: GGLWEInfos, - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigAllocBytes - + VecZnxNormalizeTmpBytes - + VecZnxBigNormalizeTmpBytes, + R: GGSWToMut, + A: GGSWToRef, + K: GLWESwitchingKeyPreparedToRef, + T: TensorKeyPreparedToRef, + Scratch: ScratchTakeCore, { - GGSWCiphertext::keyswitch_scratch_space(module, out_infos, out_infos, apply_infos, tsk_infos) + let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); + let a: &GGSW<&[u8]> = &a.to_ref(); + + assert_eq!(res.ggsw_layout(), a.ggsw_layout()); + + for row in 0..a.dnum().into() { + // Key-switch column 0, i.e. + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2) + self.glwe_keyswitch(&mut res.at_mut(row, 0), &a.at(row, 0), key, scratch); + } + + self.ggsw_expand_row(res, tsk, scratch); + } + + fn ggsw_keyswitch_inplace(&self, res: &mut R, key: &K, tsk: &T, scratch: &mut Scratch) + where + R: GGSWToMut, + K: GLWESwitchingKeyPreparedToRef, + T: TensorKeyPreparedToRef, + Scratch: ScratchTakeCore, + { + let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); + + for row in 0..res.dnum().into() { + // Key-switch column 0, i.e. + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2) + self.glwe_keyswitch_inplace(&mut res.at_mut(row, 0), key, scratch); + } + + self.ggsw_expand_row(res, tsk, scratch); } } -impl GGSWCiphertext { - pub fn from_gglwe( - &mut self, - module: &Module, - a: &GGLWECiphertext, - tsk: &GGLWETensorKeyPrepared, - scratch: &mut Scratch, - ) where - DataA: DataRef, - DataTsk: DataRef, - Module: VecZnxCopy - + VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigAllocBytes - + VecZnxNormalizeTmpBytes - + VecZnxDftApply - + VecZnxDftCopy - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftAddInplace - + VecZnxBigNormalize - + VecZnxIdftApplyTmpA - + VecZnxNormalize, - Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnxBig + TakeVecZnx, - { - #[cfg(debug_assertions)] - { - use crate::layouts::{GLWEInfos, LWEInfos}; - - assert_eq!(self.rank(), a.rank_out()); - assert_eq!(self.dnum(), a.dnum()); - assert_eq!(self.n(), module.n() as u32); - assert_eq!(a.n(), module.n() as u32); - assert_eq!(tsk.n(), module.n() as u32); - } - (0..self.dnum().into()).for_each(|row_i| { - self.at_mut(row_i, 0).copy(module, &a.at(row_i, 0)); - }); - self.expand_row(module, tsk, scratch); - } - - pub fn keyswitch( - &mut self, - module: &Module, - lhs: &GGSWCiphertext, - ksk: &GGLWESwitchingKeyPrepared, - tsk: &GGLWETensorKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxDftAllocBytes - + VecZnxBigAllocBytes - + VecZnxNormalizeTmpBytes - + VecZnxDftCopy - + VecZnxDftAddInplace - + VecZnxIdftApplyTmpA - + VecZnxNormalize, - Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnxBig + TakeVecZnx, - { - (0..lhs.dnum().into()).for_each(|row_i| { - // Key-switch column 0, i.e. - // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2) - self.at_mut(row_i, 0) - .keyswitch(module, &lhs.at(row_i, 0), ksk, scratch); - }); - self.expand_row(module, tsk, scratch); - } - - pub fn keyswitch_inplace( - &mut self, - module: &Module, - ksk: &GGLWESwitchingKeyPrepared, - tsk: &GGLWETensorKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxDftAllocBytes - + VecZnxBigAllocBytes - + VecZnxNormalizeTmpBytes - + VecZnxDftCopy - + VecZnxDftAddInplace - + VecZnxIdftApplyTmpA - + VecZnxNormalize, - Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnxBig + TakeVecZnx, - { - (0..self.dnum().into()).for_each(|row_i| { - // Key-switch column 0, i.e. - // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2) - self.at_mut(row_i, 0) - .keyswitch_inplace(module, ksk, scratch); - }); - self.expand_row(module, tsk, scratch); - } - - pub fn expand_row( - &mut self, - module: &Module, - tsk: &GGLWETensorKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigAllocBytes - + VecZnxNormalizeTmpBytes - + VecZnxDftApply - + VecZnxDftCopy - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftAddInplace - + VecZnxBigNormalize - + VecZnxIdftApplyTmpA - + VecZnxNormalize, - Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnxBig + TakeVecZnx, - { - let basek_in: usize = self.base2k().into(); - let basek_tsk: usize = tsk.base2k().into(); - - assert!(scratch.available() >= GGSWCiphertext::expand_row_scratch_space(module, self, tsk)); - - let n: usize = self.n().into(); - let rank: usize = self.rank().into(); - let cols: usize = rank + 1; - - let a_size: usize = (self.size() * basek_in).div_ceil(basek_tsk); - - // Keyswitch the j-th row of the col 0 - for row_i in 0..self.dnum().into() { - let a = &self.at(row_i, 0).data; - - // Pre-compute DFT of (a0, a1, a2) - let (mut ci_dft, scratch_1) = scratch.take_vec_znx_dft(n, cols, a_size); - - if basek_in == basek_tsk { - for i in 0..cols { - module.vec_znx_dft_apply(1, 0, &mut ci_dft, i, a, i); - } - } else { - let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(n, 1, a_size); - for i in 0..cols { - module.vec_znx_normalize(basek_tsk, &mut a_conv, 0, basek_in, a, i, scratch_2); - module.vec_znx_dft_apply(1, 0, &mut ci_dft, i, &a_conv, 0); - } - } - - for col_j in 1..cols { - // Example for rank 3: - // - // Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is - // actually composed of that many dnum and we focus on a specific row here - // implicitely given ci_dft. - // - // # Input - // - // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) - // col 1: (0, 0, 0, 0) - // col 2: (0, 0, 0, 0) - // col 3: (0, 0, 0, 0) - // - // # Output - // - // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) - // col 1: (-(b0s0 + b1s1 + b2s2) , b0 + M[i], b1 , b2 ) - // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M[i], c2 ) - // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M[i]) - - let dsize: usize = tsk.dsize().into(); - - let (mut tmp_dft_i, scratch_2) = scratch_1.take_vec_znx_dft(n, cols, tsk.size()); - let (mut tmp_a, scratch_3) = scratch_2.take_vec_znx_dft(n, 1, ci_dft.size().div_ceil(dsize)); - - { - // Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2 - // - // # Example for col=1 - // - // a0 * (-(f0s0 + f1s1 + f1s2) + s0^2, f0, f1, f2) = (-(a0f0s0 + a0f1s1 + a0f1s2) + a0s0^2, a0f0, a0f1, a0f2) - // + - // a1 * (-(g0s0 + g1s1 + g1s2) + s0s1, g0, g1, g2) = (-(a1g0s0 + a1g1s1 + a1g1s2) + a1s0s1, a1g0, a1g1, a1g2) - // + - // a2 * (-(h0s0 + h1s1 + h1s2) + s0s2, h0, h1, h2) = (-(a2h0s0 + a2h1s1 + a2h1s2) + a2s0s2, a2h0, a2h1, a2h2) - // = - // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2) - for col_i in 1..cols { - let pmat: &VmpPMat = &tsk.at(col_i - 1, col_j - 1).key.data; // Selects Enc(s[i]s[j]) - - // Extracts a[i] and multipies with Enc(s[i]s[j]) - for di in 0..dsize { - tmp_a.set_size((ci_dft.size() + di) / dsize); - - // Small optimization for dsize > 2 - // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then - // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}. - // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. - // It is possible to further ignore the last dsize-1 limbs, but this introduce - // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same - // noise is kept with respect to the ideal functionality. - tmp_dft_i.set_size(tsk.size() - ((dsize - di) as isize - 2).max(0) as usize); - - module.vec_znx_dft_copy(dsize, dsize - 1 - di, &mut tmp_a, 0, &ci_dft, col_i); - if di == 0 && col_i == 1 { - module.vmp_apply_dft_to_dft(&mut tmp_dft_i, &tmp_a, pmat, scratch_3); - } else { - module.vmp_apply_dft_to_dft_add(&mut tmp_dft_i, &tmp_a, pmat, di, scratch_3); - } - } - } - } - - // Adds -(sum a[i] * s[i]) + m) on the i-th column of tmp_idft_i - // - // (-(x0s0 + x1s1 + x2s2) + a0s0s0 + a1s0s1 + a2s0s2, x0, x1, x2) - // + - // (0, -(a0s0 + a1s1 + a2s2) + M[i], 0, 0) - // = - // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0 -(a0s0 + a1s1 + a2s2) + M[i], x1, x2) - // = - // (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2) - module.vec_znx_dft_add_inplace(&mut tmp_dft_i, col_j, &ci_dft, 0); - let (mut tmp_idft, scratch_3) = scratch_2.take_vec_znx_big(n, 1, tsk.size()); - for i in 0..cols { - module.vec_znx_idft_apply_tmpa(&mut tmp_idft, 0, &mut tmp_dft_i, i); - module.vec_znx_big_normalize( - basek_in, - &mut self.at_mut(row_i, col_j).data, - i, - basek_tsk, - &tmp_idft, - 0, - scratch_3, - ); - } - } - } - } -} +impl GGSW {} diff --git a/poulpy-core/src/keyswitching/glwe_ct.rs b/poulpy-core/src/keyswitching/glwe_ct.rs index 07d95e9..6d7bff9 100644 --- a/poulpy-core/src/keyswitching/glwe_ct.rs +++ b/poulpy-core/src/keyswitching/glwe_ct.rs @@ -1,186 +1,179 @@ use poulpy_hal::{ api::{ - ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, + ModuleN, ScratchAvailable, ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, + VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, ZnxInfos}, }; -use crate::layouts::{GGLWEInfos, GLWECiphertext, GLWEInfos, LWEInfos, prepared::GGLWESwitchingKeyPrepared}; +use crate::{ + ScratchTakeCore, + layouts::{ + GGLWEInfos, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, + prepared::{GLWESwitchingKeyPrepared, GLWESwitchingKeyPreparedToRef}, + }, +}; -impl GLWECiphertext> { - pub fn keyswitch_scratch_space( - module: &Module, - out_infos: &OUT, - in_infos: &IN, - key_apply: &KEY, - ) -> usize +impl GLWE> { + pub fn keyswitch_tmp_bytes(module: &M, res_infos: &R, a_infos: &A, b_infos: &B) -> usize where - OUT: GLWEInfos, - IN: GLWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, + R: GLWEInfos, + A: GLWEInfos, + B: GGLWEInfos, + M: GLWEKeySwitch, { - let in_size: usize = in_infos + module.glwe_keyswitch_tmp_bytes(res_infos, a_infos, b_infos) + } +} + +impl GLWE { + pub fn keyswitch(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch) + where + A: GLWEToRef, + B: GLWESwitchingKeyPreparedToRef, + M: GLWEKeySwitch, + Scratch: ScratchTakeCore, + { + module.glwe_keyswitch(self, a, b, scratch); + } + + pub fn keyswitch_inplace(&mut self, module: &M, a: &A, scratch: &mut Scratch) + where + A: GLWESwitchingKeyPreparedToRef, + M: GLWEKeySwitch, + Scratch: ScratchTakeCore, + { + module.glwe_keyswitch_inplace(self, a, scratch); + } +} + +impl GLWEKeySwitch for Module where + Self: Sized + + ModuleN + + VecZnxDftBytesOf + + VmpApplyDftToDftTmpBytes + + VecZnxBigNormalizeTmpBytes + + VecZnxNormalizeTmpBytes + + VecZnxDftBytesOf + + VmpApplyDftToDftTmpBytes + + VecZnxBigNormalizeTmpBytes + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VecZnxDftApply + + VecZnxIdftApplyConsume + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize + + VecZnxNormalize + + VecZnxNormalizeTmpBytes +{ +} + +pub trait GLWEKeySwitch +where + Self: Sized + + ModuleN + + VecZnxDftBytesOf + + VmpApplyDftToDftTmpBytes + + VecZnxBigNormalizeTmpBytes + + VecZnxNormalizeTmpBytes + + VecZnxDftBytesOf + + VmpApplyDftToDftTmpBytes + + VecZnxBigNormalizeTmpBytes + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VecZnxDftApply + + VecZnxIdftApplyConsume + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize + + VecZnxNormalize + + VecZnxNormalizeTmpBytes, +{ + fn glwe_keyswitch_tmp_bytes(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize + where + R: GLWEInfos, + A: GLWEInfos, + B: GGLWEInfos, + { + let in_size: usize = a_infos .k() - .div_ceil(key_apply.base2k()) - .div_ceil(key_apply.dsize().into()) as usize; - let out_size: usize = out_infos.size(); - let ksk_size: usize = key_apply.size(); - let res_dft: usize = module.vec_znx_dft_alloc_bytes((key_apply.rank_out() + 1).into(), ksk_size); // TODO OPTIMIZE - let ai_dft: usize = module.vec_znx_dft_alloc_bytes((key_apply.rank_in()).into(), in_size); - let vmp: usize = module.vmp_apply_dft_to_dft_tmp_bytes( + .div_ceil(b_infos.base2k()) + .div_ceil(b_infos.dsize().into()) as usize; + let out_size: usize = res_infos.size(); + let ksk_size: usize = b_infos.size(); + let res_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank_out() + 1).into(), ksk_size); // TODO OPTIMIZE + let ai_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank_in()).into(), in_size); + let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes( out_size, in_size, in_size, - (key_apply.rank_in()).into(), - (key_apply.rank_out() + 1).into(), + (b_infos.rank_in()).into(), + (b_infos.rank_out() + 1).into(), ksk_size, - ) + module.vec_znx_dft_alloc_bytes((key_apply.rank_in()).into(), in_size); - let normalize_big: usize = module.vec_znx_big_normalize_tmp_bytes(); - if in_infos.base2k() == key_apply.base2k() { + ) + self.bytes_of_vec_znx_dft((b_infos.rank_in()).into(), in_size); + let normalize_big: usize = self.vec_znx_big_normalize_tmp_bytes(); + if a_infos.base2k() == b_infos.base2k() { res_dft + ((ai_dft + vmp) | normalize_big) - } else if key_apply.dsize() == 1 { + } else if b_infos.dsize() == 1 { // In this case, we only need one column, temporary, that we can drop once a_dft is computed. - let normalize_conv: usize = VecZnx::alloc_bytes(module.n(), 1, in_size) + module.vec_znx_normalize_tmp_bytes(); + let normalize_conv: usize = VecZnx::bytes_of(self.n(), 1, in_size) + self.vec_znx_normalize_tmp_bytes(); res_dft + (((ai_dft + normalize_conv) | vmp) | normalize_big) } else { // Since we stride over a to get a_dft when dsize > 1, we need to store the full columns of a with in the base conversion. - let normalize_conv: usize = VecZnx::alloc_bytes(module.n(), (key_apply.rank_in()).into(), in_size); - res_dft + ((ai_dft + normalize_conv + (module.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big) + let normalize_conv: usize = VecZnx::bytes_of(self.n(), (b_infos.rank_in()).into(), in_size); + res_dft + ((ai_dft + normalize_conv + (self.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big) } } - pub fn keyswitch_inplace_scratch_space(module: &Module, out_infos: &OUT, key_apply: &KEY) -> usize + fn glwe_keyswitch(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) where - OUT: GLWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, + R: GLWEToMut, + A: GLWEToRef, + B: GLWESwitchingKeyPreparedToRef, + Scratch: ScratchTakeCore, { - Self::keyswitch_scratch_space(module, out_infos, out_infos, key_apply) - } -} + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); + let b: &GLWESwitchingKeyPrepared<&[u8], BE> = &b.to_ref(); -impl GLWECiphertext { - #[allow(dead_code)] - pub(crate) fn assert_keyswitch( - &self, - module: &Module, - lhs: &GLWECiphertext, - rhs: &GGLWESwitchingKeyPrepared, - scratch: &Scratch, - ) where - DataLhs: DataRef, - DataRhs: DataRef, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable, - { assert_eq!( - lhs.rank(), - rhs.rank_in(), - "lhs.rank(): {} != rhs.rank_in(): {}", - lhs.rank(), - rhs.rank_in() + a.rank(), + b.rank_in(), + "a.rank(): {} != b.rank_in(): {}", + a.rank(), + b.rank_in() ); assert_eq!( - self.rank(), - rhs.rank_out(), - "self.rank(): {} != rhs.rank_out(): {}", - self.rank(), - rhs.rank_out() + res.rank(), + b.rank_out(), + "res.rank(): {} != b.rank_out(): {}", + res.rank(), + b.rank_out() ); - assert_eq!(rhs.n(), self.n()); - assert_eq!(lhs.n(), self.n()); - let scrach_needed: usize = GLWECiphertext::keyswitch_scratch_space(module, self, lhs, rhs); + assert_eq!(res.n(), self.n() as u32); + assert_eq!(a.n(), self.n() as u32); + assert_eq!(b.n(), self.n() as u32); + + let scrach_needed: usize = self.glwe_keyswitch_tmp_bytes(res, a, b); assert!( scratch.available() >= scrach_needed, - "scratch.available()={} < GLWECiphertext::keyswitch_scratch_space( - module, - self.base2k(), - self.k(), - lhs.base2k(), - lhs.k(), - rhs.base2k(), - rhs.k(), - rhs.dsize(), - rhs.rank_in(), - rhs.rank_out(), - )={scrach_needed}", + "scratch.available()={} < glwe_keyswitch_tmp_bytes={scrach_needed}", scratch.available(), ); - } - #[allow(dead_code)] - pub(crate) fn assert_keyswitch_inplace( - &self, - module: &Module, - rhs: &GGLWESwitchingKeyPrepared, - scratch: &Scratch, - ) where - DataRhs: DataRef, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable, - { - assert_eq!( - self.rank(), - rhs.rank_out(), - "self.rank(): {} != rhs.rank_out(): {}", - self.rank(), - rhs.rank_out() - ); + let basek_out: usize = res.base2k().into(); + let base2k_out: usize = b.base2k().into(); - assert_eq!(rhs.n(), self.n()); - - let scrach_needed: usize = GLWECiphertext::keyswitch_inplace_scratch_space(module, self, rhs); - - assert!( - scratch.available() >= scrach_needed, - "scratch.available()={} < GLWECiphertext::keyswitch_scratch_space()={scrach_needed}", - scratch.available(), - ); - } -} - -impl GLWECiphertext { - pub fn keyswitch( - &mut self, - module: &Module, - glwe_in: &GLWECiphertext, - rhs: &GGLWESwitchingKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnx, - { - #[cfg(debug_assertions)] - { - self.assert_keyswitch(module, glwe_in, rhs, scratch); - } - - let basek_out: usize = self.base2k().into(); - let basek_ksk: usize = rhs.base2k().into(); - - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // Todo optimise - let res_big: VecZnxBig<_, B> = glwe_in.keyswitch_internal(module, res_dft, rhs, scratch_1); - (0..(self.rank() + 1).into()).for_each(|i| { - module.vec_znx_big_normalize( + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), b.size()); // Todo optimise + let res_big: VecZnxBig<&mut [u8], BE> = keyswitch_internal(self, res_dft, a, b, scratch_1); + (0..(res.rank() + 1).into()).for_each(|i| { + self.vec_znx_big_normalize( basek_out, - &mut self.data, + &mut res.data, i, - basek_ksk, + base2k_out, &res_big, i, scratch_1, @@ -188,227 +181,190 @@ impl GLWECiphertext { }) } - pub fn keyswitch_inplace( - &mut self, - module: &Module, - rhs: &GGLWESwitchingKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnx, - { - #[cfg(debug_assertions)] - { - self.assert_keyswitch_inplace(module, rhs, scratch); - } - - let basek_in: usize = self.base2k().into(); - let basek_ksk: usize = rhs.base2k().into(); - - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // Todo optimise - let res_big: VecZnxBig<_, B> = self.keyswitch_internal(module, res_dft, rhs, scratch_1); - (0..(self.rank() + 1).into()).for_each(|i| { - module.vec_znx_big_normalize( - basek_in, - &mut self.data, - i, - basek_ksk, - &res_big, - i, - scratch_1, - ); - }) - } -} - -impl GLWECiphertext { - pub(crate) fn keyswitch_internal( - &self, - module: &Module, - res_dft: VecZnxDft, - rhs: &GGLWESwitchingKeyPrepared, - scratch: &mut Scratch, - ) -> VecZnxBig + fn glwe_keyswitch_inplace(&self, res: &mut R, a: &A, scratch: &mut Scratch) where - DataRes: DataMut, - DataKey: DataRef, - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalize, - Scratch: TakeVecZnxDft + TakeVecZnx, + R: GLWEToMut, + A: GLWESwitchingKeyPreparedToRef, + Scratch: ScratchTakeCore, { - if rhs.dsize() == 1 { - return keyswitch_vmp_one_digit( - module, - self.base2k().into(), - rhs.base2k().into(), - res_dft, - &self.data, - &rhs.key.data, - scratch, + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWESwitchingKeyPrepared<&[u8], BE> = &a.to_ref(); + + assert_eq!( + res.rank(), + a.rank_in(), + "res.rank(): {} != a.rank_in(): {}", + res.rank(), + a.rank_in() + ); + assert_eq!( + res.rank(), + a.rank_out(), + "res.rank(): {} != b.rank_out(): {}", + res.rank(), + a.rank_out() + ); + + assert_eq!(res.n(), self.n() as u32); + assert_eq!(a.n(), self.n() as u32); + + let scrach_needed: usize = self.glwe_keyswitch_tmp_bytes(res, res, a); + + assert!( + scratch.available() >= scrach_needed, + "scratch.available()={} < glwe_keyswitch_tmp_bytes={scrach_needed}", + scratch.available(), + ); + + let base2k_in: usize = res.base2k().into(); + let base2k_out: usize = a.base2k().into(); + + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), a.size()); // Todo optimise + let res_big: VecZnxBig<&mut [u8], BE> = keyswitch_internal(self, res_dft, res, a, scratch_1); + (0..(res.rank() + 1).into()).for_each(|i| { + self.vec_znx_big_normalize( + base2k_in, + &mut res.data, + i, + base2k_out, + &res_big, + i, + scratch_1, ); - } - - keyswitch_vmp_multiple_digits( - module, - self.base2k().into(), - rhs.base2k().into(), - res_dft, - &self.data, - &rhs.key.data, - rhs.dsize().into(), - scratch, - ) + }) } } -fn keyswitch_vmp_one_digit( - module: &Module, - basek_in: usize, - basek_ksk: usize, - mut res_dft: VecZnxDft, - a: &VecZnx, - mat: &VmpPMat, - scratch: &mut Scratch, -) -> VecZnxBig +impl GLWE> {} + +impl GLWE {} + +fn keyswitch_internal( + module: &M, + mut res: VecZnxDft, + a: &GLWE, + b: &GLWESwitchingKeyPrepared, + scratch: &mut Scratch, +) -> VecZnxBig where - DataRes: DataMut, - DataIn: DataRef, - DataVmp: DataRef, - Module: VecZnxDftAllocBytes - + VecZnxDftApply - + VmpApplyDftToDft - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxNormalize, - Scratch: TakeVecZnxDft + TakeVecZnx, + DR: DataMut, + DA: DataRef, + DB: DataRef, + M: ModuleN + + VecZnxDftBytesOf + + VmpApplyDftToDftTmpBytes + + VecZnxBigNormalizeTmpBytes + + VmpApplyDftToDftTmpBytes + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VecZnxDftApply + + VecZnxIdftApplyConsume + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize + + VecZnxNormalize, + Scratch: ScratchTakeCore, { - let cols: usize = a.cols(); + let base2k_in: usize = a.base2k().into(); + let base2k_out: usize = b.base2k().into(); + let cols: usize = (a.rank() + 1).into(); + let a_size: usize = (a.size() * base2k_in).div_ceil(base2k_out); + let pmat: &VmpPMat = &b.key.data; - let a_size: usize = (a.size() * basek_in).div_ceil(basek_ksk); - let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(a.n(), cols - 1, a.size()); + if b.dsize() == 1 { + let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols - 1, a.size()); - if basek_in == basek_ksk { - (0..cols - 1).for_each(|col_i| { - module.vec_znx_dft_apply(1, 0, &mut ai_dft, col_i, a, col_i + 1); - }); + if base2k_in == base2k_out { + (0..cols - 1).for_each(|col_i| { + module.vec_znx_dft_apply(1, 0, &mut ai_dft, col_i, a.data(), col_i + 1); + }); + } else { + let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(module, 1, a_size); + (0..cols - 1).for_each(|col_i| { + module.vec_znx_normalize( + base2k_out, + &mut a_conv, + 0, + base2k_in, + a.data(), + col_i + 1, + scratch_2, + ); + module.vec_znx_dft_apply(1, 0, &mut ai_dft, col_i, &a_conv, 0); + }); + } + + module.vmp_apply_dft_to_dft(&mut res, &ai_dft, pmat, scratch_1); } else { - let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(a.n(), 1, a_size); - (0..cols - 1).for_each(|col_i| { - module.vec_znx_normalize(basek_ksk, &mut a_conv, 0, basek_in, a, col_i + 1, scratch_2); - module.vec_znx_dft_apply(1, 0, &mut ai_dft, col_i, &a_conv, 0); - }); + let dsize: usize = b.dsize().into(); + + let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols - 1, a_size.div_ceil(dsize)); + ai_dft.data_mut().fill(0); + + if base2k_in == base2k_out { + for di in 0..dsize { + ai_dft.set_size((a_size + di) / dsize); + + // Small optimization for dsize > 2 + // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then + // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}. + // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. + // It is possible to further ignore the last dsize-1 limbs, but this introduce + // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same + // noise is kept with respect to the ideal functionality. + res.set_size(pmat.size() - ((dsize - di) as isize - 2).max(0) as usize); + + for j in 0..cols - 1 { + module.vec_znx_dft_apply(dsize, dsize - di - 1, &mut ai_dft, j, a.data(), j + 1); + } + + if di == 0 { + module.vmp_apply_dft_to_dft(&mut res, &ai_dft, pmat, scratch_1); + } else { + module.vmp_apply_dft_to_dft_add(&mut res, &ai_dft, pmat, di, scratch_1); + } + } + } else { + let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(module, cols - 1, a_size); + for j in 0..cols - 1 { + module.vec_znx_normalize( + base2k_out, + &mut a_conv, + j, + base2k_in, + a.data(), + j + 1, + scratch_2, + ); + } + + for di in 0..dsize { + ai_dft.set_size((a_size + di) / dsize); + + // Small optimization for dsize > 2 + // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then + // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}. + // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. + // It is possible to further ignore the last dsize-1 limbs, but this introduce + // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same + // noise is kept with respect to the ideal functionality. + res.set_size(pmat.size() - ((dsize - di) as isize - 2).max(0) as usize); + + for j in 0..cols - 1 { + module.vec_znx_dft_apply(dsize, dsize - di - 1, &mut ai_dft, j, &a_conv, j); + } + + if di == 0 { + module.vmp_apply_dft_to_dft(&mut res, &ai_dft, pmat, scratch_2); + } else { + module.vmp_apply_dft_to_dft_add(&mut res, &ai_dft, pmat, di, scratch_2); + } + } + } + + res.set_size(res.max_size()); } - module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch_1); - let mut res_big: VecZnxBig = module.vec_znx_idft_apply_consume(res_dft); - module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0); - res_big -} - -#[allow(clippy::too_many_arguments)] -fn keyswitch_vmp_multiple_digits( - module: &Module, - basek_in: usize, - basek_ksk: usize, - mut res_dft: VecZnxDft, - a: &VecZnx, - mat: &VmpPMat, - dsize: usize, - scratch: &mut Scratch, -) -> VecZnxBig -where - DataRes: DataMut, - DataIn: DataRef, - DataVmp: DataRef, - Module: VecZnxDftAllocBytes - + VecZnxDftApply - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxNormalize, - Scratch: TakeVecZnxDft + TakeVecZnx, -{ - let cols: usize = a.cols(); - let a_size: usize = (a.size() * basek_in).div_ceil(basek_ksk); - let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(a.n(), cols - 1, a_size.div_ceil(dsize)); - ai_dft.data_mut().fill(0); - - if basek_in == basek_ksk { - for di in 0..dsize { - ai_dft.set_size((a_size + di) / dsize); - - // Small optimization for dsize > 2 - // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then - // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}. - // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. - // It is possible to further ignore the last dsize-1 limbs, but this introduce - // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same - // noise is kept with respect to the ideal functionality. - res_dft.set_size(mat.size() - ((dsize - di) as isize - 2).max(0) as usize); - - for j in 0..cols - 1 { - module.vec_znx_dft_apply(dsize, dsize - di - 1, &mut ai_dft, j, a, j + 1); - } - - if di == 0 { - module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch_1); - } else { - module.vmp_apply_dft_to_dft_add(&mut res_dft, &ai_dft, mat, di, scratch_1); - } - } - } else { - let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(a.n(), cols - 1, a_size); - for j in 0..cols - 1 { - module.vec_znx_normalize(basek_ksk, &mut a_conv, j, basek_in, a, j + 1, scratch_2); - } - - for di in 0..dsize { - ai_dft.set_size((a_size + di) / dsize); - - // Small optimization for dsize > 2 - // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then - // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}. - // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. - // It is possible to further ignore the last dsize-1 limbs, but this introduce - // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same - // noise is kept with respect to the ideal functionality. - res_dft.set_size(mat.size() - ((dsize - di) as isize - 2).max(0) as usize); - - for j in 0..cols - 1 { - module.vec_znx_dft_apply(dsize, dsize - di - 1, &mut ai_dft, j, &a_conv, j); - } - - if di == 0 { - module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch_2); - } else { - module.vmp_apply_dft_to_dft_add(&mut res_dft, &ai_dft, mat, di, scratch_2); - } - } - } - - res_dft.set_size(res_dft.max_size()); - let mut res_big: VecZnxBig = module.vec_znx_idft_apply_consume(res_dft); - module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0); + let mut res_big: VecZnxBig = module.vec_znx_idft_apply_consume(res); + module.vec_znx_big_add_small_inplace(&mut res_big, 0, a.data(), 0); res_big } diff --git a/poulpy-core/src/keyswitching/lwe_ct.rs b/poulpy-core/src/keyswitching/lwe_ct.rs index 7d9e08e..8546ccb 100644 --- a/poulpy-core/src/keyswitching/lwe_ct.rs +++ b/poulpy-core/src/keyswitching/lwe_ct.rs @@ -1,116 +1,116 @@ use poulpy_hal::{ - api::{ - ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, - VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, - VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, - }, - layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut, ZnxZero}, + api::ScratchAvailable, + layouts::{Backend, DataMut, Module, Scratch, ZnxView, ZnxViewMut, ZnxZero}, }; use crate::{ - TakeGLWECt, + ScratchTakeCore, + keyswitching::glwe_ct::GLWEKeySwitch, layouts::{ - GGLWEInfos, GLWECiphertext, GLWECiphertextLayout, LWECiphertext, LWEInfos, Rank, TorusPrecision, - prepared::LWESwitchingKeyPrepared, + GGLWEInfos, GLWE, GLWEAlloc, GLWELayout, LWE, LWEInfos, LWEToMut, LWEToRef, Rank, TorusPrecision, + prepared::{LWESwitchingKeyPrepared, LWESwitchingKeyPreparedToRef}, }, }; -impl LWECiphertext> { - pub fn keyswitch_scratch_space( - module: &Module, - out_infos: &OUT, - in_infos: &IN, - key_infos: &KEY, - ) -> usize +impl LWE> { + pub fn keyswitch_tmp_bytes(module: &M, res_infos: &R, a_infos: &A, key_infos: &K) -> usize where - OUT: LWEInfos, - IN: LWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalizeTmpBytes, + R: LWEInfos, + A: LWEInfos, + K: GGLWEInfos, + M: LWEKeySwitch, { - let max_k: TorusPrecision = in_infos.k().max(out_infos.k()); - - let glwe_in_infos: GLWECiphertextLayout = GLWECiphertextLayout { - n: module.n().into(), - base2k: in_infos.base2k(), - k: max_k, - rank: Rank(1), - }; - - let glwe_out_infos: GLWECiphertextLayout = GLWECiphertextLayout { - n: module.n().into(), - base2k: out_infos.base2k(), - k: max_k, - rank: Rank(1), - }; - - let glwe_in: usize = GLWECiphertext::alloc_bytes(&glwe_in_infos); - let glwe_out: usize = GLWECiphertext::alloc_bytes(&glwe_out_infos); - let ks: usize = GLWECiphertext::keyswitch_scratch_space(module, &glwe_out_infos, &glwe_in_infos, key_infos); - - glwe_in + glwe_out + ks + module.lwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos) } } -impl LWECiphertext { - pub fn keyswitch( - &mut self, - module: &Module, - a: &LWECiphertext, - ksk: &LWESwitchingKeyPrepared, - scratch: &mut Scratch, - ) where - A: DataRef, - DKs: DataRef, - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalize - + VecZnxNormalizeTmpBytes - + VecZnxCopy, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, +impl LWE { + pub fn keyswitch(&mut self, module: &M, a: &A, ksk: &K, scratch: &mut Scratch) + where + A: LWEToRef, + K: LWESwitchingKeyPreparedToRef, + Scratch: ScratchTakeCore, + M: LWEKeySwitch, { - #[cfg(debug_assertions)] - { - assert!(self.n() <= module.n() as u32); - assert!(a.n() <= module.n() as u32); - assert!(scratch.available() >= LWECiphertext::keyswitch_scratch_space(module, self, a, ksk)); - } + module.lwe_keyswitch(self, a, ksk, scratch); + } +} - let max_k: TorusPrecision = self.k().max(a.k()); +impl LWEKeySwitch for Module where Self: LWEKeySwitch {} + +pub trait LWEKeySwitch +where + Self: GLWEKeySwitch + GLWEAlloc, +{ + fn lwe_keyswitch_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize + where + R: LWEInfos, + A: LWEInfos, + K: GGLWEInfos, + { + let max_k: TorusPrecision = a_infos.k().max(res_infos.k()); + + let glwe_a_infos: GLWELayout = GLWELayout { + n: self.ring_degree(), + base2k: a_infos.base2k(), + k: max_k, + rank: Rank(1), + }; + + let glwe_res_infos: GLWELayout = GLWELayout { + n: self.ring_degree(), + base2k: res_infos.base2k(), + k: max_k, + rank: Rank(1), + }; + + let glwe_in: usize = GLWE::bytes_of_from_infos(self, &glwe_a_infos); + let glwe_out: usize = GLWE::bytes_of_from_infos(self, &glwe_res_infos); + let ks: usize = self.glwe_keyswitch_tmp_bytes(&glwe_res_infos, &glwe_a_infos, key_infos); + + glwe_in + glwe_out + ks + } + + fn lwe_keyswitch(&self, res: &mut R, a: &A, ksk: &K, scratch: &mut Scratch) + where + R: LWEToMut, + A: LWEToRef, + K: LWESwitchingKeyPreparedToRef, + Scratch: ScratchTakeCore, + { + let res: &mut LWE<&mut [u8]> = &mut res.to_mut(); + let a: &LWE<&[u8]> = &a.to_ref(); + let ksk: &LWESwitchingKeyPrepared<&[u8], BE> = &ksk.to_ref(); + + assert!(res.n().as_usize() <= self.n()); + assert!(a.n().as_usize() <= self.n()); + assert_eq!(ksk.n(), self.n() as u32); + assert!(scratch.available() >= self.lwe_keyswitch_tmp_bytes(res, a, ksk)); + + let max_k: TorusPrecision = res.k().max(a.k()); let a_size: usize = a.k().div_ceil(ksk.base2k()) as usize; - let (mut glwe_in, scratch_1) = scratch.take_glwe_ct(&GLWECiphertextLayout { - n: ksk.n(), - base2k: a.base2k(), - k: max_k, - rank: Rank(1), - }); + let (mut glwe_in, scratch_1) = scratch.take_glwe_ct( + self, + &GLWELayout { + n: ksk.n(), + base2k: a.base2k(), + k: max_k, + rank: Rank(1), + }, + ); glwe_in.data.zero(); - let (mut glwe_out, scratch_1) = scratch_1.take_glwe_ct(&GLWECiphertextLayout { - n: ksk.n(), - base2k: self.base2k(), - k: max_k, - rank: Rank(1), - }); + let (mut glwe_out, scratch_1) = scratch_1.take_glwe_ct( + self, + &GLWELayout { + n: ksk.n(), + base2k: res.base2k(), + k: max_k, + rank: Rank(1), + }, + ); let n_lwe: usize = a.n().into(); @@ -120,7 +120,7 @@ impl LWECiphertext { glwe_in.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]); } - glwe_out.keyswitch(module, &glwe_in, &ksk.0, scratch_1); - self.sample_extract(&glwe_out); + self.glwe_keyswitch(&mut glwe_out, &glwe_in, &ksk.0, scratch_1); + res.sample_extract(&glwe_out); } } diff --git a/poulpy-core/src/layouts/compressed/gglwe_atk.rs b/poulpy-core/src/layouts/compressed/gglwe_atk.rs index 2a10765..61ba79a 100644 --- a/poulpy-core/src/layouts/compressed/gglwe_atk.rs +++ b/poulpy-core/src/layouts/compressed/gglwe_atk.rs @@ -1,24 +1,28 @@ use poulpy_hal::{ - api::{VecZnxCopy, VecZnxFillUniform}, layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo}, source::Source, }; use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWEAutomorphismKey, GGLWEInfos, GLWEInfos, LWEInfos, Rank, TorusPrecision, - compressed::{Decompress, GGLWESwitchingKeyCompressed}, + AutomorphismKey, AutomorphismKeyToMut, Base2K, Dnum, Dsize, GGLWEInfos, GLWEInfos, LWEInfos, Rank, RingDegree, + TorusPrecision, + compressed::{ + GLWESwitchingKeyCompressed, GLWESwitchingKeyCompressedAlloc, GLWESwitchingKeyCompressedToMut, + GLWESwitchingKeyCompressedToRef, GLWESwitchingKeyDecompress, + }, + prepared::{GetAutomorphismGaloisElement, SetAutomorphismGaloisElement}, }; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::fmt; #[derive(PartialEq, Eq, Clone)] -pub struct GGLWEAutomorphismKeyCompressed { - pub(crate) key: GGLWESwitchingKeyCompressed, +pub struct AutomorphismKeyCompressed { + pub(crate) key: GLWESwitchingKeyCompressed, pub(crate) p: i64, } -impl LWEInfos for GGLWEAutomorphismKeyCompressed { - fn n(&self) -> Degree { +impl LWEInfos for AutomorphismKeyCompressed { + fn n(&self) -> RingDegree { self.key.n() } @@ -34,13 +38,13 @@ impl LWEInfos for GGLWEAutomorphismKeyCompressed { self.key.size() } } -impl GLWEInfos for GGLWEAutomorphismKeyCompressed { +impl GLWEInfos for AutomorphismKeyCompressed { fn rank(&self) -> Rank { self.rank_out() } } -impl GGLWEInfos for GGLWEAutomorphismKeyCompressed { +impl GGLWEInfos for AutomorphismKeyCompressed { fn rank_in(&self) -> Rank { self.key.rank_in() } @@ -58,76 +62,185 @@ impl GGLWEInfos for GGLWEAutomorphismKeyCompressed { } } -impl fmt::Debug for GGLWEAutomorphismKeyCompressed { +impl fmt::Debug for AutomorphismKeyCompressed { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{self}") } } -impl FillUniform for GGLWEAutomorphismKeyCompressed { +impl FillUniform for AutomorphismKeyCompressed { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.key.fill_uniform(log_bound, source); } } -impl fmt::Display for GGLWEAutomorphismKeyCompressed { +impl fmt::Display for AutomorphismKeyCompressed { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "(AutomorphismKeyCompressed: p={}) {}", self.p, self.key) } } -impl GGLWEAutomorphismKeyCompressed> { - pub fn alloc(infos: &A) -> Self - where - A: GGLWEInfos, - { - debug_assert_eq!(infos.rank_in(), infos.rank_out()); - Self { - key: GGLWESwitchingKeyCompressed::alloc(infos), +pub trait AutomorphismKeyCompressedAlloc +where + Self: GLWESwitchingKeyCompressedAlloc, +{ + fn alloc_automorphism_key_compressed( + &self, + base2k: Base2K, + k: TorusPrecision, + rank: Rank, + dnum: Dnum, + dsize: Dsize, + ) -> AutomorphismKeyCompressed> { + AutomorphismKeyCompressed { + key: self.alloc_glwe_switching_key_compressed(base2k, k, rank, rank, dnum, dsize), p: 0, } } - pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self { - Self { - key: GGLWESwitchingKeyCompressed::alloc_with(n, base2k, k, rank, rank, dnum, dsize), - p: 0, - } - } - - pub fn alloc_bytes(infos: &A) -> usize + fn alloc_automorphism_key_compressed_from_infos(&self, infos: &A) -> AutomorphismKeyCompressed> where A: GGLWEInfos, { - debug_assert_eq!(infos.rank_in(), infos.rank_out()); - GGLWESwitchingKeyCompressed::alloc_bytes(infos) + assert_eq!(infos.rank_in(), infos.rank_out()); + self.alloc_automorphism_key_compressed( + infos.base2k(), + infos.k(), + infos.rank(), + infos.dnum(), + infos.dsize(), + ) } - pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { - GGLWESwitchingKeyCompressed::alloc_bytes_with(n, base2k, k, rank, dnum, dsize) + fn bytes_of_automorphism_key_compressed( + &self, + base2k: Base2K, + k: TorusPrecision, + rank: Rank, + dnum: Dnum, + dsize: Dsize, + ) -> usize { + self.bytes_of_glwe_switching_key_compressed(base2k, k, rank, dnum, dsize) + } + + fn bytes_of_automorphism_key_compressed_from_infos(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + assert_eq!(infos.rank_in(), infos.rank_out()); + self.bytes_of_automorphism_key_compressed( + infos.base2k(), + infos.k(), + infos.rank(), + infos.dnum(), + infos.dsize(), + ) } } -impl ReaderFrom for GGLWEAutomorphismKeyCompressed { +impl AutomorphismKeyCompressed> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: GGLWEInfos, + M: AutomorphismKeyCompressedAlloc, + { + module.alloc_automorphism_key_compressed_from_infos(infos) + } + + pub fn alloc(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self + where + M: AutomorphismKeyCompressedAlloc, + { + module.alloc_automorphism_key_compressed(base2k, k, rank, dnum, dsize) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: AutomorphismKeyCompressedAlloc, + { + module.bytes_of_automorphism_key_compressed_from_infos(infos) + } + + pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize + where + M: AutomorphismKeyCompressedAlloc, + { + module.bytes_of_automorphism_key_compressed(base2k, k, rank, dnum, dsize) + } +} + +impl ReaderFrom for AutomorphismKeyCompressed { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { self.p = reader.read_u64::()? as i64; self.key.read_from(reader) } } -impl WriterTo for GGLWEAutomorphismKeyCompressed { +impl WriterTo for AutomorphismKeyCompressed { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { writer.write_u64::(self.p as u64)?; self.key.write_to(writer) } } -impl Decompress> for GGLWEAutomorphismKey +pub trait AutomorphismKeyDecompress where - Module: VecZnxFillUniform + VecZnxCopy, + Self: GLWESwitchingKeyDecompress, { - fn decompress(&mut self, module: &Module, other: &GGLWEAutomorphismKeyCompressed) { - self.key.decompress(module, &other.key); - self.p = other.p; + fn decompress_automorphism_key(&self, res: &mut R, other: &O) + where + R: AutomorphismKeyToMut + SetAutomorphismGaloisElement, + O: AutomorphismKeyCompressedToRef + GetAutomorphismGaloisElement, + { + self.decompress_glwe_switching_key(&mut res.to_mut().key, &other.to_ref().key); + res.set_p(other.p()); + } +} + +impl AutomorphismKeyDecompress for Module where Self: AutomorphismKeyDecompress {} + +impl AutomorphismKey +where + Self: SetAutomorphismGaloisElement, +{ + pub fn decompress(&mut self, module: &M, other: &O) + where + O: AutomorphismKeyCompressedToRef + GetAutomorphismGaloisElement, + M: AutomorphismKeyDecompress, + { + module.decompress_automorphism_key(self, other); + } +} + +pub trait AutomorphismKeyCompressedToRef { + fn to_ref(&self) -> AutomorphismKeyCompressed<&[u8]>; +} + +impl AutomorphismKeyCompressedToRef for AutomorphismKeyCompressed +where + GLWESwitchingKeyCompressed: GLWESwitchingKeyCompressedToRef, +{ + fn to_ref(&self) -> AutomorphismKeyCompressed<&[u8]> { + AutomorphismKeyCompressed { + key: self.key.to_ref(), + p: self.p, + } + } +} + +pub trait AutomorphismKeyCompressedToMut { + fn to_mut(&mut self) -> AutomorphismKeyCompressed<&mut [u8]>; +} + +impl AutomorphismKeyCompressedToMut for AutomorphismKeyCompressed +where + GLWESwitchingKeyCompressed: GLWESwitchingKeyCompressedToMut, +{ + fn to_mut(&mut self) -> AutomorphismKeyCompressed<&mut [u8]> { + AutomorphismKeyCompressed { + p: self.p, + key: self.key.to_mut(), + } } } diff --git a/poulpy-core/src/layouts/compressed/gglwe_ct.rs b/poulpy-core/src/layouts/compressed/gglwe_ct.rs index f7a4df9..9216ec4 100644 --- a/poulpy-core/src/layouts/compressed/gglwe_ct.rs +++ b/poulpy-core/src/layouts/compressed/gglwe_ct.rs @@ -1,18 +1,20 @@ use poulpy_hal::{ api::{VecZnxCopy, VecZnxFillUniform}, - layouts::{Backend, Data, DataMut, DataRef, FillUniform, MatZnx, Module, ReaderFrom, WriterTo, ZnxInfos}, + layouts::{ + Backend, Data, DataMut, DataRef, FillUniform, MatZnx, MatZnxToMut, MatZnxToRef, Module, ReaderFrom, WriterTo, ZnxInfos, + }, source::Source, }; use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWECiphertext, GGLWEInfos, GLWEInfos, LWEInfos, Rank, TorusPrecision, - compressed::{Decompress, GLWECiphertextCompressed}, + Base2K, Dnum, Dsize, GGLWE, GGLWEInfos, GGLWEToMut, GLWEInfos, GetRingDegree, LWEInfos, Rank, RingDegree, TorusPrecision, + compressed::{GLWECompressed, GLWEDecompress}, }; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::fmt; #[derive(PartialEq, Eq, Clone)] -pub struct GGLWECiphertextCompressed { +pub struct GGLWECompressed { pub(crate) data: MatZnx, pub(crate) base2k: Base2K, pub(crate) k: TorusPrecision, @@ -21,9 +23,9 @@ pub struct GGLWECiphertextCompressed { pub(crate) seed: Vec<[u8; 32]>, } -impl LWEInfos for GGLWECiphertextCompressed { - fn n(&self) -> Degree { - Degree(self.data.n() as u32) +impl LWEInfos for GGLWECompressed { + fn n(&self) -> RingDegree { + RingDegree(self.data.n() as u32) } fn base2k(&self) -> Base2K { @@ -38,13 +40,13 @@ impl LWEInfos for GGLWECiphertextCompressed { self.data.size() } } -impl GLWEInfos for GGLWECiphertextCompressed { +impl GLWEInfos for GGLWECompressed { fn rank(&self) -> Rank { self.rank_out() } } -impl GGLWEInfos for GGLWECiphertextCompressed { +impl GGLWEInfos for GGLWECompressed { fn rank_in(&self) -> Rank { Rank(self.data.cols_in() as u32) } @@ -62,53 +64,41 @@ impl GGLWEInfos for GGLWECiphertextCompressed { } } -impl fmt::Debug for GGLWECiphertextCompressed { +impl fmt::Debug for GGLWECompressed { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{self}") } } -impl FillUniform for GGLWECiphertextCompressed { +impl FillUniform for GGLWECompressed { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.data.fill_uniform(log_bound, source); } } -impl fmt::Display for GGLWECiphertextCompressed { +impl fmt::Display for GGLWECompressed { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "(GGLWECiphertextCompressed: base2k={} k={} dsize={}) {}", + "(GGLWECompressed: base2k={} k={} dsize={}) {}", self.base2k.0, self.k.0, self.dsize.0, self.data ) } } -impl GGLWECiphertextCompressed> { - pub fn alloc(infos: &A) -> Self - where - A: GGLWEInfos, - { - Self::alloc_with( - infos.n(), - infos.base2k(), - infos.k(), - infos.rank_in(), - infos.rank_out(), - infos.dnum(), - infos.dsize(), - ) - } - - pub fn alloc_with( - n: Degree, +pub trait GGLWECompressedAlloc +where + Self: GetRingDegree, +{ + fn alloc_gglwe_compressed( + &self, base2k: Base2K, k: TorusPrecision, rank_in: Rank, rank_out: Rank, dnum: Dnum, dsize: Dsize, - ) -> Self { + ) -> GGLWECompressed> { let size: usize = k.0.div_ceil(base2k.0) as usize; debug_assert!( size as u32 > dsize.0, @@ -123,9 +113,9 @@ impl GGLWECiphertextCompressed> { dsize.0, ); - Self { + GGLWECompressed { data: MatZnx::alloc( - n.into(), + self.ring_degree().into(), dnum.into(), rank_in.into(), 1, @@ -139,21 +129,22 @@ impl GGLWECiphertextCompressed> { } } - pub fn alloc_bytes(infos: &A) -> usize + fn alloc_gglwe_compressed_from_infos(&self, infos: &A) -> GGLWECompressed> where A: GGLWEInfos, { - Self::alloc_bytes_with( - infos.n(), + assert_eq!(infos.n(), self.ring_degree()); + self.alloc_gglwe_compressed( infos.base2k(), infos.k(), infos.rank_in(), + infos.rank_out(), infos.dnum(), infos.dsize(), ) } - pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum, dsize: Dsize) -> usize { + fn bytes_of_gglwe_compressed(&self, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum, dsize: Dsize) -> usize { let size: usize = k.0.div_ceil(base2k.0) as usize; debug_assert!( size as u32 > dsize.0, @@ -168,20 +159,76 @@ impl GGLWECiphertextCompressed> { dsize.0, ); - MatZnx::alloc_bytes( - n.into(), + MatZnx::bytes_of( + self.ring_degree().into(), dnum.into(), rank_in.into(), 1, k.0.div_ceil(base2k.0) as usize, ) } + + fn bytes_of_gglwe_compressed_from_infos(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + assert_eq!(infos.n(), self.ring_degree()); + self.bytes_of_gglwe_compressed( + infos.base2k(), + infos.k(), + infos.rank_in(), + infos.dnum(), + infos.dsize(), + ) + } } -impl GGLWECiphertextCompressed { - pub(crate) fn at(&self, row: usize, col: usize) -> GLWECiphertextCompressed<&[u8]> { +impl GGLWECompressedAlloc for Module where Self: GetRingDegree {} + +impl GGLWECompressed> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: GGLWEInfos, + M: GGLWECompressedAlloc, + { + module.alloc_gglwe_compressed_from_infos(infos) + } + + pub fn alloc( + module: &M, + base2k: Base2K, + k: TorusPrecision, + rank_in: Rank, + rank_out: Rank, + dnum: Dnum, + dsize: Dsize, + ) -> Self + where + M: GGLWECompressedAlloc, + { + module.alloc_gglwe_compressed(base2k, k, rank_in, rank_out, dnum, dsize) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: GGLWECompressedAlloc, + { + module.bytes_of_gglwe_compressed_from_infos(infos) + } + + pub fn byte_of(module: &M, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum, dsize: Dsize) -> usize + where + M: GGLWECompressedAlloc, + { + module.bytes_of_gglwe_compressed(base2k, k, rank_in, dnum, dsize) + } +} + +impl GGLWECompressed { + pub(crate) fn at(&self, row: usize, col: usize) -> GLWECompressed<&[u8]> { let rank_in: usize = self.rank_in().into(); - GLWECiphertextCompressed { + GLWECompressed { data: self.data.at(row, col), k: self.k, base2k: self.base2k, @@ -191,10 +238,10 @@ impl GGLWECiphertextCompressed { } } -impl GGLWECiphertextCompressed { - pub(crate) fn at_mut(&mut self, row: usize, col: usize) -> GLWECiphertextCompressed<&mut [u8]> { +impl GGLWECompressed { + pub(crate) fn at_mut(&mut self, row: usize, col: usize) -> GLWECompressed<&mut [u8]> { let rank_in: usize = self.rank_in().into(); - GLWECiphertextCompressed { + GLWECompressed { k: self.k, base2k: self.base2k, rank: self.rank_out, @@ -204,7 +251,7 @@ impl GGLWECiphertextCompressed { } } -impl ReaderFrom for GGLWECiphertextCompressed { +impl ReaderFrom for GGLWECompressed { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { self.k = TorusPrecision(reader.read_u32::()?); self.base2k = Base2K(reader.read_u32::()?); @@ -219,7 +266,7 @@ impl ReaderFrom for GGLWECiphertextCompressed { } } -impl WriterTo for GGLWECiphertextCompressed { +impl WriterTo for GGLWECompressed { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { writer.write_u32::(self.k.into())?; writer.write_u32::(self.base2k.into())?; @@ -233,59 +280,73 @@ impl WriterTo for GGLWECiphertextCompressed { } } -impl Decompress> for GGLWECiphertext +pub trait GGLWEDecompress where - Module: VecZnxFillUniform + VecZnxCopy, + Self: GLWEDecompress, { - fn decompress(&mut self, module: &Module, other: &GGLWECiphertextCompressed) { - #[cfg(debug_assertions)] - { - assert_eq!( - self.n(), - other.n(), - "invalid receiver: self.n()={} != other.n()={}", - self.n(), - other.n() - ); - assert_eq!( - self.size(), - other.size(), - "invalid receiver: self.size()={} != other.size()={}", - self.size(), - other.size() - ); - assert_eq!( - self.rank_in(), - other.rank_in(), - "invalid receiver: self.rank_in()={} != other.rank_in()={}", - self.rank_in(), - other.rank_in() - ); - assert_eq!( - self.rank_out(), - other.rank_out(), - "invalid receiver: self.rank_out()={} != other.rank_out()={}", - self.rank_out(), - other.rank_out() - ); + fn decompress_gglwe(&self, res: &mut R, other: &O) + where + R: GGLWEToMut, + O: GGLWECompressedToRef, + { + let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut(); + let other: &GGLWECompressed<&[u8]> = &other.to_ref(); - assert_eq!( - self.dnum(), - other.dnum(), - "invalid receiver: self.dnum()={} != other.dnum()={}", - self.dnum(), - other.dnum() - ); + assert_eq!(res.gglwe_layout(), other.gglwe_layout()); + + let rank_in: usize = res.rank_in().into(); + let dnum: usize = res.dnum().into(); + + for row_i in 0..dnum { + for col_i in 0..rank_in { + self.decompress_glwe(&mut res.at_mut(row_i, col_i), &other.at(row_i, col_i)); + } + } + } +} + +impl GGLWEDecompress for Module where Self: VecZnxFillUniform + VecZnxCopy {} + +impl GGLWE { + pub fn decompress(&mut self, module: &M, other: &O) + where + O: GGLWECompressedToRef, + M: GGLWEDecompress, + { + module.decompress_gglwe(self, other); + } +} + +pub trait GGLWECompressedToMut { + fn to_mut(&mut self) -> GGLWECompressed<&mut [u8]>; +} + +impl GGLWECompressedToMut for GGLWECompressed { + fn to_mut(&mut self) -> GGLWECompressed<&mut [u8]> { + GGLWECompressed { + k: self.k(), + base2k: self.base2k(), + dsize: self.dsize(), + seed: self.seed.clone(), + rank_out: self.rank_out, + data: self.data.to_mut(), + } + } +} + +pub trait GGLWECompressedToRef { + fn to_ref(&self) -> GGLWECompressed<&[u8]>; +} + +impl GGLWECompressedToRef for GGLWECompressed { + fn to_ref(&self) -> GGLWECompressed<&[u8]> { + GGLWECompressed { + k: self.k(), + base2k: self.base2k(), + dsize: self.dsize(), + seed: self.seed.clone(), + rank_out: self.rank_out, + data: self.data.to_ref(), } - - let rank_in: usize = self.rank_in().into(); - let dnum: usize = self.dnum().into(); - - (0..rank_in).for_each(|col_i| { - (0..dnum).for_each(|row_i| { - self.at_mut(row_i, col_i) - .decompress(module, &other.at(row_i, col_i)); - }); - }); } } diff --git a/poulpy-core/src/layouts/compressed/gglwe_ksk.rs b/poulpy-core/src/layouts/compressed/gglwe_ksk.rs index 60d9316..fb963c5 100644 --- a/poulpy-core/src/layouts/compressed/gglwe_ksk.rs +++ b/poulpy-core/src/layouts/compressed/gglwe_ksk.rs @@ -1,25 +1,25 @@ use poulpy_hal::{ - api::{VecZnxCopy, VecZnxFillUniform}, layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo}, source::Source, }; use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWESwitchingKey, GLWEInfos, LWEInfos, Rank, TorusPrecision, - compressed::{Decompress, GGLWECiphertextCompressed}, + Base2K, Dnum, Dsize, GGLWEInfos, GLWEInfos, GLWESwitchingKey, GLWESwitchingKeySetMetaData, GLWESwitchingKeyToMut, LWEInfos, + Rank, RingDegree, TorusPrecision, + compressed::{GGLWECompressed, GGLWECompressedAlloc, GGLWECompressedToMut, GGLWECompressedToRef, GGLWEDecompress}, }; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::fmt; #[derive(PartialEq, Eq, Clone)] -pub struct GGLWESwitchingKeyCompressed { - pub(crate) key: GGLWECiphertextCompressed, +pub struct GLWESwitchingKeyCompressed { + pub(crate) key: GGLWECompressed, pub(crate) sk_in_n: usize, // Degree of sk_in pub(crate) sk_out_n: usize, // Degree of sk_out } -impl LWEInfos for GGLWESwitchingKeyCompressed { - fn n(&self) -> Degree { +impl LWEInfos for GLWESwitchingKeyCompressed { + fn n(&self) -> RingDegree { self.key.n() } @@ -35,13 +35,13 @@ impl LWEInfos for GGLWESwitchingKeyCompressed { self.key.size() } } -impl GLWEInfos for GGLWESwitchingKeyCompressed { +impl GLWEInfos for GLWESwitchingKeyCompressed { fn rank(&self) -> Rank { self.rank_out() } } -impl GGLWEInfos for GGLWESwitchingKeyCompressed { +impl GGLWEInfos for GLWESwitchingKeyCompressed { fn rank_in(&self) -> Rank { self.key.rank_in() } @@ -59,19 +59,19 @@ impl GGLWEInfos for GGLWESwitchingKeyCompressed { } } -impl fmt::Debug for GGLWESwitchingKeyCompressed { +impl fmt::Debug for GLWESwitchingKeyCompressed { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{self}") } } -impl FillUniform for GGLWESwitchingKeyCompressed { +impl FillUniform for GLWESwitchingKeyCompressed { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.key.fill_uniform(log_bound, source); } } -impl fmt::Display for GGLWESwitchingKeyCompressed { +impl fmt::Display for GLWESwitchingKeyCompressed { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, @@ -81,47 +81,100 @@ impl fmt::Display for GGLWESwitchingKeyCompressed { } } -impl GGLWESwitchingKeyCompressed> { - pub fn alloc(infos: &A) -> Self - where - A: GGLWEInfos, - { - GGLWESwitchingKeyCompressed { - key: GGLWECiphertextCompressed::alloc(infos), - sk_in_n: 0, - sk_out_n: 0, - } - } - - pub fn alloc_with( - n: Degree, +pub trait GLWESwitchingKeyCompressedAlloc +where + Self: GGLWECompressedAlloc, +{ + fn alloc_glwe_switching_key_compressed( + &self, base2k: Base2K, k: TorusPrecision, rank_in: Rank, rank_out: Rank, dnum: Dnum, dsize: Dsize, - ) -> Self { - GGLWESwitchingKeyCompressed { - key: GGLWECiphertextCompressed::alloc_with(n, base2k, k, rank_in, rank_out, dnum, dsize), + ) -> GLWESwitchingKeyCompressed> { + GLWESwitchingKeyCompressed { + key: self.alloc_gglwe_compressed(base2k, k, rank_in, rank_out, dnum, dsize), sk_in_n: 0, sk_out_n: 0, } } - pub fn alloc_bytes(infos: &A) -> usize + fn alloc_glwe_switching_key_compressed_from_infos(&self, infos: &A) -> GLWESwitchingKeyCompressed> where A: GGLWEInfos, { - GGLWECiphertextCompressed::alloc_bytes(infos) + self.alloc_glwe_switching_key_compressed( + infos.base2k(), + infos.k(), + infos.rank_in(), + infos.rank_out(), + infos.dnum(), + infos.dsize(), + ) } - pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum, dsize: Dsize) -> usize { - GGLWECiphertextCompressed::alloc_bytes_with(n, base2k, k, rank_in, dnum, dsize) + fn bytes_of_glwe_switching_key_compressed( + &self, + base2k: Base2K, + k: TorusPrecision, + rank_in: Rank, + dnum: Dnum, + dsize: Dsize, + ) -> usize { + self.bytes_of_gglwe_compressed(base2k, k, rank_in, dnum, dsize) + } + + fn bytes_of_glwe_switching_key_compressed_from_infos(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + self.bytes_of_gglwe_compressed_from_infos(infos) } } -impl ReaderFrom for GGLWESwitchingKeyCompressed { +impl GLWESwitchingKeyCompressed> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: GGLWEInfos, + M: GLWESwitchingKeyCompressedAlloc, + { + module.alloc_glwe_switching_key_compressed_from_infos(infos) + } + + pub fn alloc( + module: &M, + base2k: Base2K, + k: TorusPrecision, + rank_in: Rank, + rank_out: Rank, + dnum: Dnum, + dsize: Dsize, + ) -> Self + where + M: GLWESwitchingKeyCompressedAlloc, + { + module.alloc_glwe_switching_key_compressed(base2k, k, rank_in, rank_out, dnum, dsize) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: GLWESwitchingKeyCompressedAlloc, + { + module.bytes_of_glwe_switching_key_compressed_from_infos(infos) + } + + pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum, dsize: Dsize) -> usize + where + M: GLWESwitchingKeyCompressedAlloc, + { + module.bytes_of_glwe_switching_key_compressed(base2k, k, rank_in, dnum, dsize) + } +} + +impl ReaderFrom for GLWESwitchingKeyCompressed { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { self.sk_in_n = reader.read_u64::()? as usize; self.sk_out_n = reader.read_u64::()? as usize; @@ -129,7 +182,7 @@ impl ReaderFrom for GGLWESwitchingKeyCompressed { } } -impl WriterTo for GGLWESwitchingKeyCompressed { +impl WriterTo for GLWESwitchingKeyCompressed { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { writer.write_u64::(self.sk_in_n as u64)?; writer.write_u64::(self.sk_out_n as u64)?; @@ -137,13 +190,64 @@ impl WriterTo for GGLWESwitchingKeyCompressed { } } -impl Decompress> for GGLWESwitchingKey +pub trait GLWESwitchingKeyDecompress where - Module: VecZnxFillUniform + VecZnxCopy, + Self: GGLWEDecompress, { - fn decompress(&mut self, module: &Module, other: &GGLWESwitchingKeyCompressed) { - self.key.decompress(module, &other.key); - self.sk_in_n = other.sk_in_n; - self.sk_out_n = other.sk_out_n; + fn decompress_glwe_switching_key(&self, res: &mut R, other: &O) + where + R: GLWESwitchingKeyToMut + GLWESwitchingKeySetMetaData, + O: GLWESwitchingKeyCompressedToRef, + { + let other: &GLWESwitchingKeyCompressed<&[u8]> = &other.to_ref(); + self.decompress_gglwe(&mut res.to_mut().key, &other.key); + res.set_sk_in_n(other.sk_in_n); + res.set_sk_out_n(other.sk_out_n); + } +} + +impl GLWESwitchingKeyDecompress for Module where Self: GGLWEDecompress {} + +impl GLWESwitchingKey { + pub fn decompress(&mut self, module: &M, other: &O) + where + O: GLWESwitchingKeyCompressedToRef, + M: GLWESwitchingKeyDecompress, + { + module.decompress_glwe_switching_key(self, other); + } +} + +pub trait GLWESwitchingKeyCompressedToMut { + fn to_mut(&mut self) -> GLWESwitchingKeyCompressed<&mut [u8]>; +} + +impl GLWESwitchingKeyCompressedToMut for GLWESwitchingKeyCompressed +where + GGLWECompressed: GGLWECompressedToMut, +{ + fn to_mut(&mut self) -> GLWESwitchingKeyCompressed<&mut [u8]> { + GLWESwitchingKeyCompressed { + sk_in_n: self.sk_in_n, + sk_out_n: self.sk_out_n, + key: self.key.to_mut(), + } + } +} + +pub trait GLWESwitchingKeyCompressedToRef { + fn to_ref(&self) -> GLWESwitchingKeyCompressed<&[u8]>; +} + +impl GLWESwitchingKeyCompressedToRef for GLWESwitchingKeyCompressed +where + GGLWECompressed: GGLWECompressedToRef, +{ + fn to_ref(&self) -> GLWESwitchingKeyCompressed<&[u8]> { + GLWESwitchingKeyCompressed { + sk_in_n: self.sk_in_n, + sk_out_n: self.sk_out_n, + key: self.key.to_ref(), + } } } diff --git a/poulpy-core/src/layouts/compressed/gglwe_tsk.rs b/poulpy-core/src/layouts/compressed/gglwe_tsk.rs index fef4647..ff2b8b8 100644 --- a/poulpy-core/src/layouts/compressed/gglwe_tsk.rs +++ b/poulpy-core/src/layouts/compressed/gglwe_tsk.rs @@ -1,23 +1,25 @@ use poulpy_hal::{ - api::{VecZnxCopy, VecZnxFillUniform}, layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo}, source::Source, }; use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWETensorKey, GLWEInfos, LWEInfos, Rank, TorusPrecision, - compressed::{Decompress, GGLWESwitchingKeyCompressed}, + Base2K, Dnum, Dsize, GGLWEInfos, GLWEInfos, LWEInfos, Rank, RingDegree, TensorKey, TensorKeyToMut, TorusPrecision, + compressed::{ + GLWESwitchingKeyCompressed, GLWESwitchingKeyCompressedAlloc, GLWESwitchingKeyCompressedToMut, + GLWESwitchingKeyCompressedToRef, GLWESwitchingKeyDecompress, + }, }; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::fmt; #[derive(PartialEq, Eq, Clone)] -pub struct GGLWETensorKeyCompressed { - pub(crate) keys: Vec>, +pub struct TensorKeyCompressed { + pub(crate) keys: Vec>, } -impl LWEInfos for GGLWETensorKeyCompressed { - fn n(&self) -> Degree { +impl LWEInfos for TensorKeyCompressed { + fn n(&self) -> RingDegree { self.keys[0].n() } @@ -32,13 +34,13 @@ impl LWEInfos for GGLWETensorKeyCompressed { self.keys[0].size() } } -impl GLWEInfos for GGLWETensorKeyCompressed { +impl GLWEInfos for TensorKeyCompressed { fn rank(&self) -> Rank { self.rank_out() } } -impl GGLWEInfos for GGLWETensorKeyCompressed { +impl GGLWEInfos for TensorKeyCompressed { fn rank_in(&self) -> Rank { self.rank_out() } @@ -56,21 +58,21 @@ impl GGLWEInfos for GGLWETensorKeyCompressed { } } -impl fmt::Debug for GGLWETensorKeyCompressed { +impl fmt::Debug for TensorKeyCompressed { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{self}") } } -impl FillUniform for GGLWETensorKeyCompressed { +impl FillUniform for TensorKeyCompressed { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.keys .iter_mut() - .for_each(|key: &mut GGLWESwitchingKeyCompressed| key.fill_uniform(log_bound, source)) + .for_each(|key: &mut GLWESwitchingKeyCompressed| key.fill_uniform(log_bound, source)) } } -impl fmt::Display for GGLWETensorKeyCompressed { +impl fmt::Display for TensorKeyCompressed { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!(f, "(GLWETensorKeyCompressed)",)?; for (i, key) in self.keys.iter().enumerate() { @@ -80,8 +82,27 @@ impl fmt::Display for GGLWETensorKeyCompressed { } } -impl GGLWETensorKeyCompressed> { - pub fn alloc(infos: &A) -> Self +pub trait TensorKeyCompressedAlloc +where + Self: GLWESwitchingKeyCompressedAlloc, +{ + fn alloc_tensor_key_compressed( + &self, + base2k: Base2K, + k: TorusPrecision, + rank: Rank, + dnum: Dnum, + dsize: Dsize, + ) -> TensorKeyCompressed> { + let pairs: u32 = (((rank.as_u32() + 1) * rank.as_u32()) >> 1).max(1); + TensorKeyCompressed { + keys: (0..pairs) + .map(|_| self.alloc_glwe_switching_key_compressed(base2k, k, Rank(1), rank, dnum, dsize)) + .collect(), + } + } + + fn alloc_tensor_key_compressed_from_infos(&self, infos: &A) -> TensorKeyCompressed> where A: GGLWEInfos, { @@ -90,62 +111,67 @@ impl GGLWETensorKeyCompressed> { infos.rank_out(), "rank_in != rank_out is not supported for GGLWETensorKeyCompressed" ); - Self::alloc_with( - infos.n(), + self.alloc_tensor_key_compressed( infos.base2k(), infos.k(), - infos.rank_out(), + infos.rank(), infos.dnum(), infos.dsize(), ) } - pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self { - let mut keys: Vec>> = Vec::new(); - let pairs: u32 = (((rank.0 + 1) * rank.0) >> 1).max(1); - (0..pairs).for_each(|_| { - keys.push(GGLWESwitchingKeyCompressed::alloc_with( - n, - base2k, - k, - Rank(1), - rank, - dnum, - dsize, - )); - }); - Self { keys } + fn bytes_of_tensor_key_compressed(&self, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { + let pairs: usize = (((rank.0 + 1) * rank.0) >> 1).max(1) as usize; + pairs * self.bytes_of_glwe_switching_key_compressed(base2k, k, Rank(1), dnum, dsize) } - pub fn alloc_bytes(infos: &A) -> usize + fn bytes_of_tensor_key_compressed_from_infos(&self, infos: &A) -> usize where A: GGLWEInfos, { - assert_eq!( - infos.rank_in(), - infos.rank_out(), - "rank_in != rank_out is not supported for GGLWETensorKeyCompressed" - ); - let rank_out: usize = infos.rank_out().into(); - let pairs: usize = (((rank_out + 1) * rank_out) >> 1).max(1); - pairs - * GGLWESwitchingKeyCompressed::alloc_bytes_with( - infos.n(), - infos.base2k(), - infos.k(), - Rank(1), - infos.dnum(), - infos.dsize(), - ) - } - - pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { - let pairs: usize = (((rank.0 + 1) * rank.0) >> 1).max(1) as usize; - pairs * GGLWESwitchingKeyCompressed::alloc_bytes_with(n, base2k, k, Rank(1), dnum, dsize) + self.bytes_of_tensor_key_compressed( + infos.base2k(), + infos.k(), + infos.rank(), + infos.dnum(), + infos.dsize(), + ) } } -impl ReaderFrom for GGLWETensorKeyCompressed { +impl TensorKeyCompressed> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: GGLWEInfos, + M: TensorKeyCompressedAlloc, + { + module.alloc_tensor_key_compressed_from_infos(infos) + } + + pub fn alloc(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self + where + M: TensorKeyCompressedAlloc, + { + module.alloc_tensor_key_compressed(base2k, k, rank, dnum, dsize) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: TensorKeyCompressedAlloc, + { + module.bytes_of_tensor_key_compressed_from_infos(infos) + } + + pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize + where + M: TensorKeyCompressedAlloc, + { + module.bytes_of_tensor_key_compressed(base2k, k, rank, dnum, dsize) + } +} + +impl ReaderFrom for TensorKeyCompressed { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { let len: usize = reader.read_u64::()? as usize; if self.keys.len() != len { @@ -161,7 +187,7 @@ impl ReaderFrom for GGLWETensorKeyCompressed { } } -impl WriterTo for GGLWETensorKeyCompressed { +impl WriterTo for TensorKeyCompressed { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { writer.write_u64::(self.keys.len() as u64)?; for key in &self.keys { @@ -171,8 +197,8 @@ impl WriterTo for GGLWETensorKeyCompressed { } } -impl GGLWETensorKeyCompressed { - pub(crate) fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GGLWESwitchingKeyCompressed { +impl TensorKeyCompressed { + pub(crate) fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GLWESwitchingKeyCompressed { if i > j { std::mem::swap(&mut i, &mut j); }; @@ -181,27 +207,70 @@ impl GGLWETensorKeyCompressed { } } -impl Decompress> for GGLWETensorKey +pub trait TensorKeyDecompress where - Module: VecZnxFillUniform + VecZnxCopy, + Self: GLWESwitchingKeyDecompress, { - fn decompress(&mut self, module: &Module, other: &GGLWETensorKeyCompressed) { - #[cfg(debug_assertions)] - { - assert_eq!( - self.keys.len(), - other.keys.len(), - "invalid receiver: self.keys.len()={} != other.keys.len()={}", - self.keys.len(), - other.keys.len() - ); - } + fn decompress_tensor_key(&self, res: &mut R, other: &O) + where + R: TensorKeyToMut, + O: TensorKeyCompressedToRef, + { + let res: &mut TensorKey<&mut [u8]> = &mut res.to_mut(); + let other: &TensorKeyCompressed<&[u8]> = &other.to_ref(); - self.keys - .iter_mut() - .zip(other.keys.iter()) - .for_each(|(a, b)| { - a.decompress(module, b); - }); + assert_eq!( + res.keys.len(), + other.keys.len(), + "invalid receiver: res.keys.len()={} != other.keys.len()={}", + res.keys.len(), + other.keys.len() + ); + + for (a, b) in res.keys.iter_mut().zip(other.keys.iter()) { + self.decompress_glwe_switching_key(a, b); + } + } +} + +impl TensorKeyDecompress for Module where Self: GLWESwitchingKeyDecompress {} + +impl TensorKey { + pub fn decompress(&mut self, module: &M, other: &O) + where + O: TensorKeyCompressedToRef, + M: TensorKeyDecompress, + { + module.decompress_tensor_key(self, other); + } +} + +pub trait TensorKeyCompressedToMut { + fn to_mut(&mut self) -> TensorKeyCompressed<&mut [u8]>; +} + +impl TensorKeyCompressedToMut for TensorKeyCompressed +where + GLWESwitchingKeyCompressed: GLWESwitchingKeyCompressedToMut, +{ + fn to_mut(&mut self) -> TensorKeyCompressed<&mut [u8]> { + TensorKeyCompressed { + keys: self.keys.iter_mut().map(|c| c.to_mut()).collect(), + } + } +} + +pub trait TensorKeyCompressedToRef { + fn to_ref(&self) -> TensorKeyCompressed<&[u8]>; +} + +impl TensorKeyCompressedToRef for TensorKeyCompressed +where + GLWESwitchingKeyCompressed: GLWESwitchingKeyCompressedToRef, +{ + fn to_ref(&self) -> TensorKeyCompressed<&[u8]> { + TensorKeyCompressed { + keys: self.keys.iter().map(|c| c.to_ref()).collect(), + } } } diff --git a/poulpy-core/src/layouts/compressed/ggsw_ct.rs b/poulpy-core/src/layouts/compressed/ggsw_ct.rs index f0a62cc..adad621 100644 --- a/poulpy-core/src/layouts/compressed/ggsw_ct.rs +++ b/poulpy-core/src/layouts/compressed/ggsw_ct.rs @@ -1,18 +1,19 @@ use poulpy_hal::{ - api::{VecZnxCopy, VecZnxFillUniform}, - layouts::{Backend, Data, DataMut, DataRef, FillUniform, MatZnx, Module, ReaderFrom, WriterTo, ZnxInfos}, + layouts::{ + Backend, Data, DataMut, DataRef, FillUniform, MatZnx, MatZnxToMut, MatZnxToRef, Module, ReaderFrom, WriterTo, ZnxInfos, + }, source::Source, }; use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGSWCiphertext, GGSWInfos, GLWEInfos, LWEInfos, Rank, TorusPrecision, - compressed::{Decompress, GLWECiphertextCompressed}, + Base2K, Dnum, Dsize, GGSW, GGSWInfos, GGSWToMut, GLWEInfos, GetRingDegree, LWEInfos, Rank, RingDegree, TorusPrecision, + compressed::{GLWECompressed, GLWEDecompress}, }; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::fmt; #[derive(PartialEq, Eq, Clone)] -pub struct GGSWCiphertextCompressed { +pub struct GGSWCompressed { pub(crate) data: MatZnx, pub(crate) k: TorusPrecision, pub(crate) base2k: Base2K, @@ -21,9 +22,9 @@ pub struct GGSWCiphertextCompressed { pub(crate) seed: Vec<[u8; 32]>, } -impl LWEInfos for GGSWCiphertextCompressed { - fn n(&self) -> Degree { - Degree(self.data.n() as u32) +impl LWEInfos for GGSWCompressed { + fn n(&self) -> RingDegree { + RingDegree(self.data.n() as u32) } fn base2k(&self) -> Base2K { @@ -37,13 +38,13 @@ impl LWEInfos for GGSWCiphertextCompressed { self.data.size() } } -impl GLWEInfos for GGSWCiphertextCompressed { +impl GLWEInfos for GGSWCompressed { fn rank(&self) -> Rank { self.rank } } -impl GGSWInfos for GGSWCiphertextCompressed { +impl GGSWInfos for GGSWCompressed { fn dsize(&self) -> Dsize { self.dsize } @@ -53,46 +54,42 @@ impl GGSWInfos for GGSWCiphertextCompressed { } } -impl fmt::Debug for GGSWCiphertextCompressed { +impl fmt::Debug for GGSWCompressed { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}", self.data) } } -impl fmt::Display for GGSWCiphertextCompressed { +impl fmt::Display for GGSWCompressed { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "(GGSWCiphertextCompressed: base2k={} k={} dsize={}) {}", + "(GGSWCompressed: base2k={} k={} dsize={}) {}", self.base2k, self.k, self.dsize, self.data ) } } -impl FillUniform for GGSWCiphertextCompressed { +impl FillUniform for GGSWCompressed { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.data.fill_uniform(log_bound, source); } } -impl GGSWCiphertextCompressed> { - pub fn alloc(infos: &A) -> Self - where - A: GGSWInfos, - { - Self::alloc_with( - infos.n(), - infos.base2k(), - infos.k(), - infos.rank(), - infos.dnum(), - infos.dsize(), - ) - } - - pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self { +pub trait GGSWCompressedAlloc +where + Self: GetRingDegree, +{ + fn alloc_ggsw_compressed( + &self, + base2k: Base2K, + k: TorusPrecision, + rank: Rank, + dnum: Dnum, + dsize: Dsize, + ) -> GGSWCompressed> { let size: usize = k.0.div_ceil(base2k.0) as usize; - debug_assert!( + assert!( size as u32 > dsize.0, "invalid ggsw: ceil(k/base2k): {size} <= dsize: {}", dsize.0 @@ -105,9 +102,9 @@ impl GGSWCiphertextCompressed> { dsize.0, ); - Self { + GGSWCompressed { data: MatZnx::alloc( - n.into(), + self.ring_degree().into(), dnum.into(), (rank + 1).into(), 1, @@ -121,12 +118,11 @@ impl GGSWCiphertextCompressed> { } } - pub fn alloc_bytes(infos: &A) -> usize + fn alloc_ggsw_compressed_from_infos(&self, infos: &A) -> GGSWCompressed> where A: GGSWInfos, { - Self::alloc_bytes_with( - infos.n(), + self.alloc_ggsw_compressed( infos.base2k(), infos.k(), infos.rank(), @@ -135,9 +131,9 @@ impl GGSWCiphertextCompressed> { ) } - pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { + fn bytes_of_ggsw_compressed(&self, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { let size: usize = k.0.div_ceil(base2k.0) as usize; - debug_assert!( + assert!( size as u32 > dsize.0, "invalid ggsw: ceil(k/base2k): {size} <= dsize: {}", dsize.0 @@ -150,20 +146,65 @@ impl GGSWCiphertextCompressed> { dsize.0, ); - MatZnx::alloc_bytes( - n.into(), + MatZnx::bytes_of( + self.ring_degree().into(), dnum.into(), (rank + 1).into(), 1, k.0.div_ceil(base2k.0) as usize, ) } + + fn bytes_of_ggsw_compressed_key_from_infos(&self, infos: &A) -> usize + where + A: GGSWInfos, + { + self.bytes_of_ggsw_compressed( + infos.base2k(), + infos.k(), + infos.rank(), + infos.dnum(), + infos.dsize(), + ) + } } -impl GGSWCiphertextCompressed { - pub fn at(&self, row: usize, col: usize) -> GLWECiphertextCompressed<&[u8]> { +impl GGSWCompressed> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: GGSWInfos, + M: GGSWCompressedAlloc, + { + module.alloc_ggsw_compressed_from_infos(infos) + } + + pub fn alloc(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self + where + M: GGSWCompressedAlloc, + { + module.alloc_ggsw_compressed(base2k, k, rank, dnum, dsize) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: GGSWInfos, + M: GGSWCompressedAlloc, + { + module.bytes_of_ggsw_compressed_key_from_infos(infos) + } + + pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize + where + M: GGSWCompressedAlloc, + { + module.bytes_of_ggsw_compressed(base2k, k, rank, dnum, dsize) + } +} + +impl GGSWCompressed { + pub fn at(&self, row: usize, col: usize) -> GLWECompressed<&[u8]> { let rank: usize = self.rank().into(); - GLWECiphertextCompressed { + GLWECompressed { data: self.data.at(row, col), k: self.k, base2k: self.base2k, @@ -173,10 +214,10 @@ impl GGSWCiphertextCompressed { } } -impl GGSWCiphertextCompressed { - pub fn at_mut(&mut self, row: usize, col: usize) -> GLWECiphertextCompressed<&mut [u8]> { +impl GGSWCompressed { + pub fn at_mut(&mut self, row: usize, col: usize) -> GLWECompressed<&mut [u8]> { let rank: usize = self.rank().into(); - GLWECiphertextCompressed { + GLWECompressed { data: self.data.at_mut(row, col), k: self.k, base2k: self.base2k, @@ -186,7 +227,7 @@ impl GGSWCiphertextCompressed { } } -impl ReaderFrom for GGSWCiphertextCompressed { +impl ReaderFrom for GGSWCompressed { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { self.k = TorusPrecision(reader.read_u32::()?); self.base2k = Base2K(reader.read_u32::()?); @@ -201,7 +242,7 @@ impl ReaderFrom for GGSWCiphertextCompressed { } } -impl WriterTo for GGSWCiphertextCompressed { +impl WriterTo for GGSWCompressed { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { writer.write_u32::(self.k.into())?; writer.write_u32::(self.base2k.into())?; @@ -215,23 +256,72 @@ impl WriterTo for GGSWCiphertextCompressed { } } -impl Decompress> for GGSWCiphertext +pub trait GGSWDecompress where - Module: VecZnxFillUniform + VecZnxCopy, + Self: GLWEDecompress, { - fn decompress(&mut self, module: &Module, other: &GGSWCiphertextCompressed) { - #[cfg(debug_assertions)] - { - assert_eq!(self.rank(), other.rank()) - } + fn decompress_ggsw(&self, res: &mut R, other: &O) + where + R: GGSWToMut, + O: GGSWCompressedToRef, + { + let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); + let other: &GGSWCompressed<&[u8]> = &other.to_ref(); - let dnum: usize = self.dnum().into(); - let rank: usize = self.rank().into(); - (0..dnum).for_each(|row_i| { - (0..rank + 1).for_each(|col_j| { - self.at_mut(row_i, col_j) - .decompress(module, &other.at(row_i, col_j)); - }); - }); + assert_eq!(res.rank(), other.rank()); + let dnum: usize = res.dnum().into(); + let rank: usize = res.rank().into(); + + for row_i in 0..dnum { + for col_j in 0..rank + 1 { + self.decompress_glwe(&mut res.at_mut(row_i, col_j), &other.at(row_i, col_j)); + } + } + } +} + +impl GGSWDecompress for Module where Self: GGSWDecompress {} + +impl GGSW { + pub fn decompress(&mut self, module: &M, other: &O) + where + O: GGSWCompressedToRef, + M: GGSWDecompress, + { + module.decompress_ggsw(self, other); + } +} + +pub trait GGSWCompressedToMut { + fn to_mut(&mut self) -> GGSWCompressed<&mut [u8]>; +} + +impl GGSWCompressedToMut for GGSWCompressed { + fn to_mut(&mut self) -> GGSWCompressed<&mut [u8]> { + GGSWCompressed { + k: self.k(), + base2k: self.base2k(), + dsize: self.dsize(), + rank: self.rank(), + seed: self.seed.clone(), + data: self.data.to_mut(), + } + } +} + +pub trait GGSWCompressedToRef { + fn to_ref(&self) -> GGSWCompressed<&[u8]>; +} + +impl GGSWCompressedToRef for GGSWCompressed { + fn to_ref(&self) -> GGSWCompressed<&[u8]> { + GGSWCompressed { + k: self.k(), + base2k: self.base2k(), + dsize: self.dsize(), + rank: self.rank(), + seed: self.seed.clone(), + data: self.data.to_ref(), + } } } diff --git a/poulpy-core/src/layouts/compressed/glwe_ct.rs b/poulpy-core/src/layouts/compressed/glwe_ct.rs index 30a3733..e558cf7 100644 --- a/poulpy-core/src/layouts/compressed/glwe_ct.rs +++ b/poulpy-core/src/layouts/compressed/glwe_ct.rs @@ -1,15 +1,19 @@ use poulpy_hal::{ api::{VecZnxCopy, VecZnxFillUniform}, - layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, VecZnx, WriterTo, ZnxInfos}, + layouts::{ + Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, VecZnx, VecZnxToMut, VecZnxToRef, WriterTo, ZnxInfos, + }, source::Source, }; -use crate::layouts::{Base2K, Degree, GLWECiphertext, GLWEInfos, LWEInfos, Rank, TorusPrecision, compressed::Decompress}; +use crate::layouts::{ + Base2K, GLWE, GLWEInfos, GLWEToMut, GetRingDegree, LWEInfos, Rank, RingDegree, SetGLWEInfos, TorusPrecision, +}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::fmt; #[derive(PartialEq, Eq, Clone)] -pub struct GLWECiphertextCompressed { +pub struct GLWECompressed { pub(crate) data: VecZnx, pub(crate) base2k: Base2K, pub(crate) k: TorusPrecision, @@ -17,7 +21,7 @@ pub struct GLWECiphertextCompressed { pub(crate) seed: [u8; 32], } -impl LWEInfos for GLWECiphertextCompressed { +impl LWEInfos for GLWECompressed { fn base2k(&self) -> Base2K { self.base2k } @@ -30,27 +34,27 @@ impl LWEInfos for GLWECiphertextCompressed { self.data.size() } - fn n(&self) -> Degree { - Degree(self.data.n() as u32) + fn n(&self) -> RingDegree { + RingDegree(self.data.n() as u32) } } -impl GLWEInfos for GLWECiphertextCompressed { +impl GLWEInfos for GLWECompressed { fn rank(&self) -> Rank { self.rank } } -impl fmt::Debug for GLWECiphertextCompressed { +impl fmt::Debug for GLWECompressed { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{self}") } } -impl fmt::Display for GLWECiphertextCompressed { +impl fmt::Display for GLWECompressed { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "GLWECiphertextCompressed: base2k={} k={} rank={} seed={:?}: {}", + "GLWECompressed: base2k={} k={} rank={} seed={:?}: {}", self.base2k(), self.k(), self.rank(), @@ -60,23 +64,23 @@ impl fmt::Display for GLWECiphertextCompressed { } } -impl FillUniform for GLWECiphertextCompressed { +impl FillUniform for GLWECompressed { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.data.fill_uniform(log_bound, source); } } -impl GLWECiphertextCompressed> { - pub fn alloc(infos: &A) -> Self - where - A: GLWEInfos, - { - Self::alloc_with(infos.n(), infos.base2k(), infos.k(), infos.rank()) - } - - pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self { - Self { - data: VecZnx::alloc(n.into(), 1, k.0.div_ceil(base2k.0) as usize), +pub trait GLWECompressedAlloc +where + Self: GetRingDegree, +{ + fn alloc_glwe_compressed(&self, base2k: Base2K, k: TorusPrecision, rank: Rank) -> GLWECompressed> { + GLWECompressed { + data: VecZnx::alloc( + self.ring_degree().into(), + 1, + k.0.div_ceil(base2k.0) as usize, + ), base2k, k, rank, @@ -84,19 +88,66 @@ impl GLWECiphertextCompressed> { } } - pub fn alloc_bytes(infos: &A) -> usize + fn alloc_glwe_compressed_from_infos(&self, infos: &A) -> GLWECompressed> where A: GLWEInfos, { - Self::alloc_bytes_with(infos.n(), infos.base2k(), infos.k()) + assert_eq!(self.ring_degree(), infos.n()); + self.alloc_glwe_compressed(infos.base2k(), infos.k(), infos.rank()) } - pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision) -> usize { - VecZnx::alloc_bytes(n.into(), 1, k.0.div_ceil(base2k.0) as usize) + fn bytes_of_glwe_compressed(&self, base2k: Base2K, k: TorusPrecision) -> usize { + VecZnx::bytes_of( + self.ring_degree().into(), + 1, + k.0.div_ceil(base2k.0) as usize, + ) + } + + fn bytes_of_glwe_compressed_from_infos(&self, infos: &A) -> usize + where + A: GLWEInfos, + { + assert_eq!(self.ring_degree(), infos.n()); + self.bytes_of_glwe_compressed(infos.base2k(), infos.k()) } } -impl ReaderFrom for GLWECiphertextCompressed { +impl GLWECompressedAlloc for Module where Self: GetRingDegree {} + +impl GLWECompressed> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: GLWEInfos, + M: GLWECompressedAlloc, + { + module.alloc_glwe_compressed_from_infos(infos) + } + + pub fn alloc(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self + where + M: GLWECompressedAlloc, + { + module.alloc_glwe_compressed(base2k, k, rank) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: GLWEInfos, + M: GLWECompressedAlloc, + { + module.bytes_of_glwe_compressed_from_infos(infos) + } + + pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision) -> usize + where + M: GLWECompressedAlloc, + { + module.bytes_of_glwe_compressed(base2k, k) + } +} + +impl ReaderFrom for GLWECompressed { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { self.k = TorusPrecision(reader.read_u32::()?); self.base2k = Base2K(reader.read_u32::()?); @@ -106,7 +157,7 @@ impl ReaderFrom for GLWECiphertextCompressed { } } -impl WriterTo for GLWECiphertextCompressed { +impl WriterTo for GLWECompressed { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { writer.write_u32::(self.k.into())?; writer.write_u32::(self.base2k.into())?; @@ -116,63 +167,82 @@ impl WriterTo for GLWECiphertextCompressed { } } -impl Decompress> for GLWECiphertext +pub trait GLWEDecompress where - Module: VecZnxFillUniform + VecZnxCopy, + Self: GetRingDegree + VecZnxFillUniform + VecZnxCopy, { - fn decompress(&mut self, module: &Module, other: &GLWECiphertextCompressed) { - #[cfg(debug_assertions)] - { - assert_eq!( - self.n(), - other.n(), - "invalid receiver: self.n()={} != other.n()={}", - self.n(), - other.n() - ); - assert_eq!( - self.size(), - other.size(), - "invalid receiver: self.size()={} != other.size()={}", - self.size(), - other.size() - ); - assert_eq!( - self.rank(), - other.rank(), - "invalid receiver: self.rank()={} != other.rank()={}", - self.rank(), - other.rank() - ); - } - - let mut source: Source = Source::new(other.seed); - self.decompress_internal(module, other, &mut source); - } -} - -impl GLWECiphertext { - pub(crate) fn decompress_internal( - &mut self, - module: &Module, - other: &GLWECiphertextCompressed, - source: &mut Source, - ) where - DataOther: DataRef, - Module: VecZnxCopy + VecZnxFillUniform, + fn decompress_glwe(&self, res: &mut R, other: &O) + where + R: GLWEToMut + SetGLWEInfos, + O: GLWECompressedToRef + GLWEInfos, { - #[cfg(debug_assertions)] { - assert_eq!(self.rank(), other.rank()); - debug_assert_eq!(self.size(), other.size()); + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let other: &GLWECompressed<&[u8]> = &other.to_ref(); + assert_eq!( + res.n(), + self.ring_degree(), + "invalid receiver: res.n()={} != other.n()={}", + res.n(), + self.ring_degree() + ); + + assert_eq!(res.lwe_layout(), other.lwe_layout()); + assert_eq!(res.glwe_layout(), other.glwe_layout()); + + let mut source: Source = Source::new(other.seed); + + self.vec_znx_copy(&mut res.data, 0, &other.data, 0); + (1..(other.rank() + 1).into()).for_each(|i| { + self.vec_znx_fill_uniform(other.base2k.into(), &mut res.data, i, &mut source); + }); } - module.vec_znx_copy(&mut self.data, 0, &other.data, 0); - (1..(other.rank() + 1).into()).for_each(|i| { - module.vec_znx_fill_uniform(other.base2k.into(), &mut self.data, i, source); - }); - - self.base2k = other.base2k; - self.k = other.k; + res.set_base2k(other.base2k()); + res.set_k(other.k()); + } +} + +impl GLWEDecompress for Module where Self: GetRingDegree + VecZnxFillUniform + VecZnxCopy {} + +impl GLWE { + pub fn decompress(&mut self, module: &M, other: &O) + where + O: GLWECompressedToRef + GLWEInfos, + M: GLWEDecompress, + { + module.decompress_glwe(self, other); + } +} + +pub trait GLWECompressedToRef { + fn to_ref(&self) -> GLWECompressed<&[u8]>; +} + +impl GLWECompressedToRef for GLWECompressed { + fn to_ref(&self) -> GLWECompressed<&[u8]> { + GLWECompressed { + seed: self.seed.clone(), + base2k: self.base2k, + k: self.k, + rank: self.rank, + data: self.data.to_ref(), + } + } +} + +pub trait GLWECompressedToMut { + fn to_mut(&mut self) -> GLWECompressed<&mut [u8]>; +} + +impl GLWECompressedToMut for GLWECompressed { + fn to_mut(&mut self) -> GLWECompressed<&mut [u8]> { + GLWECompressed { + seed: self.seed.clone(), + base2k: self.base2k, + k: self.k, + rank: self.rank, + data: self.data.to_mut(), + } } } diff --git a/poulpy-core/src/layouts/compressed/glwe_to_lwe_ksk.rs b/poulpy-core/src/layouts/compressed/glwe_to_lwe_ksk.rs index 63933e8..5bfa3eb 100644 --- a/poulpy-core/src/layouts/compressed/glwe_to_lwe_ksk.rs +++ b/poulpy-core/src/layouts/compressed/glwe_to_lwe_ksk.rs @@ -1,16 +1,21 @@ use std::fmt; use poulpy_hal::{ - layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, WriterTo}, + layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo}, source::Source, }; use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWEInfos, GLWEInfos, LWEInfos, Rank, TorusPrecision, compressed::GGLWESwitchingKeyCompressed, + Base2K, Dnum, Dsize, GGLWEInfos, GLWEInfos, GLWEToLWESwitchingKey, GLWEToLWESwitchingKeyToMut, LWEInfos, Rank, RingDegree, + TorusPrecision, + compressed::{ + GLWESwitchingKeyCompressed, GLWESwitchingKeyCompressedAlloc, GLWESwitchingKeyCompressedToMut, + GLWESwitchingKeyCompressedToRef, GLWESwitchingKeyDecompress, + }, }; #[derive(PartialEq, Eq, Clone)] -pub struct GLWEToLWESwitchingKeyCompressed(pub(crate) GGLWESwitchingKeyCompressed); +pub struct GLWEToLWESwitchingKeyCompressed(pub(crate) GLWESwitchingKeyCompressed); impl LWEInfos for GLWEToLWESwitchingKeyCompressed { fn base2k(&self) -> Base2K { @@ -21,7 +26,7 @@ impl LWEInfos for GLWEToLWESwitchingKeyCompressed { self.0.k() } - fn n(&self) -> Degree { + fn n(&self) -> RingDegree { self.0.n() } fn size(&self) -> usize { @@ -83,54 +88,146 @@ impl WriterTo for GLWEToLWESwitchingKeyCompressed { } } -impl GLWEToLWESwitchingKeyCompressed> { - pub fn alloc(infos: &A) -> Self +pub trait GLWEToLWESwitchingKeyCompressedAlloc +where + Self: GLWESwitchingKeyCompressedAlloc, +{ + fn alloc_glwe_to_lwe_switching_key_compressed( + &self, + base2k: Base2K, + k: TorusPrecision, + rank_in: Rank, + dnum: Dnum, + ) -> GLWEToLWESwitchingKeyCompressed> { + GLWEToLWESwitchingKeyCompressed(self.alloc_glwe_switching_key_compressed(base2k, k, rank_in, Rank(1), dnum, Dsize(1))) + } + + fn alloc_glwe_to_lwe_switching_key_compressed_from_infos(&self, infos: &A) -> GLWEToLWESwitchingKeyCompressed> where A: GGLWEInfos, { - debug_assert_eq!( + assert_eq!( infos.rank_out().0, 1, "rank_out > 1 is unsupported for GLWEToLWESwitchingKeyCompressed" ); - debug_assert_eq!( + assert_eq!( infos.dsize().0, 1, "dsize > 1 is unsupported for GLWEToLWESwitchingKeyCompressed" ); - Self(GGLWESwitchingKeyCompressed::alloc(infos)) + self.alloc_glwe_to_lwe_switching_key_compressed(infos.base2k(), infos.k(), infos.rank_in(), infos.dnum()) } - pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> Self { - Self(GGLWESwitchingKeyCompressed::alloc_with( - n, - base2k, - k, - rank_in, - Rank(1), - dnum, - Dsize(1), - )) + fn bytes_of_glwe_to_lwe_switching_key_compressed( + &self, + base2k: Base2K, + k: TorusPrecision, + rank_in: Rank, + dnum: Dnum, + ) -> usize { + self.bytes_of_glwe_switching_key_compressed(base2k, k, rank_in, dnum, Dsize(1)) } - pub fn alloc_bytes(infos: &A) -> usize + fn bytes_of_glwe_to_lwe_switching_key_compressed_from_infos(&self, infos: &A) -> usize where A: GGLWEInfos, { - debug_assert_eq!( + assert_eq!( infos.rank_out().0, 1, "rank_out > 1 is unsupported for GLWEToLWESwitchingKeyCompressed" ); - debug_assert_eq!( + assert_eq!( infos.dsize().0, 1, "dsize > 1 is unsupported for GLWEToLWESwitchingKeyCompressed" ); - GGLWESwitchingKeyCompressed::alloc_bytes(infos) - } - - pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, dnum: Dnum, rank_in: Rank) -> usize { - GGLWESwitchingKeyCompressed::alloc_bytes_with(n, base2k, k, rank_in, dnum, Dsize(1)) + self.bytes_of_glwe_switching_key_compressed_from_infos(infos) + } +} + +impl GLWEToLWESwitchingKeyCompressedAlloc for Module where Self: GLWESwitchingKeyCompressedAlloc {} + +impl GLWEToLWESwitchingKeyCompressed> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: GGLWEInfos, + M: GLWEToLWESwitchingKeyCompressedAlloc, + { + module.alloc_glwe_to_lwe_switching_key_compressed_from_infos(infos) + } + + pub fn alloc(module: &M, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> Self + where + M: GLWEToLWESwitchingKeyCompressedAlloc, + { + module.alloc_glwe_to_lwe_switching_key_compressed(base2k, k, rank_in, dnum) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: GLWEToLWESwitchingKeyCompressedAlloc, + { + module.bytes_of_glwe_to_lwe_switching_key_compressed_from_infos(infos) + } + + pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision, dnum: Dnum, rank_in: Rank) -> usize + where + M: GLWEToLWESwitchingKeyCompressedAlloc, + { + module.bytes_of_glwe_to_lwe_switching_key_compressed(base2k, k, rank_in, dnum) + } +} + +pub trait GLWEToLWESwitchingKeyDecompress +where + Self: GLWESwitchingKeyDecompress, +{ + fn decompress_glwe_to_lwe_switching_key(&self, res: &mut R, other: &O) + where + R: GLWEToLWESwitchingKeyToMut, + O: GLWEToLWESwitchingKeyCompressedToRef, + { + self.decompress_glwe_switching_key(&mut res.to_mut().0, &other.to_ref().0); + } +} + +impl GLWEToLWESwitchingKeyDecompress for Module where Self: GLWESwitchingKeyDecompress {} + +impl GLWEToLWESwitchingKey { + pub fn decompress(&mut self, module: &M, other: &O) + where + O: GLWEToLWESwitchingKeyCompressedToRef, + M: GLWEToLWESwitchingKeyDecompress, + { + module.decompress_glwe_to_lwe_switching_key(self, other); + } +} + +pub trait GLWEToLWESwitchingKeyCompressedToRef { + fn to_ref(&self) -> GLWEToLWESwitchingKeyCompressed<&[u8]>; +} + +impl GLWEToLWESwitchingKeyCompressedToRef for GLWEToLWESwitchingKeyCompressed +where + GLWESwitchingKeyCompressed: GLWESwitchingKeyCompressedToRef, +{ + fn to_ref(&self) -> GLWEToLWESwitchingKeyCompressed<&[u8]> { + GLWEToLWESwitchingKeyCompressed(self.0.to_ref()) + } +} + +pub trait GLWEToLWESwitchingKeyCompressedToMut { + fn to_mut(&mut self) -> GLWEToLWESwitchingKeyCompressed<&mut [u8]>; +} + +impl GLWEToLWESwitchingKeyCompressedToMut for GLWEToLWESwitchingKeyCompressed +where + GLWESwitchingKeyCompressed: GLWESwitchingKeyCompressedToMut, +{ + fn to_mut(&mut self) -> GLWEToLWESwitchingKeyCompressed<&mut [u8]> { + GLWEToLWESwitchingKeyCompressed(self.0.to_mut()) } } diff --git a/poulpy-core/src/layouts/compressed/lwe_ct.rs b/poulpy-core/src/layouts/compressed/lwe_ct.rs index e11b3f3..0393758 100644 --- a/poulpy-core/src/layouts/compressed/lwe_ct.rs +++ b/poulpy-core/src/layouts/compressed/lwe_ct.rs @@ -2,21 +2,24 @@ use std::fmt; use poulpy_hal::{ api::ZnFillUniform, - layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo, Zn, ZnxInfos, ZnxView, ZnxViewMut}, + layouts::{ + Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo, Zn, ZnToMut, ZnToRef, ZnxInfos, ZnxView, + ZnxViewMut, + }, source::Source, }; -use crate::layouts::{Base2K, Degree, LWECiphertext, LWEInfos, TorusPrecision, compressed::Decompress}; +use crate::layouts::{Base2K, LWE, LWEInfos, LWEToMut, RingDegree, TorusPrecision}; #[derive(PartialEq, Eq, Clone)] -pub struct LWECiphertextCompressed { +pub struct LWECompressed { pub(crate) data: Zn, pub(crate) k: TorusPrecision, pub(crate) base2k: Base2K, pub(crate) seed: [u8; 32], } -impl LWEInfos for LWECiphertextCompressed { +impl LWEInfos for LWECompressed { fn base2k(&self) -> Base2K { self.base2k } @@ -25,8 +28,8 @@ impl LWEInfos for LWECiphertextCompressed { self.k } - fn n(&self) -> Degree { - Degree(self.data.n() as u32) + fn n(&self) -> RingDegree { + RingDegree(self.data.n() as u32) } fn size(&self) -> usize { @@ -34,17 +37,17 @@ impl LWEInfos for LWECiphertextCompressed { } } -impl fmt::Debug for LWECiphertextCompressed { +impl fmt::Debug for LWECompressed { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{self}") } } -impl fmt::Display for LWECiphertextCompressed { +impl fmt::Display for LWECompressed { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "LWECiphertextCompressed: base2k={} k={} seed={:?}: {}", + "LWECompressed: base2k={} k={} seed={:?}: {}", self.base2k(), self.k(), self.seed, @@ -53,22 +56,15 @@ impl fmt::Display for LWECiphertextCompressed { } } -impl FillUniform for LWECiphertextCompressed { +impl FillUniform for LWECompressed { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.data.fill_uniform(log_bound, source); } } -impl LWECiphertextCompressed> { - pub fn alloc(infos: &A) -> Self - where - A: LWEInfos, - { - Self::alloc_with(infos.base2k(), infos.k()) - } - - pub fn alloc_with(base2k: Base2K, k: TorusPrecision) -> Self { - Self { +pub trait LWECompressedAlloc { + fn alloc_lwe_compressed(&self, base2k: Base2K, k: TorusPrecision) -> LWECompressed> { + LWECompressed { data: Zn::alloc(1, 1, k.0.div_ceil(base2k.0) as usize), k, base2k, @@ -76,21 +72,62 @@ impl LWECiphertextCompressed> { } } - pub fn alloc_bytes(infos: &A) -> usize + fn alloc_lwe_compressed_from_infos(&self, infos: &A) -> LWECompressed> where A: LWEInfos, { - Self::alloc_bytes_with(infos.base2k(), infos.k()) + self.alloc_lwe_compressed(infos.base2k(), infos.k()) } - pub fn alloc_bytes_with(base2k: Base2K, k: TorusPrecision) -> usize { - Zn::alloc_bytes(1, 1, k.0.div_ceil(base2k.0) as usize) + fn bytes_of_lwe_compressed(&self, base2k: Base2K, k: TorusPrecision) -> usize { + Zn::bytes_of(1, 1, k.0.div_ceil(base2k.0) as usize) + } + + fn bytes_of_lwe_compressed_from_infos(&self, infos: &A) -> usize + where + A: LWEInfos, + { + self.bytes_of_lwe_compressed(infos.base2k(), infos.k()) + } +} + +impl LWECompressedAlloc for Module {} + +impl LWECompressed> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: LWEInfos, + M: LWECompressedAlloc, + { + module.alloc_lwe_compressed_from_infos(infos) + } + + pub fn alloc(module: &M, base2k: Base2K, k: TorusPrecision) -> Self + where + M: LWECompressedAlloc, + { + module.alloc_lwe_compressed(base2k, k) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: LWEInfos, + M: LWECompressedAlloc, + { + module.bytes_of_lwe_compressed_from_infos(infos) + } + + pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision) -> usize + where + M: LWECompressedAlloc, + { + module.bytes_of_lwe_compressed(base2k, k) } } use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; -impl ReaderFrom for LWECiphertextCompressed { +impl ReaderFrom for LWECompressed { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { self.k = TorusPrecision(reader.read_u32::()?); self.base2k = Base2K(reader.read_u32::()?); @@ -99,7 +136,7 @@ impl ReaderFrom for LWECiphertextCompressed { } } -impl WriterTo for LWECiphertextCompressed { +impl WriterTo for LWECompressed { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { writer.write_u32::(self.k.into())?; writer.write_u32::(self.base2k.into())?; @@ -108,22 +145,72 @@ impl WriterTo for LWECiphertextCompressed { } } -impl Decompress> for LWECiphertext +pub trait LWEDecompress where - Module: ZnFillUniform, + Self: ZnFillUniform, { - fn decompress(&mut self, module: &Module, other: &LWECiphertextCompressed) { - debug_assert_eq!(self.size(), other.size()); + fn decompress_lwe(&self, res: &mut R, other: &O) + where + R: LWEToMut, + O: LWECompressedToRef, + { + let res: &mut LWE<&mut [u8]> = &mut res.to_mut(); + let other: &LWECompressed<&[u8]> = &other.to_ref(); + + assert_eq!(res.lwe_layout(), other.lwe_layout()); + let mut source: Source = Source::new(other.seed); - module.zn_fill_uniform( - self.n().into(), + self.zn_fill_uniform( + res.n().into(), other.base2k().into(), - &mut self.data, + &mut res.data, 0, &mut source, ); - (0..self.size()).for_each(|i| { - self.data.at_mut(0, i)[0] = other.data.at(0, i)[0]; - }); + for i in 0..res.size() { + res.data.at_mut(0, i)[0] = other.data.at(0, i)[0]; + } + } +} + +impl LWEDecompress for Module where Self: ZnFillUniform {} + +impl LWE { + pub fn decompress(&mut self, module: &M, other: &O) + where + O: LWECompressedToRef, + M: LWEDecompress, + { + module.decompress_lwe(self, other); + } +} + +pub trait LWECompressedToRef { + fn to_ref(&self) -> LWECompressed<&[u8]>; +} + +impl LWECompressedToRef for LWECompressed { + fn to_ref(&self) -> LWECompressed<&[u8]> { + LWECompressed { + k: self.k, + base2k: self.base2k, + seed: self.seed, + data: self.data.to_ref(), + } + } +} + +pub trait LWECompressedToMut { + fn to_mut(&mut self) -> LWECompressed<&mut [u8]>; +} + +impl LWECompressedToMut for LWECompressed { + fn to_mut(&mut self) -> LWECompressed<&mut [u8]> { + LWECompressed { + k: self.k, + base2k: self.base2k, + seed: self.seed, + data: self.data.to_mut(), + } } } diff --git a/poulpy-core/src/layouts/compressed/lwe_ksk.rs b/poulpy-core/src/layouts/compressed/lwe_ksk.rs index 480707b..cf6ae9a 100644 --- a/poulpy-core/src/layouts/compressed/lwe_ksk.rs +++ b/poulpy-core/src/layouts/compressed/lwe_ksk.rs @@ -1,17 +1,20 @@ use poulpy_hal::{ - api::{VecZnxCopy, VecZnxFillUniform}, layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo}, source::Source, }; use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWEInfos, GLWEInfos, LWEInfos, LWESwitchingKey, Rank, TorusPrecision, - compressed::{Decompress, GGLWESwitchingKeyCompressed}, + Base2K, Dnum, Dsize, GGLWEInfos, GLWEInfos, LWEInfos, LWESwitchingKey, LWESwitchingKeyToMut, Rank, RingDegree, + TorusPrecision, + compressed::{ + GLWESwitchingKeyCompressed, GLWESwitchingKeyCompressedAlloc, GLWESwitchingKeyCompressedToMut, + GLWESwitchingKeyCompressedToRef, GLWESwitchingKeyDecompress, + }, }; use std::fmt; #[derive(PartialEq, Eq, Clone)] -pub struct LWESwitchingKeyCompressed(pub(crate) GGLWESwitchingKeyCompressed); +pub struct LWESwitchingKeyCompressed(pub(crate) GLWESwitchingKeyCompressed); impl LWEInfos for LWESwitchingKeyCompressed { fn base2k(&self) -> Base2K { @@ -22,7 +25,7 @@ impl LWEInfos for LWESwitchingKeyCompressed { self.0.k() } - fn n(&self) -> Degree { + fn n(&self) -> RingDegree { self.0.n() } fn size(&self) -> usize { @@ -83,73 +86,149 @@ impl WriterTo for LWESwitchingKeyCompressed { } } -impl LWESwitchingKeyCompressed> { - pub fn alloc(infos: &A) -> Self +pub trait LWESwitchingKeyCompressedAlloc +where + Self: GLWESwitchingKeyCompressedAlloc, +{ + fn alloc_lwe_switching_key_compressed( + &self, + base2k: Base2K, + k: TorusPrecision, + dnum: Dnum, + ) -> LWESwitchingKeyCompressed> { + LWESwitchingKeyCompressed(self.alloc_glwe_switching_key_compressed(base2k, k, Rank(1), Rank(1), dnum, Dsize(1))) + } + + fn alloc_lwe_switching_key_compressed_from_infos(&self, infos: &A) -> LWESwitchingKeyCompressed> where A: GGLWEInfos, { - debug_assert_eq!( + assert_eq!( infos.dsize().0, 1, "dsize > 1 is not supported for LWESwitchingKeyCompressed" ); - debug_assert_eq!( + assert_eq!( infos.rank_in().0, 1, "rank_in > 1 is not supported for LWESwitchingKeyCompressed" ); - debug_assert_eq!( + assert_eq!( infos.rank_out().0, 1, "rank_out > 1 is not supported for LWESwitchingKeyCompressed" ); - Self(GGLWESwitchingKeyCompressed::alloc(infos)) + self.alloc_lwe_switching_key_compressed(infos.base2k(), infos.k(), infos.dnum()) } - pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> Self { - Self(GGLWESwitchingKeyCompressed::alloc_with( - n, - base2k, - k, - Rank(1), - Rank(1), - dnum, - Dsize(1), - )) + fn bytes_of_lwe_switching_key_compressed(&self, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> usize { + self.bytes_of_glwe_switching_key_compressed(base2k, k, Rank(1), dnum, Dsize(1)) } - pub fn alloc_bytes(infos: &A) -> usize + fn bytes_of_lwe_switching_key_compressed_from_infos(&self, infos: &A) -> usize where A: GGLWEInfos, { - debug_assert_eq!( + assert_eq!( infos.dsize().0, 1, - "dsize > 1 is not supported for LWESwitchingKey" + "dsize > 1 is not supported for LWESwitchingKeyCompressed" ); - debug_assert_eq!( + assert_eq!( infos.rank_in().0, 1, - "rank_in > 1 is not supported for LWESwitchingKey" + "rank_in > 1 is not supported for LWESwitchingKeyCompressed" ); - debug_assert_eq!( + assert_eq!( infos.rank_out().0, 1, - "rank_out > 1 is not supported for LWESwitchingKey" + "rank_out > 1 is not supported for LWESwitchingKeyCompressed" ); - GGLWESwitchingKeyCompressed::alloc_bytes(infos) - } - - pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> usize { - GGLWESwitchingKeyCompressed::alloc_bytes_with(n, base2k, k, Rank(1), dnum, Dsize(1)) + self.bytes_of_glwe_switching_key_compressed_from_infos(infos) } } -impl Decompress> for LWESwitchingKey +impl LWESwitchingKeyCompressedAlloc for Module where Self: GLWESwitchingKeyCompressedAlloc {} + +impl LWESwitchingKeyCompressed> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: GGLWEInfos, + M: LWESwitchingKeyCompressedAlloc, + { + module.alloc_lwe_switching_key_compressed_from_infos(infos) + } + + pub fn alloc(module: &M, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> Self + where + M: LWESwitchingKeyCompressedAlloc, + { + module.alloc_lwe_switching_key_compressed(base2k, k, dnum) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: LWESwitchingKeyCompressedAlloc, + { + module.bytes_of_lwe_switching_key_compressed_from_infos(infos) + } + + pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> usize + where + M: LWESwitchingKeyCompressedAlloc, + { + module.bytes_of_lwe_switching_key_compressed(base2k, k, dnum) + } +} + +pub trait LWESwitchingKeyDecompress where - Module: VecZnxFillUniform + VecZnxCopy, + Self: GLWESwitchingKeyDecompress, { - fn decompress(&mut self, module: &Module, other: &LWESwitchingKeyCompressed) { - self.0.decompress(module, &other.0); + fn decompress_lwe_switching_key(&self, res: &mut R, other: &O) + where + R: LWESwitchingKeyToMut, + O: LWESwitchingKeyCompressedToRef, + { + self.decompress_glwe_switching_key(&mut res.to_mut().0, &other.to_ref().0); + } +} + +impl LWESwitchingKeyDecompress for Module where Self: GLWESwitchingKeyDecompress {} + +impl LWESwitchingKey { + pub fn decompress(&mut self, module: &M, other: &O) + where + O: LWESwitchingKeyCompressedToRef, + M: LWESwitchingKeyDecompress, + { + module.decompress_lwe_switching_key(self, other); + } +} + +pub trait LWESwitchingKeyCompressedToRef { + fn to_ref(&self) -> LWESwitchingKeyCompressed<&[u8]>; +} + +impl LWESwitchingKeyCompressedToRef for LWESwitchingKeyCompressed +where + GLWESwitchingKeyCompressed: GLWESwitchingKeyCompressedToRef, +{ + fn to_ref(&self) -> LWESwitchingKeyCompressed<&[u8]> { + LWESwitchingKeyCompressed(self.0.to_ref()) + } +} + +pub trait LWESwitchingKeyCompressedToMut { + fn to_mut(&mut self) -> LWESwitchingKeyCompressed<&mut [u8]>; +} + +impl LWESwitchingKeyCompressedToMut for LWESwitchingKeyCompressed +where + GLWESwitchingKeyCompressed: GLWESwitchingKeyCompressedToMut, +{ + fn to_mut(&mut self) -> LWESwitchingKeyCompressed<&mut [u8]> { + LWESwitchingKeyCompressed(self.0.to_mut()) } } diff --git a/poulpy-core/src/layouts/compressed/lwe_to_glwe_ksk.rs b/poulpy-core/src/layouts/compressed/lwe_to_glwe_ksk.rs index 86c353b..33fcf49 100644 --- a/poulpy-core/src/layouts/compressed/lwe_to_glwe_ksk.rs +++ b/poulpy-core/src/layouts/compressed/lwe_to_glwe_ksk.rs @@ -1,20 +1,23 @@ use poulpy_hal::{ - api::{VecZnxCopy, VecZnxFillUniform}, layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo}, source::Source, }; use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWEInfos, GLWEInfos, LWEInfos, LWEToGLWESwitchingKey, Rank, TorusPrecision, - compressed::{Decompress, GGLWESwitchingKeyCompressed}, + Base2K, Dnum, Dsize, GGLWEInfos, GLWEInfos, LWEInfos, LWEToGLWESwitchingKey, LWEToGLWESwitchingKeyToMut, Rank, RingDegree, + TorusPrecision, + compressed::{ + GLWESwitchingKeyCompressed, GLWESwitchingKeyCompressedAlloc, GLWESwitchingKeyCompressedToMut, + GLWESwitchingKeyCompressedToRef, GLWESwitchingKeyDecompress, + }, }; use std::fmt; #[derive(PartialEq, Eq, Clone)] -pub struct LWEToGLWESwitchingKeyCompressed(pub(crate) GGLWESwitchingKeyCompressed); +pub struct LWEToGLWESwitchingKeyCompressed(pub(crate) GLWESwitchingKeyCompressed); impl LWEInfos for LWEToGLWESwitchingKeyCompressed { - fn n(&self) -> Degree { + fn n(&self) -> RingDegree { self.0.n() } @@ -83,63 +86,138 @@ impl WriterTo for LWEToGLWESwitchingKeyCompressed { } } -impl LWEToGLWESwitchingKeyCompressed> { - pub fn alloc(infos: &A) -> Self +pub trait LWEToGLWESwitchingKeyCompressedAlloc +where + Self: GLWESwitchingKeyCompressedAlloc, +{ + fn alloc_lwe_to_glwe_switching_key_compressed( + &self, + base2k: Base2K, + k: TorusPrecision, + rank_out: Rank, + dnum: Dnum, + ) -> LWEToGLWESwitchingKeyCompressed> { + LWEToGLWESwitchingKeyCompressed(self.alloc_glwe_switching_key_compressed(base2k, k, Rank(1), rank_out, dnum, Dsize(1))) + } + + fn alloc_lwe_to_glwe_switching_key_compressed_from_infos(&self, infos: &A) -> LWEToGLWESwitchingKeyCompressed> where A: GGLWEInfos, { - debug_assert_eq!( + assert_eq!( infos.dsize().0, 1, "dsize > 1 is not supported for LWEToGLWESwitchingKeyCompressed" ); - debug_assert_eq!( + assert_eq!( infos.rank_in().0, 1, "rank_in > 1 is not supported for LWEToGLWESwitchingKeyCompressed" ); - Self(GGLWESwitchingKeyCompressed::alloc(infos)) + self.alloc_lwe_to_glwe_switching_key_compressed(infos.base2k(), infos.k(), infos.rank_out(), infos.dnum()) } - pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> Self { - Self(GGLWESwitchingKeyCompressed::alloc_with( - n, - base2k, - k, - Rank(1), - rank_out, - dnum, - Dsize(1), - )) + fn bytes_of_lwe_to_glwe_switching_key_compressed(&self, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> usize { + self.bytes_of_glwe_switching_key_compressed(base2k, k, Rank(1), dnum, Dsize(1)) } - pub fn alloc_bytes(infos: &A) -> usize + fn bytes_of_lwe_to_glwe_switching_key_compressed_from_infos(&self, infos: &A) -> usize where A: GGLWEInfos, { - debug_assert_eq!( - infos.rank_in().0, - 1, - "rank_in > 1 is not supported for LWEToGLWESwitchingKey" - ); - debug_assert_eq!( + assert_eq!( infos.dsize().0, 1, - "dsize > 1 is not supported for LWEToGLWESwitchingKey" + "dsize > 1 is not supported for LWEToGLWESwitchingKeyCompressed" ); - GGLWESwitchingKeyCompressed::alloc_bytes(infos) - } - - pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> usize { - GGLWESwitchingKeyCompressed::alloc_bytes_with(n, base2k, k, Rank(1), dnum, Dsize(1)) + assert_eq!( + infos.rank_in().0, + 1, + "rank_in > 1 is not supported for LWEToGLWESwitchingKeyCompressed" + ); + self.bytes_of_lwe_to_glwe_switching_key_compressed(infos.base2k(), infos.k(), infos.dnum()) } } -impl Decompress> for LWEToGLWESwitchingKey +impl LWEToGLWESwitchingKeyCompressed> { + pub fn alloc(module: &M, infos: &A) -> Self + where + A: GGLWEInfos, + M: LWEToGLWESwitchingKeyCompressedAlloc, + { + module.alloc_lwe_to_glwe_switching_key_compressed_from_infos(infos) + } + + pub fn alloc_with(module: &M, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> Self + where + M: LWEToGLWESwitchingKeyCompressedAlloc, + { + module.alloc_lwe_to_glwe_switching_key_compressed(base2k, k, rank_out, dnum) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: LWEToGLWESwitchingKeyCompressedAlloc, + { + module.bytes_of_lwe_to_glwe_switching_key_compressed_from_infos(infos) + } + + pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> usize + where + M: LWEToGLWESwitchingKeyCompressedAlloc, + { + module.bytes_of_lwe_to_glwe_switching_key_compressed(base2k, k, dnum) + } +} + +pub trait LWEToGLWESwitchingKeyDecompress where - Module: VecZnxFillUniform + VecZnxCopy, + Self: GLWESwitchingKeyDecompress, { - fn decompress(&mut self, module: &Module, other: &LWEToGLWESwitchingKeyCompressed) { - self.0.decompress(module, &other.0); + fn decompress_lwe_to_glwe_switching_key(&self, res: &mut R, other: &O) + where + R: LWEToGLWESwitchingKeyToMut, + O: LWEToGLWESwitchingKeyCompressedToRef, + { + self.decompress_glwe_switching_key(&mut res.to_mut().0, &other.to_ref().0); + } +} + +impl LWEToGLWESwitchingKeyDecompress for Module where Self: GLWESwitchingKeyDecompress {} + +impl LWEToGLWESwitchingKey { + pub fn decompress(&mut self, module: &M, other: &O) + where + O: LWEToGLWESwitchingKeyCompressedToRef, + M: LWEToGLWESwitchingKeyDecompress, + { + module.decompress_lwe_to_glwe_switching_key(self, other); + } +} + +pub trait LWEToGLWESwitchingKeyCompressedToRef { + fn to_ref(&self) -> LWEToGLWESwitchingKeyCompressed<&[u8]>; +} + +impl LWEToGLWESwitchingKeyCompressedToRef for LWEToGLWESwitchingKeyCompressed +where + GLWESwitchingKeyCompressed: GLWESwitchingKeyCompressedToRef, +{ + fn to_ref(&self) -> LWEToGLWESwitchingKeyCompressed<&[u8]> { + LWEToGLWESwitchingKeyCompressed(self.0.to_ref()) + } +} + +pub trait LWEToGLWESwitchingKeyCompressedToMut { + fn to_mut(&mut self) -> LWEToGLWESwitchingKeyCompressed<&mut [u8]>; +} + +impl LWEToGLWESwitchingKeyCompressedToMut for LWEToGLWESwitchingKeyCompressed +where + GLWESwitchingKeyCompressed: GLWESwitchingKeyCompressedToMut, +{ + fn to_mut(&mut self) -> LWEToGLWESwitchingKeyCompressed<&mut [u8]> { + LWEToGLWESwitchingKeyCompressed(self.0.to_mut()) } } diff --git a/poulpy-core/src/layouts/compressed/mod.rs b/poulpy-core/src/layouts/compressed/mod.rs index c1fcacf..cd7c459 100644 --- a/poulpy-core/src/layouts/compressed/mod.rs +++ b/poulpy-core/src/layouts/compressed/mod.rs @@ -19,9 +19,3 @@ pub use glwe_to_lwe_ksk::*; pub use lwe_ct::*; pub use lwe_ksk::*; pub use lwe_to_glwe_ksk::*; - -use poulpy_hal::layouts::{Backend, Module}; - -pub trait Decompress { - fn decompress(&mut self, module: &Module, other: &C); -} diff --git a/poulpy-core/src/layouts/gglwe_atk.rs b/poulpy-core/src/layouts/gglwe_atk.rs index 5c786d2..eb93bf4 100644 --- a/poulpy-core/src/layouts/gglwe_atk.rs +++ b/poulpy-core/src/layouts/gglwe_atk.rs @@ -1,18 +1,19 @@ use poulpy_hal::{ - layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, WriterTo}, + layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo}, source::Source, }; use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWESwitchingKey, GLWECiphertext, GLWEInfos, LWEInfos, Rank, TorusPrecision, + Base2K, Dnum, Dsize, GGLWEInfos, GLWE, GLWEInfos, GLWESwitchingKey, GLWESwitchingKeyAlloc, GLWESwitchingKeyToMut, + GLWESwitchingKeyToRef, LWEInfos, Rank, RingDegree, TorusPrecision, }; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::fmt; #[derive(PartialEq, Eq, Copy, Clone, Debug)] -pub struct GGLWEAutomorphismKeyLayout { - pub n: Degree, +pub struct AutomorphismKeyLayout { + pub n: RingDegree, pub base2k: Base2K, pub k: TorusPrecision, pub rank: Rank, @@ -21,19 +22,19 @@ pub struct GGLWEAutomorphismKeyLayout { } #[derive(PartialEq, Eq, Clone)] -pub struct GGLWEAutomorphismKey { - pub(crate) key: GGLWESwitchingKey, +pub struct AutomorphismKey { + pub(crate) key: GLWESwitchingKey, pub(crate) p: i64, } -impl GGLWEAutomorphismKey { +impl AutomorphismKey { pub fn p(&self) -> i64 { self.p } } -impl LWEInfos for GGLWEAutomorphismKey { - fn n(&self) -> Degree { +impl LWEInfos for AutomorphismKey { + fn n(&self) -> RingDegree { self.key.n() } @@ -50,13 +51,13 @@ impl LWEInfos for GGLWEAutomorphismKey { } } -impl GLWEInfos for GGLWEAutomorphismKey { +impl GLWEInfos for AutomorphismKey { fn rank(&self) -> Rank { self.rank_out() } } -impl GGLWEInfos for GGLWEAutomorphismKey { +impl GGLWEInfos for AutomorphismKey { fn rank_in(&self) -> Rank { self.key.rank_in() } @@ -74,7 +75,7 @@ impl GGLWEInfos for GGLWEAutomorphismKey { } } -impl LWEInfos for GGLWEAutomorphismKeyLayout { +impl LWEInfos for AutomorphismKeyLayout { fn base2k(&self) -> Base2K { self.base2k } @@ -83,18 +84,18 @@ impl LWEInfos for GGLWEAutomorphismKeyLayout { self.k } - fn n(&self) -> Degree { + fn n(&self) -> RingDegree { self.n } } -impl GLWEInfos for GGLWEAutomorphismKeyLayout { +impl GLWEInfos for AutomorphismKeyLayout { fn rank(&self) -> Rank { self.rank } } -impl GGLWEInfos for GGLWEAutomorphismKeyLayout { +impl GGLWEInfos for AutomorphismKeyLayout { fn rank_in(&self) -> Rank { self.rank } @@ -112,84 +113,164 @@ impl GGLWEInfos for GGLWEAutomorphismKeyLayout { } } -impl fmt::Debug for GGLWEAutomorphismKey { +impl fmt::Debug for AutomorphismKey { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{self}") } } -impl FillUniform for GGLWEAutomorphismKey { +impl FillUniform for AutomorphismKey { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.key.fill_uniform(log_bound, source); } } -impl fmt::Display for GGLWEAutomorphismKey { +impl fmt::Display for AutomorphismKey { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "(AutomorphismKey: p={}) {}", self.p, self.key) } } -impl GGLWEAutomorphismKey> { - pub fn alloc(infos: &A) -> Self +impl AutomorphismKeyAlloc for Module where Self: GLWESwitchingKeyAlloc {} + +pub trait AutomorphismKeyAlloc +where + Self: GLWESwitchingKeyAlloc, +{ + fn alloc_automorphism_key( + &self, + base2k: Base2K, + k: TorusPrecision, + rank: Rank, + dnum: Dnum, + dsize: Dsize, + ) -> AutomorphismKey> { + AutomorphismKey { + key: self.alloc_glwe_switching_key(base2k, k, rank, rank, dnum, dsize), + p: 0, + } + } + + fn alloc_automorphism_key_from_infos(&self, infos: &A) -> AutomorphismKey> + where + A: GGLWEInfos, + { + self.alloc_automorphism_key( + infos.base2k(), + infos.k(), + infos.rank(), + infos.dnum(), + infos.dsize(), + ) + } + + fn bytes_of_automorphism_key(&self, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { + self.bytes_of_glwe_switching_key(base2k, k, rank, rank, dnum, dsize) + } + + fn bytes_of_automorphism_key_from_infos(&self, infos: &A) -> usize where A: GGLWEInfos, { assert_eq!( infos.rank_in(), infos.rank_out(), - "rank_in != rank_out is not supported for GGLWEAutomorphismKey" + "rank_in != rank_out is not supported for AutomorphismKey" ); - GGLWEAutomorphismKey { - key: GGLWESwitchingKey::alloc(infos), - p: 0, - } - } - - pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self { - GGLWEAutomorphismKey { - key: GGLWESwitchingKey::alloc_with(n, base2k, k, rank, rank, dnum, dsize), - p: 0, - } - } - - pub fn alloc_bytes(infos: &A) -> usize - where - A: GGLWEInfos, - { - assert_eq!( - infos.rank_in(), - infos.rank_out(), - "rank_in != rank_out is not supported for GGLWEAutomorphismKey" - ); - GGLWESwitchingKey::alloc_bytes(infos) - } - - pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { - GGLWESwitchingKey::alloc_bytes_with(n, base2k, k, rank, rank, dnum, dsize) + self.bytes_of_automorphism_key( + infos.base2k(), + infos.k(), + infos.rank(), + infos.dnum(), + infos.dsize(), + ) } } -impl GGLWEAutomorphismKey { - pub fn at(&self, row: usize, col: usize) -> GLWECiphertext<&[u8]> { +impl AutomorphismKey> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: GGLWEInfos, + M: AutomorphismKeyAlloc, + { + module.alloc_automorphism_key_from_infos(infos) + } + + pub fn alloc_with(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self + where + M: AutomorphismKeyAlloc, + { + module.alloc_automorphism_key(base2k, k, rank, dnum, dsize) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: AutomorphismKeyAlloc, + { + module.bytes_of_automorphism_key_from_infos(infos) + } + + pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize + where + M: AutomorphismKeyAlloc, + { + module.bytes_of_automorphism_key(base2k, k, rank, dnum, dsize) + } +} + +pub trait AutomorphismKeyToMut { + fn to_mut(&mut self) -> AutomorphismKey<&mut [u8]>; +} + +impl AutomorphismKeyToMut for AutomorphismKey +where + GLWESwitchingKey: GLWESwitchingKeyToMut, +{ + fn to_mut(&mut self) -> AutomorphismKey<&mut [u8]> { + AutomorphismKey { + key: self.key.to_mut(), + p: self.p, + } + } +} + +pub trait AutomorphismKeyToRef { + fn to_ref(&self) -> AutomorphismKey<&[u8]>; +} + +impl AutomorphismKeyToRef for AutomorphismKey +where + GLWESwitchingKey: GLWESwitchingKeyToRef, +{ + fn to_ref(&self) -> AutomorphismKey<&[u8]> { + AutomorphismKey { + p: self.p, + key: self.key.to_ref(), + } + } +} + +impl AutomorphismKey { + pub fn at(&self, row: usize, col: usize) -> GLWE<&[u8]> { self.key.at(row, col) } } -impl GGLWEAutomorphismKey { - pub fn at_mut(&mut self, row: usize, col: usize) -> GLWECiphertext<&mut [u8]> { +impl AutomorphismKey { + pub fn at_mut(&mut self, row: usize, col: usize) -> GLWE<&mut [u8]> { self.key.at_mut(row, col) } } -impl ReaderFrom for GGLWEAutomorphismKey { +impl ReaderFrom for AutomorphismKey { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { self.p = reader.read_u64::()? as i64; self.key.read_from(reader) } } -impl WriterTo for GGLWEAutomorphismKey { +impl WriterTo for AutomorphismKey { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { writer.write_u64::(self.p as u64)?; self.key.write_to(writer) diff --git a/poulpy-core/src/layouts/gglwe_ct.rs b/poulpy-core/src/layouts/gglwe_ct.rs index ca8236c..3b95d63 100644 --- a/poulpy-core/src/layouts/gglwe_ct.rs +++ b/poulpy-core/src/layouts/gglwe_ct.rs @@ -1,9 +1,11 @@ use poulpy_hal::{ - layouts::{Data, DataMut, DataRef, FillUniform, MatZnx, ReaderFrom, WriterTo, ZnxInfos}, + layouts::{ + Backend, Data, DataMut, DataRef, FillUniform, MatZnx, MatZnxToMut, MatZnxToRef, Module, ReaderFrom, WriterTo, ZnxInfos, + }, source::Source, }; -use crate::layouts::{Base2K, BuildError, Degree, Dnum, Dsize, GLWECiphertext, GLWEInfos, LWEInfos, Rank, TorusPrecision}; +use crate::layouts::{Base2K, Dnum, Dsize, GLWE, GLWEInfos, GetRingDegree, LWEInfos, Rank, RingDegree, TorusPrecision}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::fmt; @@ -16,8 +18,8 @@ where fn dsize(&self) -> Dsize; fn rank_in(&self) -> Rank; fn rank_out(&self) -> Rank; - fn layout(&self) -> GGLWECiphertextLayout { - GGLWECiphertextLayout { + fn gglwe_layout(&self) -> GGLWELayout { + GGLWELayout { n: self.n(), base2k: self.base2k(), k: self.k(), @@ -29,9 +31,13 @@ where } } +pub trait SetGGLWEInfos { + fn set_dsize(&mut self, dsize: usize); +} + #[derive(PartialEq, Eq, Copy, Clone, Debug)] -pub struct GGLWECiphertextLayout { - pub n: Degree, +pub struct GGLWELayout { + pub n: RingDegree, pub base2k: Base2K, pub k: TorusPrecision, pub rank_in: Rank, @@ -40,7 +46,7 @@ pub struct GGLWECiphertextLayout { pub dsize: Dsize, } -impl LWEInfos for GGLWECiphertextLayout { +impl LWEInfos for GGLWELayout { fn base2k(&self) -> Base2K { self.base2k } @@ -49,18 +55,18 @@ impl LWEInfos for GGLWECiphertextLayout { self.k } - fn n(&self) -> Degree { + fn n(&self) -> RingDegree { self.n } } -impl GLWEInfos for GGLWECiphertextLayout { +impl GLWEInfos for GGLWELayout { fn rank(&self) -> Rank { self.rank_out } } -impl GGLWEInfos for GGLWECiphertextLayout { +impl GGLWEInfos for GGLWELayout { fn rank_in(&self) -> Rank { self.rank_in } @@ -79,14 +85,14 @@ impl GGLWEInfos for GGLWECiphertextLayout { } #[derive(PartialEq, Eq, Clone)] -pub struct GGLWECiphertext { +pub struct GGLWE { pub(crate) data: MatZnx, pub(crate) k: TorusPrecision, pub(crate) base2k: Base2K, pub(crate) dsize: Dsize, } -impl LWEInfos for GGLWECiphertext { +impl LWEInfos for GGLWE { fn base2k(&self) -> Base2K { self.base2k } @@ -95,8 +101,8 @@ impl LWEInfos for GGLWECiphertext { self.k } - fn n(&self) -> Degree { - Degree(self.data.n() as u32) + fn n(&self) -> RingDegree { + RingDegree(self.data.n() as u32) } fn size(&self) -> usize { @@ -104,13 +110,13 @@ impl LWEInfos for GGLWECiphertext { } } -impl GLWEInfos for GGLWECiphertext { +impl GLWEInfos for GGLWE { fn rank(&self) -> Rank { self.rank_out() } } -impl GGLWEInfos for GGLWECiphertext { +impl GGLWEInfos for GGLWE { fn rank_in(&self) -> Rank { Rank(self.data.cols_in() as u32) } @@ -128,136 +134,35 @@ impl GGLWEInfos for GGLWECiphertext { } } -pub struct GGLWECiphertextBuilder { - data: Option>, - base2k: Option, - k: Option, - dsize: Option, -} - -impl GGLWECiphertext { - #[inline] - pub fn builder() -> GGLWECiphertextBuilder { - GGLWECiphertextBuilder { - data: None, - base2k: None, - k: None, - dsize: None, - } - } -} - -impl GGLWECiphertextBuilder> { - #[inline] - pub fn layout(mut self, infos: &A) -> Self - where - A: GGLWEInfos, - { - self.data = Some(MatZnx::alloc( - infos.n().into(), - infos.dnum().into(), - infos.rank_in().into(), - (infos.rank_out() + 1).into(), - infos.size(), - )); - self.base2k = Some(infos.base2k()); - self.k = Some(infos.k()); - self.dsize = Some(infos.dsize()); - self - } -} - -impl GGLWECiphertextBuilder { - #[inline] - pub fn data(mut self, data: MatZnx) -> Self { - self.data = Some(data); - self - } - #[inline] - pub fn base2k(mut self, base2k: Base2K) -> Self { - self.base2k = Some(base2k); - self - } - #[inline] - pub fn k(mut self, k: TorusPrecision) -> Self { - self.k = Some(k); - self - } - - #[inline] - pub fn dsize(mut self, dsize: Dsize) -> Self { - self.dsize = Some(dsize); - self - } - - pub fn build(self) -> Result, BuildError> { - let data: MatZnx = self.data.ok_or(BuildError::MissingData)?; - let base2k: Base2K = self.base2k.ok_or(BuildError::MissingBase2K)?; - let k: TorusPrecision = self.k.ok_or(BuildError::MissingK)?; - let dsize: Dsize = self.dsize.ok_or(BuildError::MissingDigits)?; - - if base2k == 0_u32 { - return Err(BuildError::ZeroBase2K); - } - - if dsize == 0_u32 { - return Err(BuildError::ZeroBase2K); - } - - if k == 0_u32 { - return Err(BuildError::ZeroTorusPrecision); - } - - if data.n() == 0 { - return Err(BuildError::ZeroDegree); - } - - if data.cols() == 0 { - return Err(BuildError::ZeroCols); - } - - if data.size() == 0 { - return Err(BuildError::ZeroLimbs); - } - - Ok(GGLWECiphertext { - data, - base2k, - k, - dsize, - }) - } -} - -impl GGLWECiphertext { +impl GGLWE { pub fn data(&self) -> &MatZnx { &self.data } } -impl GGLWECiphertext { +impl GGLWE { pub fn data_mut(&mut self) -> &mut MatZnx { &mut self.data } } -impl fmt::Debug for GGLWECiphertext { +impl fmt::Debug for GGLWE { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{self}") } } -impl FillUniform for GGLWECiphertext { +impl FillUniform for GGLWE { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.data.fill_uniform(log_bound, source); } } -impl fmt::Display for GGLWECiphertext { +impl fmt::Display for GGLWE { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "(GGLWECiphertext: k={} base2k={} dsize={}) {}", + "(GGLWE: k={} base2k={} dsize={}) {}", self.k().0, self.base2k().0, self.dsize().0, @@ -266,53 +171,39 @@ impl fmt::Display for GGLWECiphertext { } } -impl GGLWECiphertext { - pub fn at(&self, row: usize, col: usize) -> GLWECiphertext<&[u8]> { - GLWECiphertext::builder() - .data(self.data.at(row, col)) - .base2k(self.base2k()) - .k(self.k()) - .build() - .unwrap() +impl GGLWE { + pub fn at(&self, row: usize, col: usize) -> GLWE<&[u8]> { + GLWE { + k: self.k, + base2k: self.base2k, + data: self.data.at(row, col), + } } } -impl GGLWECiphertext { - pub fn at_mut(&mut self, row: usize, col: usize) -> GLWECiphertext<&mut [u8]> { - GLWECiphertext::builder() - .base2k(self.base2k()) - .k(self.k()) - .data(self.data.at_mut(row, col)) - .build() - .unwrap() +impl GGLWE { + pub fn at_mut(&mut self, row: usize, col: usize) -> GLWE<&mut [u8]> { + GLWE { + k: self.k, + base2k: self.base2k, + data: self.data.at_mut(row, col), + } } } -impl GGLWECiphertext> { - pub fn alloc(infos: &A) -> Self - where - A: GGLWEInfos, - { - Self::alloc_with( - infos.n(), - infos.base2k(), - infos.k(), - infos.rank_in(), - infos.rank_out(), - infos.dnum(), - infos.dsize(), - ) - } - - pub fn alloc_with( - n: Degree, +pub trait GGLWEAlloc +where + Self: GetRingDegree, +{ + fn alloc_gglwe( + &self, base2k: Base2K, k: TorusPrecision, rank_in: Rank, rank_out: Rank, dnum: Dnum, dsize: Dsize, - ) -> Self { + ) -> GGLWE> { let size: usize = k.0.div_ceil(base2k.0) as usize; debug_assert!( size as u32 > dsize.0, @@ -327,9 +218,9 @@ impl GGLWECiphertext> { dsize.0, ); - Self { + GGLWE { data: MatZnx::alloc( - n.into(), + self.ring_degree().into(), dnum.into(), rank_in.into(), (rank_out + 1).into(), @@ -341,12 +232,11 @@ impl GGLWECiphertext> { } } - pub fn alloc_bytes(infos: &A) -> usize + fn alloc_glwe_from_infos(&self, infos: &A) -> GGLWE> where A: GGLWEInfos, { - Self::alloc_bytes_with( - infos.n(), + self.alloc_gglwe( infos.base2k(), infos.k(), infos.rank_in(), @@ -356,8 +246,8 @@ impl GGLWECiphertext> { ) } - pub fn alloc_bytes_with( - n: Degree, + fn bytes_of_gglwe( + &self, base2k: Base2K, k: TorusPrecision, rank_in: Rank, @@ -379,17 +269,111 @@ impl GGLWECiphertext> { dsize.0, ); - MatZnx::alloc_bytes( - n.into(), + MatZnx::bytes_of( + self.ring_degree().into(), dnum.into(), rank_in.into(), (rank_out + 1).into(), k.0.div_ceil(base2k.0) as usize, ) } + + fn bytes_of_gglwe_from_infos(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + self.bytes_of_gglwe( + infos.base2k(), + infos.k(), + infos.rank_in(), + infos.rank_out(), + infos.dnum(), + infos.dsize(), + ) + } } -impl ReaderFrom for GGLWECiphertext { +impl GGLWEAlloc for Module where Self: GetRingDegree {} + +impl GGLWE> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: GGLWEInfos, + M: GGLWEAlloc, + { + module.alloc_glwe_from_infos(infos) + } + + pub fn alloc( + module: &M, + base2k: Base2K, + k: TorusPrecision, + rank_in: Rank, + rank_out: Rank, + dnum: Dnum, + dsize: Dsize, + ) -> Self + where + M: GGLWEAlloc, + { + module.alloc_gglwe(base2k, k, rank_in, rank_out, dnum, dsize) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: GGLWEAlloc, + { + module.bytes_of_gglwe_from_infos(infos) + } + + pub fn bytes_of( + module: &M, + base2k: Base2K, + k: TorusPrecision, + rank_in: Rank, + rank_out: Rank, + dnum: Dnum, + dsize: Dsize, + ) -> usize + where + M: GGLWEAlloc, + { + module.bytes_of_gglwe(base2k, k, rank_in, rank_out, dnum, dsize) + } +} + +pub trait GGLWEToMut { + fn to_mut(&mut self) -> GGLWE<&mut [u8]>; +} + +impl GGLWEToMut for GGLWE { + fn to_mut(&mut self) -> GGLWE<&mut [u8]> { + GGLWE { + k: self.k(), + base2k: self.base2k(), + dsize: self.dsize(), + data: self.data.to_mut(), + } + } +} + +pub trait GGLWEToRef { + fn to_ref(&self) -> GGLWE<&[u8]>; +} + +impl GGLWEToRef for GGLWE { + fn to_ref(&self) -> GGLWE<&[u8]> { + GGLWE { + k: self.k(), + base2k: self.base2k(), + dsize: self.dsize(), + data: self.data.to_ref(), + } + } +} + +impl ReaderFrom for GGLWE { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { self.k = TorusPrecision(reader.read_u32::()?); self.base2k = Base2K(reader.read_u32::()?); @@ -398,7 +382,7 @@ impl ReaderFrom for GGLWECiphertext { } } -impl WriterTo for GGLWECiphertext { +impl WriterTo for GGLWE { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { writer.write_u32::(self.k.0)?; writer.write_u32::(self.base2k.0)?; diff --git a/poulpy-core/src/layouts/gglwe_ksk.rs b/poulpy-core/src/layouts/gglwe_ksk.rs index 31a483b..ddb4de7 100644 --- a/poulpy-core/src/layouts/gglwe_ksk.rs +++ b/poulpy-core/src/layouts/gglwe_ksk.rs @@ -1,18 +1,19 @@ use poulpy_hal::{ - layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, WriterTo}, + layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo}, source::Source, }; use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWECiphertext, GGLWEInfos, GLWECiphertext, GLWEInfos, LWEInfos, Rank, TorusPrecision, + Base2K, Dnum, Dsize, GGLWE, GGLWEAlloc, GGLWEInfos, GGLWEToMut, GGLWEToRef, GLWE, GLWEInfos, LWEInfos, Rank, RingDegree, + TorusPrecision, }; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::fmt; #[derive(PartialEq, Eq, Copy, Clone, Debug)] -pub struct GGLWESwitchingKeyLayout { - pub n: Degree, +pub struct GLWESwitchingKeyLayout { + pub n: RingDegree, pub base2k: Base2K, pub k: TorusPrecision, pub rank_in: Rank, @@ -21,8 +22,8 @@ pub struct GGLWESwitchingKeyLayout { pub dsize: Dsize, } -impl LWEInfos for GGLWESwitchingKeyLayout { - fn n(&self) -> Degree { +impl LWEInfos for GLWESwitchingKeyLayout { + fn n(&self) -> RingDegree { self.n } @@ -35,13 +36,13 @@ impl LWEInfos for GGLWESwitchingKeyLayout { } } -impl GLWEInfos for GGLWESwitchingKeyLayout { +impl GLWEInfos for GLWESwitchingKeyLayout { fn rank(&self) -> Rank { self.rank_out() } } -impl GGLWEInfos for GGLWESwitchingKeyLayout { +impl GGLWEInfos for GLWESwitchingKeyLayout { fn rank_in(&self) -> Rank { self.rank_in } @@ -60,14 +61,44 @@ impl GGLWEInfos for GGLWESwitchingKeyLayout { } #[derive(PartialEq, Eq, Clone)] -pub struct GGLWESwitchingKey { - pub(crate) key: GGLWECiphertext, +pub struct GLWESwitchingKey { + pub(crate) key: GGLWE, pub(crate) sk_in_n: usize, // Degree of sk_in pub(crate) sk_out_n: usize, // Degree of sk_out } -impl LWEInfos for GGLWESwitchingKey { - fn n(&self) -> Degree { +pub(crate) trait GLWESwitchingKeySetMetaData { + fn set_sk_in_n(&mut self, sk_in_n: usize); + fn set_sk_out_n(&mut self, sk_out_n: usize); +} + +impl GLWESwitchingKeySetMetaData for GLWESwitchingKey { + fn set_sk_in_n(&mut self, sk_in_n: usize) { + self.sk_in_n = sk_in_n + } + + fn set_sk_out_n(&mut self, sk_out_n: usize) { + self.sk_out_n = sk_out_n + } +} + +pub(crate) trait GLWESwtichingKeyGetMetaData { + fn sk_in_n(&self) -> usize; + fn sk_out_n(&self) -> usize; +} + +impl GLWESwtichingKeyGetMetaData for GLWESwitchingKey { + fn sk_in_n(&self) -> usize { + self.sk_in_n + } + + fn sk_out_n(&self) -> usize { + self.sk_out_n + } +} + +impl LWEInfos for GLWESwitchingKey { + fn n(&self) -> RingDegree { self.key.n() } @@ -84,13 +115,13 @@ impl LWEInfos for GGLWESwitchingKey { } } -impl GLWEInfos for GGLWESwitchingKey { +impl GLWEInfos for GLWESwitchingKey { fn rank(&self) -> Rank { self.rank_out() } } -impl GGLWEInfos for GGLWESwitchingKey { +impl GGLWEInfos for GLWESwitchingKey { fn rank_in(&self) -> Rank { self.key.rank_in() } @@ -108,13 +139,13 @@ impl GGLWEInfos for GGLWESwitchingKey { } } -impl fmt::Debug for GGLWESwitchingKey { +impl fmt::Debug for GLWESwitchingKey { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{self}") } } -impl fmt::Display for GGLWESwitchingKey { +impl fmt::Display for GLWESwitchingKey { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, @@ -126,49 +157,48 @@ impl fmt::Display for GGLWESwitchingKey { } } -impl FillUniform for GGLWESwitchingKey { +impl FillUniform for GLWESwitchingKey { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.key.fill_uniform(log_bound, source); } } -impl GGLWESwitchingKey> { - pub fn alloc(infos: &A) -> Self - where - A: GGLWEInfos, - { - GGLWESwitchingKey { - key: GGLWECiphertext::alloc(infos), - sk_in_n: 0, - sk_out_n: 0, - } - } - - pub fn alloc_with( - n: Degree, +pub trait GLWESwitchingKeyAlloc +where + Self: GGLWEAlloc, +{ + fn alloc_glwe_switching_key( + &self, base2k: Base2K, k: TorusPrecision, rank_in: Rank, rank_out: Rank, dnum: Dnum, dsize: Dsize, - ) -> Self { - GGLWESwitchingKey { - key: GGLWECiphertext::alloc_with(n, base2k, k, rank_in, rank_out, dnum, dsize), + ) -> GLWESwitchingKey> { + GLWESwitchingKey { + key: self.alloc_gglwe(base2k, k, rank_in, rank_out, dnum, dsize), sk_in_n: 0, sk_out_n: 0, } } - pub fn alloc_bytes(infos: &A) -> usize + fn alloc_glwe_switching_key_from_infos(&self, infos: &A) -> GLWESwitchingKey> where A: GGLWEInfos, { - GGLWECiphertext::alloc_bytes(infos) + self.alloc_glwe_switching_key( + infos.base2k(), + infos.k(), + infos.rank_in(), + infos.rank_out(), + infos.dnum(), + infos.dsize(), + ) } - pub fn alloc_bytes_with( - n: Degree, + fn bytes_of_glwe_switching_key( + &self, base2k: Base2K, k: TorusPrecision, rank_in: Rank, @@ -176,23 +206,121 @@ impl GGLWESwitchingKey> { dnum: Dnum, dsize: Dsize, ) -> usize { - GGLWECiphertext::alloc_bytes_with(n, base2k, k, rank_in, rank_out, dnum, dsize) + self.bytes_of_gglwe(base2k, k, rank_in, rank_out, dnum, dsize) + } + + fn bytes_of_glwe_switching_key_from_infos(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + self.bytes_of_glwe_switching_key( + infos.base2k(), + infos.k(), + infos.rank_in(), + infos.rank_out(), + infos.dnum(), + infos.dsize(), + ) } } -impl GGLWESwitchingKey { - pub fn at(&self, row: usize, col: usize) -> GLWECiphertext<&[u8]> { +impl GLWESwitchingKeyAlloc for Module where Self: GGLWEAlloc {} + +impl GLWESwitchingKey> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: GGLWEInfos, + M: GLWESwitchingKeyAlloc, + { + module.alloc_glwe_switching_key_from_infos(infos) + } + + pub fn alloc( + module: &M, + base2k: Base2K, + k: TorusPrecision, + rank_in: Rank, + rank_out: Rank, + dnum: Dnum, + dsize: Dsize, + ) -> Self + where + M: GLWESwitchingKeyAlloc, + { + module.alloc_glwe_switching_key(base2k, k, rank_in, rank_out, dnum, dsize) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: GLWESwitchingKeyAlloc, + { + module.bytes_of_glwe_switching_key_from_infos(infos) + } + + pub fn bytes_of( + module: &M, + base2k: Base2K, + k: TorusPrecision, + rank_in: Rank, + rank_out: Rank, + dnum: Dnum, + dsize: Dsize, + ) -> usize + where + M: GLWESwitchingKeyAlloc, + { + module.bytes_of_glwe_switching_key(base2k, k, rank_in, rank_out, dnum, dsize) + } +} + +pub trait GLWESwitchingKeyToMut { + fn to_mut(&mut self) -> GLWESwitchingKey<&mut [u8]>; +} + +impl GLWESwitchingKeyToMut for GLWESwitchingKey +where + GGLWE: GGLWEToMut, +{ + fn to_mut(&mut self) -> GLWESwitchingKey<&mut [u8]> { + GLWESwitchingKey { + key: self.key.to_mut(), + sk_in_n: self.sk_in_n, + sk_out_n: self.sk_out_n, + } + } +} + +pub trait GLWESwitchingKeyToRef { + fn to_ref(&self) -> GLWESwitchingKey<&[u8]>; +} + +impl GLWESwitchingKeyToRef for GLWESwitchingKey +where + GGLWE: GGLWEToRef, +{ + fn to_ref(&self) -> GLWESwitchingKey<&[u8]> { + GLWESwitchingKey { + key: self.key.to_ref(), + sk_in_n: self.sk_in_n, + sk_out_n: self.sk_out_n, + } + } +} + +impl GLWESwitchingKey { + pub fn at(&self, row: usize, col: usize) -> GLWE<&[u8]> { self.key.at(row, col) } } -impl GGLWESwitchingKey { - pub fn at_mut(&mut self, row: usize, col: usize) -> GLWECiphertext<&mut [u8]> { +impl GLWESwitchingKey { + pub fn at_mut(&mut self, row: usize, col: usize) -> GLWE<&mut [u8]> { self.key.at_mut(row, col) } } -impl ReaderFrom for GGLWESwitchingKey { +impl ReaderFrom for GLWESwitchingKey { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { self.sk_in_n = reader.read_u64::()? as usize; self.sk_out_n = reader.read_u64::()? as usize; @@ -200,7 +328,7 @@ impl ReaderFrom for GGLWESwitchingKey { } } -impl WriterTo for GGLWESwitchingKey { +impl WriterTo for GLWESwitchingKey { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { writer.write_u64::(self.sk_in_n as u64)?; writer.write_u64::(self.sk_out_n as u64)?; diff --git a/poulpy-core/src/layouts/gglwe_tsk.rs b/poulpy-core/src/layouts/gglwe_tsk.rs index a949b7e..6ef67b8 100644 --- a/poulpy-core/src/layouts/gglwe_tsk.rs +++ b/poulpy-core/src/layouts/gglwe_tsk.rs @@ -1,16 +1,19 @@ use poulpy_hal::{ - layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, WriterTo}, + layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo}, source::Source, }; -use crate::layouts::{Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWESwitchingKey, GLWEInfos, LWEInfos, Rank, TorusPrecision}; +use crate::layouts::{ + Base2K, Dnum, Dsize, GGLWEInfos, GLWEInfos, GLWESwitchingKey, GLWESwitchingKeyAlloc, GLWESwitchingKeyToMut, + GLWESwitchingKeyToRef, LWEInfos, Rank, RingDegree, TorusPrecision, +}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::fmt; #[derive(PartialEq, Eq, Copy, Clone, Debug)] -pub struct GGLWETensorKeyLayout { - pub n: Degree, +pub struct TensorKeyLayout { + pub n: RingDegree, pub base2k: Base2K, pub k: TorusPrecision, pub rank: Rank, @@ -19,12 +22,12 @@ pub struct GGLWETensorKeyLayout { } #[derive(PartialEq, Eq, Clone)] -pub struct GGLWETensorKey { - pub(crate) keys: Vec>, +pub struct TensorKey { + pub(crate) keys: Vec>, } -impl LWEInfos for GGLWETensorKey { - fn n(&self) -> Degree { +impl LWEInfos for TensorKey { + fn n(&self) -> RingDegree { self.keys[0].n() } @@ -41,13 +44,13 @@ impl LWEInfos for GGLWETensorKey { } } -impl GLWEInfos for GGLWETensorKey { +impl GLWEInfos for TensorKey { fn rank(&self) -> Rank { self.keys[0].rank_out() } } -impl GGLWEInfos for GGLWETensorKey { +impl GGLWEInfos for TensorKey { fn rank_in(&self) -> Rank { self.rank_out() } @@ -65,8 +68,8 @@ impl GGLWEInfos for GGLWETensorKey { } } -impl LWEInfos for GGLWETensorKeyLayout { - fn n(&self) -> Degree { +impl LWEInfos for TensorKeyLayout { + fn n(&self) -> RingDegree { self.n } @@ -79,13 +82,13 @@ impl LWEInfos for GGLWETensorKeyLayout { } } -impl GLWEInfos for GGLWETensorKeyLayout { +impl GLWEInfos for TensorKeyLayout { fn rank(&self) -> Rank { self.rank_out() } } -impl GGLWEInfos for GGLWETensorKeyLayout { +impl GGLWEInfos for TensorKeyLayout { fn rank_in(&self) -> Rank { self.rank } @@ -103,21 +106,21 @@ impl GGLWEInfos for GGLWETensorKeyLayout { } } -impl fmt::Debug for GGLWETensorKey { +impl fmt::Debug for TensorKey { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{self}") } } -impl FillUniform for GGLWETensorKey { +impl FillUniform for TensorKey { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.keys .iter_mut() - .for_each(|key: &mut GGLWESwitchingKey| key.fill_uniform(log_bound, source)) + .for_each(|key: &mut GLWESwitchingKey| key.fill_uniform(log_bound, source)) } } -impl fmt::Display for GGLWETensorKey { +impl fmt::Display for TensorKey { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!(f, "(GLWETensorKey)",)?; for (i, key) in self.keys.iter().enumerate() { @@ -127,8 +130,20 @@ impl fmt::Display for GGLWETensorKey { } } -impl GGLWETensorKey> { - pub fn alloc(infos: &A) -> Self +pub trait TensorKeyAlloc +where + Self: GLWESwitchingKeyAlloc, +{ + fn alloc_tensor_key(&self, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> TensorKey> { + let pairs: u32 = (((rank.0 + 1) * rank.0) >> 1).max(1); + TensorKey { + keys: (0..pairs) + .map(|_| self.alloc_glwe_switching_key(base2k, k, Rank(1), rank, dnum, dsize)) + .collect(), + } + } + + fn alloc_tensor_key_from_infos(&self, infos: &A) -> TensorKey> where A: GGLWEInfos, { @@ -137,34 +152,21 @@ impl GGLWETensorKey> { infos.rank_out(), "rank_in != rank_out is not supported for GGLWETensorKey" ); - Self::alloc_with( - infos.n(), + self.alloc_tensor_key( infos.base2k(), infos.k(), - infos.rank_out(), + infos.rank(), infos.dnum(), infos.dsize(), ) } - pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self { - let mut keys: Vec>> = Vec::new(); - let pairs: u32 = (((rank.0 + 1) * rank.0) >> 1).max(1); - (0..pairs).for_each(|_| { - keys.push(GGLWESwitchingKey::alloc_with( - n, - base2k, - k, - Rank(1), - rank, - dnum, - dsize, - )); - }); - Self { keys } + fn bytes_of_tensor_key(&self, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { + let pairs: usize = (((rank.0 + 1) * rank.0) >> 1).max(1) as usize; + pairs * self.bytes_of_glwe_switching_key(base2k, k, Rank(1), rank, dnum, dsize) } - pub fn alloc_bytes(infos: &A) -> usize + fn bytes_of_tensor_key_from_infos(&self, infos: &A) -> usize where A: GGLWEInfos, { @@ -173,29 +175,53 @@ impl GGLWETensorKey> { infos.rank_out(), "rank_in != rank_out is not supported for GGLWETensorKey" ); - let rank_out: usize = infos.rank_out().into(); - let pairs: usize = (((rank_out + 1) * rank_out) >> 1).max(1); - pairs - * GGLWESwitchingKey::alloc_bytes_with( - infos.n(), - infos.base2k(), - infos.k(), - Rank(1), - infos.rank_out(), - infos.dnum(), - infos.dsize(), - ) - } - - pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { - let pairs: usize = (((rank.0 + 1) * rank.0) >> 1).max(1) as usize; - pairs * GGLWESwitchingKey::alloc_bytes_with(n, base2k, k, Rank(1), rank, dnum, dsize) + self.bytes_of_tensor_key( + infos.base2k(), + infos.k(), + infos.rank(), + infos.dnum(), + infos.dsize(), + ) } } -impl GGLWETensorKey { +impl TensorKeyAlloc for Module where Self: GLWESwitchingKeyAlloc {} + +impl TensorKey> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: GGLWEInfos, + M: TensorKeyAlloc, + { + module.alloc_tensor_key_from_infos(infos) + } + + pub fn alloc(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self + where + M: TensorKeyAlloc, + { + module.alloc_tensor_key(base2k, k, rank, dnum, dsize) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: TensorKeyAlloc, + { + module.bytes_of_tensor_key_from_infos(infos) + } + + pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize + where + M: TensorKeyAlloc, + { + module.bytes_of_tensor_key(base2k, k, rank, dnum, dsize) + } +} + +impl TensorKey { // Returns a mutable reference to GLWESwitchingKey_{s}(s[i] * s[j]) - pub fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GGLWESwitchingKey { + pub fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GLWESwitchingKey { if i > j { std::mem::swap(&mut i, &mut j); }; @@ -204,9 +230,9 @@ impl GGLWETensorKey { } } -impl GGLWETensorKey { +impl TensorKey { // Returns a reference to GLWESwitchingKey_{s}(s[i] * s[j]) - pub fn at(&self, mut i: usize, mut j: usize) -> &GGLWESwitchingKey { + pub fn at(&self, mut i: usize, mut j: usize) -> &GLWESwitchingKey { if i > j { std::mem::swap(&mut i, &mut j); }; @@ -215,7 +241,7 @@ impl GGLWETensorKey { } } -impl ReaderFrom for GGLWETensorKey { +impl ReaderFrom for TensorKey { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { let len: usize = reader.read_u64::()? as usize; if self.keys.len() != len { @@ -231,7 +257,7 @@ impl ReaderFrom for GGLWETensorKey { } } -impl WriterTo for GGLWETensorKey { +impl WriterTo for TensorKey { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { writer.write_u64::(self.keys.len() as u64)?; for key in &self.keys { @@ -240,3 +266,33 @@ impl WriterTo for GGLWETensorKey { Ok(()) } } + +pub trait TensorKeyToRef { + fn to_ref(&self) -> TensorKey<&[u8]>; +} + +impl TensorKeyToRef for TensorKey +where + GLWESwitchingKey: GLWESwitchingKeyToRef, +{ + fn to_ref(&self) -> TensorKey<&[u8]> { + TensorKey { + keys: self.keys.iter().map(|c| c.to_ref()).collect(), + } + } +} + +pub trait TensorKeyToMut { + fn to_mut(&mut self) -> TensorKey<&mut [u8]>; +} + +impl TensorKeyToMut for TensorKey +where + GLWESwitchingKey: GLWESwitchingKeyToMut, +{ + fn to_mut(&mut self) -> TensorKey<&mut [u8]> { + TensorKey { + keys: self.keys.iter_mut().map(|c| c.to_mut()).collect(), + } + } +} diff --git a/poulpy-core/src/layouts/ggsw_ct.rs b/poulpy-core/src/layouts/ggsw_ct.rs index f1bb228..aac4a20 100644 --- a/poulpy-core/src/layouts/ggsw_ct.rs +++ b/poulpy-core/src/layouts/ggsw_ct.rs @@ -1,10 +1,12 @@ use poulpy_hal::{ - layouts::{Data, DataMut, DataRef, FillUniform, MatZnx, ReaderFrom, WriterTo, ZnxInfos}, + layouts::{ + Backend, Data, DataMut, DataRef, FillUniform, MatZnx, MatZnxToMut, MatZnxToRef, Module, ReaderFrom, WriterTo, ZnxInfos, + }, source::Source, }; use std::fmt; -use crate::layouts::{Base2K, BuildError, Degree, Dnum, Dsize, GLWECiphertext, GLWEInfos, LWEInfos, Rank, TorusPrecision}; +use crate::layouts::{Base2K, Dnum, Dsize, GLWE, GLWEInfos, GetRingDegree, LWEInfos, Rank, RingDegree, TorusPrecision}; pub trait GGSWInfos where @@ -12,8 +14,8 @@ where { fn dnum(&self) -> Dnum; fn dsize(&self) -> Dsize; - fn ggsw_layout(&self) -> GGSWCiphertextLayout { - GGSWCiphertextLayout { + fn ggsw_layout(&self) -> GGSWLayout { + GGSWLayout { n: self.n(), base2k: self.base2k(), k: self.k(), @@ -25,8 +27,8 @@ where } #[derive(PartialEq, Eq, Copy, Clone, Debug)] -pub struct GGSWCiphertextLayout { - pub n: Degree, +pub struct GGSWLayout { + pub n: RingDegree, pub base2k: Base2K, pub k: TorusPrecision, pub rank: Rank, @@ -34,7 +36,7 @@ pub struct GGSWCiphertextLayout { pub dsize: Dsize, } -impl LWEInfos for GGSWCiphertextLayout { +impl LWEInfos for GGSWLayout { fn base2k(&self) -> Base2K { self.base2k } @@ -43,17 +45,17 @@ impl LWEInfos for GGSWCiphertextLayout { self.k } - fn n(&self) -> Degree { + fn n(&self) -> RingDegree { self.n } } -impl GLWEInfos for GGSWCiphertextLayout { +impl GLWEInfos for GGSWLayout { fn rank(&self) -> Rank { self.rank } } -impl GGSWInfos for GGSWCiphertextLayout { +impl GGSWInfos for GGSWLayout { fn dsize(&self) -> Dsize { self.dsize } @@ -64,16 +66,16 @@ impl GGSWInfos for GGSWCiphertextLayout { } #[derive(PartialEq, Eq, Clone)] -pub struct GGSWCiphertext { +pub struct GGSW { pub(crate) data: MatZnx, pub(crate) k: TorusPrecision, pub(crate) base2k: Base2K, pub(crate) dsize: Dsize, } -impl LWEInfos for GGSWCiphertext { - fn n(&self) -> Degree { - Degree(self.data.n() as u32) +impl LWEInfos for GGSW { + fn n(&self) -> RingDegree { + RingDegree(self.data.n() as u32) } fn base2k(&self) -> Base2K { @@ -89,13 +91,13 @@ impl LWEInfos for GGSWCiphertext { } } -impl GLWEInfos for GGSWCiphertext { +impl GLWEInfos for GGSW { fn rank(&self) -> Rank { Rank(self.data.cols_out() as u32 - 1) } } -impl GGSWInfos for GGSWCiphertext { +impl GGSWInfos for GGSW { fn dsize(&self) -> Dsize { self.dsize } @@ -105,133 +107,17 @@ impl GGSWInfos for GGSWCiphertext { } } -pub struct GGSWCiphertextBuilder { - data: Option>, - base2k: Option, - k: Option, - dsize: Option, -} - -impl GGSWCiphertext { - #[inline] - pub fn builder() -> GGSWCiphertextBuilder { - GGSWCiphertextBuilder { - data: None, - base2k: None, - k: None, - dsize: None, - } - } -} - -impl GGSWCiphertextBuilder> { - #[inline] - pub fn layout(mut self, infos: &A) -> Self - where - A: GGSWInfos, - { - debug_assert!( - infos.size() as u32 > infos.dsize().0, - "invalid ggsw: ceil(k/base2k): {} <= dsize: {}", - infos.size(), - infos.dsize() - ); - - assert!( - infos.dnum().0 * infos.dsize().0 <= infos.size() as u32, - "invalid ggsw: dnum: {} * dsize:{} > ceil(k/base2k): {}", - infos.dnum(), - infos.dsize(), - infos.size(), - ); - - self.data = Some(MatZnx::alloc( - infos.n().into(), - infos.dnum().into(), - (infos.rank() + 1).into(), - (infos.rank() + 1).into(), - infos.size(), - )); - self.base2k = Some(infos.base2k()); - self.k = Some(infos.k()); - self.dsize = Some(infos.dsize()); - self - } -} - -impl GGSWCiphertextBuilder { - #[inline] - pub fn data(mut self, data: MatZnx) -> Self { - self.data = Some(data); - self - } - #[inline] - pub fn base2k(mut self, base2k: Base2K) -> Self { - self.base2k = Some(base2k); - self - } - #[inline] - pub fn k(mut self, k: TorusPrecision) -> Self { - self.k = Some(k); - self - } - - #[inline] - pub fn dsize(mut self, dsize: Dsize) -> Self { - self.dsize = Some(dsize); - self - } - - pub fn build(self) -> Result, BuildError> { - let data: MatZnx = self.data.ok_or(BuildError::MissingData)?; - let base2k: Base2K = self.base2k.ok_or(BuildError::MissingBase2K)?; - let k: TorusPrecision = self.k.ok_or(BuildError::MissingK)?; - let dsize: Dsize = self.dsize.ok_or(BuildError::MissingDigits)?; - - if base2k == 0_u32 { - return Err(BuildError::ZeroBase2K); - } - - if dsize == 0_u32 { - return Err(BuildError::ZeroBase2K); - } - - if k == 0_u32 { - return Err(BuildError::ZeroTorusPrecision); - } - - if data.n() == 0 { - return Err(BuildError::ZeroDegree); - } - - if data.cols() == 0 { - return Err(BuildError::ZeroCols); - } - - if data.size() == 0 { - return Err(BuildError::ZeroLimbs); - } - - Ok(GGSWCiphertext { - data, - base2k, - k, - dsize, - }) - } -} - -impl fmt::Debug for GGSWCiphertext { +impl fmt::Debug for GGSW { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}", self.data) } } -impl fmt::Display for GGSWCiphertext { +impl fmt::Display for GGSW { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "(GGSWCiphertext: k: {} base2k: {} dsize: {}) {}", + "(GGSW: k: {} base2k: {} dsize: {}) {}", self.k().0, self.base2k().0, self.dsize().0, @@ -240,50 +126,39 @@ impl fmt::Display for GGSWCiphertext { } } -impl FillUniform for GGSWCiphertext { +impl FillUniform for GGSW { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.data.fill_uniform(log_bound, source); } } -impl GGSWCiphertext { - pub fn at(&self, row: usize, col: usize) -> GLWECiphertext<&[u8]> { - GLWECiphertext::builder() - .data(self.data.at(row, col)) - .base2k(self.base2k()) - .k(self.k()) - .build() - .unwrap() +impl GGSW { + pub fn at(&self, row: usize, col: usize) -> GLWE<&[u8]> { + GLWE { + k: self.k, + base2k: self.base2k, + data: self.data.at(row, col), + } } } -impl GGSWCiphertext { - pub fn at_mut(&mut self, row: usize, col: usize) -> GLWECiphertext<&mut [u8]> { - GLWECiphertext::builder() - .base2k(self.base2k()) - .k(self.k()) - .data(self.data.at_mut(row, col)) - .build() - .unwrap() +impl GGSW { + pub fn at_mut(&mut self, row: usize, col: usize) -> GLWE<&mut [u8]> { + GLWE { + k: self.k, + base2k: self.base2k, + data: self.data.at_mut(row, col), + } } } -impl GGSWCiphertext> { - pub fn alloc(infos: &A) -> Self - where - A: GGSWInfos, - { - Self::alloc_with( - infos.n(), - infos.base2k(), - infos.k(), - infos.rank(), - infos.dnum(), - infos.dsize(), - ) - } +impl GGSWAlloc for Module where Self: GetRingDegree {} - pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self { +pub trait GGSWAlloc +where + Self: GetRingDegree, +{ + fn alloc_ggsw(&self, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> GGSW> { let size: usize = k.0.div_ceil(base2k.0) as usize; debug_assert!( size as u32 > dsize.0, @@ -298,9 +173,9 @@ impl GGSWCiphertext> { dsize.0, ); - Self { + GGSW { data: MatZnx::alloc( - n.into(), + self.ring_degree().into(), dnum.into(), (rank + 1).into(), (rank + 1).into(), @@ -312,12 +187,11 @@ impl GGSWCiphertext> { } } - pub fn alloc_bytes(infos: &A) -> usize + fn alloc_ggsw_from_infos(&self, infos: &A) -> GGSW> where A: GGSWInfos, { - Self::alloc_bytes_with( - infos.n(), + self.alloc_ggsw( infos.base2k(), infos.k(), infos.rank(), @@ -326,7 +200,7 @@ impl GGSWCiphertext> { ) } - pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { + fn bytes_of_ggsw(&self, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { let size: usize = k.0.div_ceil(base2k.0) as usize; debug_assert!( size as u32 > dsize.0, @@ -341,19 +215,64 @@ impl GGSWCiphertext> { dsize.0, ); - MatZnx::alloc_bytes( - n.into(), + MatZnx::bytes_of( + self.ring_degree().into(), dnum.into(), (rank + 1).into(), (rank + 1).into(), k.0.div_ceil(base2k.0) as usize, ) } + + fn bytes_of_ggsw_from_infos(&self, infos: &A) -> usize + where + A: GGSWInfos, + { + self.bytes_of_ggsw( + infos.base2k(), + infos.k(), + infos.rank(), + infos.dnum(), + infos.dsize(), + ) + } +} + +impl GGSW> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: GGSWInfos, + M: GGSWAlloc, + { + module.alloc_ggsw_from_infos(infos) + } + + pub fn alloc(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self + where + M: GGSWAlloc, + { + module.alloc_ggsw(base2k, k, rank, dnum, dsize) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: GGSWInfos, + M: GGSWAlloc, + { + module.bytes_of_ggsw_from_infos(infos) + } + + pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize + where + M: GGSWAlloc, + { + module.bytes_of_ggsw(base2k, k, rank, dnum, dsize) + } } use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; -impl ReaderFrom for GGSWCiphertext { +impl ReaderFrom for GGSW { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { self.k = TorusPrecision(reader.read_u32::()?); self.base2k = Base2K(reader.read_u32::()?); @@ -362,7 +281,7 @@ impl ReaderFrom for GGSWCiphertext { } } -impl WriterTo for GGSWCiphertext { +impl WriterTo for GGSW { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { writer.write_u32::(self.k.into())?; writer.write_u32::(self.base2k.into())?; @@ -370,3 +289,33 @@ impl WriterTo for GGSWCiphertext { self.data.write_to(writer) } } + +pub trait GGSWToMut { + fn to_mut(&mut self) -> GGSW<&mut [u8]>; +} + +impl GGSWToMut for GGSW { + fn to_mut(&mut self) -> GGSW<&mut [u8]> { + GGSW { + dsize: self.dsize, + k: self.k, + base2k: self.base2k, + data: self.data.to_mut(), + } + } +} + +pub trait GGSWToRef { + fn to_ref(&self) -> GGSW<&[u8]>; +} + +impl GGSWToRef for GGSW { + fn to_ref(&self) -> GGSW<&[u8]> { + GGSW { + dsize: self.dsize, + k: self.k, + base2k: self.base2k, + data: self.data.to_ref(), + } + } +} diff --git a/poulpy-core/src/layouts/glwe_ct.rs b/poulpy-core/src/layouts/glwe_ct.rs index 23b6ef9..2361e3c 100644 --- a/poulpy-core/src/layouts/glwe_ct.rs +++ b/poulpy-core/src/layouts/glwe_ct.rs @@ -1,11 +1,12 @@ use poulpy_hal::{ layouts::{ - Data, DataMut, DataRef, FillUniform, ReaderFrom, ToOwnedDeep, VecZnx, VecZnxToMut, VecZnxToRef, WriterTo, ZnxInfos, + Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, ToOwnedDeep, VecZnx, VecZnxToMut, VecZnxToRef, + WriterTo, ZnxInfos, }, source::Source, }; -use crate::layouts::{Base2K, BuildError, Degree, LWEInfos, Rank, TorusPrecision}; +use crate::layouts::{Base2K, GetRingDegree, LWEInfos, Rank, RingDegree, TorusPrecision}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::fmt; @@ -14,8 +15,8 @@ where Self: LWEInfos, { fn rank(&self) -> Rank; - fn glwe_layout(&self) -> GLWECiphertextLayout { - GLWECiphertextLayout { + fn glwe_layout(&self) -> GLWELayout { + GLWELayout { n: self.n(), base2k: self.base2k(), k: self.k(), @@ -24,21 +25,21 @@ where } } -pub trait GLWELayoutSet { +pub trait SetGLWEInfos { fn set_k(&mut self, k: TorusPrecision); - fn set_basek(&mut self, base2k: Base2K); + fn set_base2k(&mut self, base2k: Base2K); } #[derive(PartialEq, Eq, Copy, Clone, Debug)] -pub struct GLWECiphertextLayout { - pub n: Degree, +pub struct GLWELayout { + pub n: RingDegree, pub base2k: Base2K, pub k: TorusPrecision, pub rank: Rank, } -impl LWEInfos for GLWECiphertextLayout { - fn n(&self) -> Degree { +impl LWEInfos for GLWELayout { + fn n(&self) -> RingDegree { self.n } @@ -51,21 +52,21 @@ impl LWEInfos for GLWECiphertextLayout { } } -impl GLWEInfos for GLWECiphertextLayout { +impl GLWEInfos for GLWELayout { fn rank(&self) -> Rank { self.rank } } #[derive(PartialEq, Eq, Clone)] -pub struct GLWECiphertext { +pub struct GLWE { pub(crate) data: VecZnx, pub(crate) base2k: Base2K, pub(crate) k: TorusPrecision, } -impl GLWELayoutSet for GLWECiphertext { - fn set_basek(&mut self, base2k: Base2K) { +impl SetGLWEInfos for GLWE { + fn set_base2k(&mut self, base2k: Base2K) { self.base2k = base2k } @@ -74,99 +75,19 @@ impl GLWELayoutSet for GLWECiphertext { } } -impl GLWECiphertext { +impl GLWE { pub fn data(&self) -> &VecZnx { &self.data } } -impl GLWECiphertext { +impl GLWE { pub fn data_mut(&mut self) -> &mut VecZnx { &mut self.data } } -pub struct GLWECiphertextBuilder { - data: Option>, - base2k: Option, - k: Option, -} - -impl GLWECiphertext { - #[inline] - pub fn builder() -> GLWECiphertextBuilder { - GLWECiphertextBuilder { - data: None, - base2k: None, - k: None, - } - } -} - -impl GLWECiphertextBuilder> { - #[inline] - pub fn layout(mut self, layout: &A) -> Self - where - A: GLWEInfos, - { - self.data = Some(VecZnx::alloc( - layout.n().into(), - (layout.rank() + 1).into(), - layout.size(), - )); - self.base2k = Some(layout.base2k()); - self.k = Some(layout.k()); - self - } -} - -impl GLWECiphertextBuilder { - #[inline] - pub fn data(mut self, data: VecZnx) -> Self { - self.data = Some(data); - self - } - #[inline] - pub fn base2k(mut self, base2k: Base2K) -> Self { - self.base2k = Some(base2k); - self - } - #[inline] - pub fn k(mut self, k: TorusPrecision) -> Self { - self.k = Some(k); - self - } - - pub fn build(self) -> Result, BuildError> { - let data: VecZnx = self.data.ok_or(BuildError::MissingData)?; - let base2k: Base2K = self.base2k.ok_or(BuildError::MissingBase2K)?; - let k: TorusPrecision = self.k.ok_or(BuildError::MissingK)?; - - if base2k == 0_u32 { - return Err(BuildError::ZeroBase2K); - } - - if k == 0_u32 { - return Err(BuildError::ZeroTorusPrecision); - } - - if data.n() == 0 { - return Err(BuildError::ZeroDegree); - } - - if data.cols() == 0 { - return Err(BuildError::ZeroCols); - } - - if data.size() == 0 { - return Err(BuildError::ZeroLimbs); - } - - Ok(GLWECiphertext { data, base2k, k }) - } -} - -impl LWEInfos for GLWECiphertext { +impl LWEInfos for GLWE { fn base2k(&self) -> Base2K { self.base2k } @@ -175,8 +96,8 @@ impl LWEInfos for GLWECiphertext { self.k } - fn n(&self) -> Degree { - Degree(self.data.n() as u32) + fn n(&self) -> RingDegree { + RingDegree(self.data.n() as u32) } fn size(&self) -> usize { @@ -184,16 +105,16 @@ impl LWEInfos for GLWECiphertext { } } -impl GLWEInfos for GLWECiphertext { +impl GLWEInfos for GLWE { fn rank(&self) -> Rank { Rank(self.data.cols() as u32 - 1) } } -impl ToOwnedDeep for GLWECiphertext { - type Owned = GLWECiphertext>; +impl ToOwnedDeep for GLWE { + type Owned = GLWE>; fn to_owned_deep(&self) -> Self::Owned { - GLWECiphertext { + GLWE { data: self.data.to_owned_deep(), k: self.k, base2k: self.base2k, @@ -201,17 +122,17 @@ impl ToOwnedDeep for GLWECiphertext { } } -impl fmt::Debug for GLWECiphertext { +impl fmt::Debug for GLWE { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{self}") } } -impl fmt::Display for GLWECiphertext { +impl fmt::Display for GLWE { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "GLWECiphertext: base2k={} k={}: {}", + "GLWE: base2k={} k={}: {}", self.base2k().0, self.k().0, self.data @@ -219,71 +140,86 @@ impl fmt::Display for GLWECiphertext { } } -impl FillUniform for GLWECiphertext { +impl FillUniform for GLWE { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.data.fill_uniform(log_bound, source); } } -impl GLWECiphertext> { - pub fn alloc(infos: &A) -> Self - where - A: GLWEInfos, - { - Self::alloc_with(infos.n(), infos.base2k(), infos.k(), infos.rank()) - } - - pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self { - Self { - data: VecZnx::alloc(n.into(), (rank + 1).into(), k.0.div_ceil(base2k.0) as usize), +pub trait GLWEAlloc +where + Self: GetRingDegree, +{ + fn alloc_glwe(&self, base2k: Base2K, k: TorusPrecision, rank: Rank) -> GLWE> { + GLWE { + data: VecZnx::alloc( + self.ring_degree().into(), + (rank + 1).into(), + k.0.div_ceil(base2k.0) as usize, + ), base2k, k, } } - pub fn alloc_bytes(infos: &A) -> usize + fn alloc_glwe_from_infos(&self, infos: &A) -> GLWE> where A: GLWEInfos, { - Self::alloc_bytes_with(infos.n(), infos.base2k(), infos.k(), infos.rank()) + self.alloc_glwe(infos.base2k(), infos.k(), infos.rank()) } - pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize { - VecZnx::alloc_bytes(n.into(), (rank + 1).into(), k.0.div_ceil(base2k.0) as usize) + fn bytes_of_glwe(&self, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize { + VecZnx::bytes_of( + self.ring_degree().into(), + (rank + 1).into(), + k.0.div_ceil(base2k.0) as usize, + ) + } + + fn bytes_of_glwe_from_infos(&self, infos: &A) -> usize + where + A: GLWEInfos, + { + self.bytes_of_glwe(infos.base2k(), infos.k(), infos.rank()) } } -pub trait GLWECiphertextToRef { - fn to_ref(&self) -> GLWECiphertext<&[u8]>; -} +impl GLWEAlloc for Module where Self: GetRingDegree {} -impl GLWECiphertextToRef for GLWECiphertext { - fn to_ref(&self) -> GLWECiphertext<&[u8]> { - GLWECiphertext::builder() - .k(self.k()) - .base2k(self.base2k()) - .data(self.data.to_ref()) - .build() - .unwrap() +impl GLWE> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: GLWEInfos, + M: GLWEAlloc, + { + module.alloc_glwe_from_infos(infos) + } + + pub fn alloc(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self + where + M: GLWEAlloc, + { + module.alloc_glwe(base2k, k, rank) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: GLWEInfos, + M: GLWEAlloc, + { + module.bytes_of_glwe_from_infos(infos) + } + + pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize + where + M: GLWEAlloc, + { + module.bytes_of_glwe(base2k, k, rank) } } -pub trait GLWECiphertextToMut { - fn to_mut(&mut self) -> GLWECiphertext<&mut [u8]>; -} - -impl GLWECiphertextToMut for GLWECiphertext { - fn to_mut(&mut self) -> GLWECiphertext<&mut [u8]> { - GLWECiphertext::builder() - .k(self.k()) - .base2k(self.base2k()) - .data(self.data.to_mut()) - .build() - .unwrap() - } -} - -impl ReaderFrom for GLWECiphertext { +impl ReaderFrom for GLWE { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { self.k = TorusPrecision(reader.read_u32::()?); self.base2k = Base2K(reader.read_u32::()?); @@ -291,10 +227,38 @@ impl ReaderFrom for GLWECiphertext { } } -impl WriterTo for GLWECiphertext { +impl WriterTo for GLWE { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { writer.write_u32::(self.k.0)?; writer.write_u32::(self.base2k.0)?; self.data.write_to(writer) } } + +pub trait GLWEToRef { + fn to_ref(&self) -> GLWE<&[u8]>; +} + +impl GLWEToRef for GLWE { + fn to_ref(&self) -> GLWE<&[u8]> { + GLWE { + k: self.k, + base2k: self.base2k, + data: self.data.to_ref(), + } + } +} + +pub trait GLWEToMut { + fn to_mut(&mut self) -> GLWE<&mut [u8]>; +} + +impl GLWEToMut for GLWE { + fn to_mut(&mut self) -> GLWE<&mut [u8]> { + GLWE { + k: self.k, + base2k: self.base2k, + data: self.data.to_mut(), + } + } +} diff --git a/poulpy-core/src/layouts/glwe_pk.rs b/poulpy-core/src/layouts/glwe_pk.rs index fc4b0fa..16fc6ea 100644 --- a/poulpy-core/src/layouts/glwe_pk.rs +++ b/poulpy-core/src/layouts/glwe_pk.rs @@ -1,8 +1,10 @@ -use poulpy_hal::layouts::{Data, DataMut, DataRef, ReaderFrom, VecZnx, WriterTo, ZnxInfos}; +use poulpy_hal::layouts::{ + Backend, Data, DataMut, DataRef, Module, ReaderFrom, VecZnx, VecZnxToMut, VecZnxToRef, WriterTo, ZnxInfos, +}; use crate::{ dist::Distribution, - layouts::{Base2K, BuildError, Degree, GLWEInfos, LWEInfos, Rank, TorusPrecision}, + layouts::{Base2K, GLWEInfos, GetRingDegree, LWEInfos, Rank, RingDegree, TorusPrecision}, }; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; @@ -16,12 +18,22 @@ pub struct GLWEPublicKey { #[derive(PartialEq, Eq, Copy, Clone, Debug)] pub struct GLWEPublicKeyLayout { - pub n: Degree, + pub n: RingDegree, pub base2k: Base2K, pub k: TorusPrecision, pub rank: Rank, } +pub trait GetDist { + fn get_dist(&self) -> Distribution; +} + +impl GetDist for GLWEPublicKey { + fn get_dist(&self) -> Distribution { + self.dist + } +} + impl LWEInfos for GLWEPublicKey { fn base2k(&self) -> Base2K { self.base2k @@ -31,8 +43,8 @@ impl LWEInfos for GLWEPublicKey { self.k } - fn n(&self) -> Degree { - Degree(self.data.n() as u32) + fn n(&self) -> RingDegree { + RingDegree(self.data.n() as u32) } fn size(&self) -> usize { @@ -55,7 +67,7 @@ impl LWEInfos for GLWEPublicKeyLayout { self.k } - fn n(&self) -> Degree { + fn n(&self) -> RingDegree { self.n } @@ -70,117 +82,77 @@ impl GLWEInfos for GLWEPublicKeyLayout { } } -pub struct GLWEPublicKeyBuilder { - data: Option>, - base2k: Option, - k: Option, -} - -impl GLWEPublicKey { - #[inline] - pub fn builder() -> GLWEPublicKeyBuilder { - GLWEPublicKeyBuilder { - data: None, - base2k: None, - k: None, - } - } -} - -impl GLWEPublicKeyBuilder> { - #[inline] - pub fn layout(mut self, layout: &A) -> Self - where - A: GLWEInfos, - { - self.data = Some(VecZnx::alloc( - layout.n().into(), - (layout.rank() + 1).into(), - layout.size(), - )); - self.base2k = Some(layout.base2k()); - self.k = Some(layout.k()); - self - } -} - -impl GLWEPublicKeyBuilder { - #[inline] - pub fn data(mut self, data: VecZnx) -> Self { - self.data = Some(data); - self - } - #[inline] - pub fn base2k(mut self, base2k: Base2K) -> Self { - self.base2k = Some(base2k); - self - } - #[inline] - pub fn k(mut self, k: TorusPrecision) -> Self { - self.k = Some(k); - self - } - - pub fn build(self) -> Result, BuildError> { - let data: VecZnx = self.data.ok_or(BuildError::MissingData)?; - let base2k: Base2K = self.base2k.ok_or(BuildError::MissingBase2K)?; - let k: TorusPrecision = self.k.ok_or(BuildError::MissingK)?; - - if base2k == 0_u32 { - return Err(BuildError::ZeroBase2K); - } - - if k == 0_u32 { - return Err(BuildError::ZeroTorusPrecision); - } - - if data.n() == 0 { - return Err(BuildError::ZeroDegree); - } - - if data.cols() == 0 { - return Err(BuildError::ZeroCols); - } - - if data.size() == 0 { - return Err(BuildError::ZeroLimbs); - } - - Ok(GLWEPublicKey { - data, +pub trait GLWEPublicKeyAlloc +where + Self: GetRingDegree, +{ + fn alloc_glwe_public_key(&self, base2k: Base2K, k: TorusPrecision, rank: Rank) -> GLWEPublicKey> { + GLWEPublicKey { + data: VecZnx::alloc( + self.ring_degree().into(), + (rank + 1).into(), + k.0.div_ceil(base2k.0) as usize, + ), base2k, k, dist: Distribution::NONE, - }) + } + } + + fn alloc_glwe_public_key_from_infos(&self, infos: &A) -> GLWEPublicKey> + where + A: GLWEInfos, + { + self.alloc_glwe_public_key(infos.base2k(), infos.k(), infos.rank()) + } + + fn bytes_of_glwe_public_key(&self, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize { + VecZnx::bytes_of( + self.ring_degree().into(), + (rank + 1).into(), + k.0.div_ceil(base2k.0) as usize, + ) + } + + fn bytes_of_glwe_public_key_from_infos(&self, infos: &A) -> usize + where + A: GLWEInfos, + { + self.bytes_of_glwe_public_key(infos.base2k(), infos.k(), infos.rank()) } } +impl GLWEPublicKeyAlloc for Module where Self: GetRingDegree {} + impl GLWEPublicKey> { - pub fn alloc(infos: &A) -> Self + pub fn alloc_from_infos(module: &M, infos: &A) -> Self where A: GLWEInfos, + M: GLWEPublicKeyAlloc, { - Self::alloc_with(infos.n(), infos.base2k(), infos.k(), infos.rank()) + module.alloc_glwe_public_key_from_infos(infos) } - pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self { - Self { - data: VecZnx::alloc(n.into(), (rank + 1).into(), k.0.div_ceil(base2k.0) as usize), - base2k, - k, - dist: Distribution::NONE, - } + pub fn alloc(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self + where + M: GLWEPublicKeyAlloc, + { + module.alloc_glwe_public_key(base2k, k, rank) } - pub fn alloc_bytes(infos: &A) -> usize + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize where A: GLWEInfos, + M: GLWEPublicKeyAlloc, { - Self::alloc_bytes_with(infos.n(), infos.base2k(), infos.k(), infos.rank()) + module.bytes_of_glwe_public_key_from_infos(infos) } - pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize { - VecZnx::alloc_bytes(n.into(), (rank + 1).into(), k.0.div_ceil(base2k.0) as usize) + pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize + where + M: GLWEPublicKeyAlloc, + { + module.bytes_of_glwe_public_key(base2k, k, rank) } } @@ -207,3 +179,33 @@ impl WriterTo for GLWEPublicKey { self.data.write_to(writer) } } + +pub trait GLWEPublicKeyToRef { + fn to_ref(&self) -> GLWEPublicKey<&[u8]>; +} + +impl GLWEPublicKeyToRef for GLWEPublicKey { + fn to_ref(&self) -> GLWEPublicKey<&[u8]> { + GLWEPublicKey { + data: self.data.to_ref(), + base2k: self.base2k, + k: self.k, + dist: self.dist, + } + } +} + +pub trait GLWEPublicKeyToMut { + fn to_mut(&mut self) -> GLWEPublicKey<&mut [u8]>; +} + +impl GLWEPublicKeyToMut for GLWEPublicKey { + fn to_mut(&mut self) -> GLWEPublicKey<&mut [u8]> { + GLWEPublicKey { + base2k: self.base2k, + k: self.k, + dist: self.dist, + data: self.data.to_mut(), + } + } +} diff --git a/poulpy-core/src/layouts/glwe_pt.rs b/poulpy-core/src/layouts/glwe_pt.rs index b565055..3d5ce86 100644 --- a/poulpy-core/src/layouts/glwe_pt.rs +++ b/poulpy-core/src/layouts/glwe_pt.rs @@ -1,15 +1,14 @@ use std::fmt; -use poulpy_hal::layouts::{Data, DataMut, DataRef, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos}; +use poulpy_hal::layouts::{Backend, Data, DataMut, DataRef, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos}; use crate::layouts::{ - Base2K, BuildError, Degree, GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEInfos, GLWELayoutSet, LWEInfos, - Rank, TorusPrecision, + Base2K, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, GetRingDegree, LWEInfos, Rank, RingDegree, SetGLWEInfos, TorusPrecision, }; #[derive(PartialEq, Eq, Copy, Clone, Debug)] pub struct GLWEPlaintextLayout { - pub n: Degree, + pub n: RingDegree, pub base2k: Base2K, pub k: TorusPrecision, } @@ -23,7 +22,7 @@ impl LWEInfos for GLWEPlaintextLayout { self.k } - fn n(&self) -> Degree { + fn n(&self) -> RingDegree { self.n } } @@ -40,8 +39,8 @@ pub struct GLWEPlaintext { pub k: TorusPrecision, } -impl GLWELayoutSet for GLWEPlaintext { - fn set_basek(&mut self, base2k: Base2K) { +impl SetGLWEInfos for GLWEPlaintext { + fn set_base2k(&mut self, base2k: Base2K) { self.base2k = base2k } @@ -63,8 +62,8 @@ impl LWEInfos for GLWEPlaintext { self.data.size() } - fn n(&self) -> Degree { - Degree(self.data.n() as u32) + fn n(&self) -> RingDegree { + RingDegree(self.data.n() as u32) } } @@ -74,69 +73,6 @@ impl GLWEInfos for GLWEPlaintext { } } -pub struct GLWEPlaintextBuilder { - data: Option>, - base2k: Option, - k: Option, -} - -impl GLWEPlaintext { - #[inline] - pub fn builder() -> GLWEPlaintextBuilder { - GLWEPlaintextBuilder { - data: None, - base2k: None, - k: None, - } - } -} - -impl GLWEPlaintextBuilder { - #[inline] - pub fn data(mut self, data: VecZnx) -> Self { - self.data = Some(data); - self - } - #[inline] - pub fn base2k(mut self, base2k: Base2K) -> Self { - self.base2k = Some(base2k); - self - } - #[inline] - pub fn k(mut self, k: TorusPrecision) -> Self { - self.k = Some(k); - self - } - - pub fn build(self) -> Result, BuildError> { - let data: VecZnx = self.data.ok_or(BuildError::MissingData)?; - let base2k: Base2K = self.base2k.ok_or(BuildError::MissingBase2K)?; - let k: TorusPrecision = self.k.ok_or(BuildError::MissingK)?; - - if base2k.0 == 0 { - return Err(BuildError::ZeroBase2K); - } - - if k.0 == 0 { - return Err(BuildError::ZeroTorusPrecision); - } - - if data.n() == 0 { - return Err(BuildError::ZeroDegree); - } - - if data.cols() != 1 { - return Err(BuildError::ZeroCols); - } - - if data.size() == 0 { - return Err(BuildError::ZeroLimbs); - } - - Ok(GLWEPlaintext { data, base2k, k }) - } -} - impl fmt::Display for GLWEPlaintext { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( @@ -149,54 +85,123 @@ impl fmt::Display for GLWEPlaintext { } } -impl GLWEPlaintext> { - pub fn alloc(infos: &A) -> Self - where - A: GLWEInfos, - { - Self::alloc_with(infos.n(), infos.base2k(), infos.k(), Rank(0)) - } - - pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self { - debug_assert!(rank.0 == 0); - Self { - data: VecZnx::alloc(n.into(), (rank + 1).into(), k.0.div_ceil(base2k.0) as usize), +pub trait GLWEPlaintextAlloc +where + Self: GetRingDegree, +{ + fn alloc_glwe_plaintext(&self, base2k: Base2K, k: TorusPrecision) -> GLWEPlaintext> { + GLWEPlaintext { + data: VecZnx::alloc( + self.ring_degree().into(), + 1, + k.0.div_ceil(base2k.0) as usize, + ), base2k, k, } } - pub fn alloc_bytes(infos: &A) -> usize + fn alloc_glwe_plaintext_from_infos(&self, infos: &A) -> GLWEPlaintext> where A: GLWEInfos, { - Self::alloc_bytes_with(infos.n(), infos.base2k(), infos.k(), Rank(0)) + self.alloc_glwe_plaintext(infos.base2k(), infos.k()) } - pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize { - debug_assert!(rank.0 == 0); - VecZnx::alloc_bytes(n.into(), (rank + 1).into(), k.0.div_ceil(base2k.0) as usize) + fn bytes_of_glwe_plaintext(&self, base2k: Base2K, k: TorusPrecision) -> usize { + VecZnx::bytes_of( + self.ring_degree().into(), + 1, + k.0.div_ceil(base2k.0) as usize, + ) + } + + fn bytes_of_glwe_plaintext_from_infos(&self, infos: &A) -> usize + where + A: GLWEInfos, + { + self.bytes_of_glwe_plaintext(infos.base2k(), infos.k()) } } -impl GLWECiphertextToRef for GLWEPlaintext { - fn to_ref(&self) -> GLWECiphertext<&[u8]> { - GLWECiphertext::builder() - .data(self.data.to_ref()) - .k(self.k()) - .base2k(self.base2k()) - .build() - .unwrap() +impl GLWEPlaintextAlloc for Module where Self: GetRingDegree {} + +impl GLWEPlaintext> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: GLWEInfos, + M: GLWEPlaintextAlloc, + { + module.alloc_glwe_plaintext_from_infos(infos) + } + + pub fn alloc(module: &M, base2k: Base2K, k: TorusPrecision) -> Self + where + M: GLWEPlaintextAlloc, + { + module.alloc_glwe_plaintext(base2k, k) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: GLWEInfos, + M: GLWEPlaintextAlloc, + { + module.bytes_of_glwe_plaintext_from_infos(infos) + } + + pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision) -> usize + where + M: GLWEPlaintextAlloc, + { + module.bytes_of_glwe_plaintext(base2k, k) } } -impl GLWECiphertextToMut for GLWEPlaintext { - fn to_mut(&mut self) -> GLWECiphertext<&mut [u8]> { - GLWECiphertext::builder() - .k(self.k()) - .base2k(self.base2k()) - .data(self.data.to_mut()) - .build() - .unwrap() +impl GLWEToRef for GLWEPlaintext { + fn to_ref(&self) -> GLWE<&[u8]> { + GLWE { + k: self.k, + base2k: self.base2k, + data: self.data.to_ref(), + } + } +} + +impl GLWEToMut for GLWEPlaintext { + fn to_mut(&mut self) -> GLWE<&mut [u8]> { + GLWE { + k: self.k, + base2k: self.base2k, + data: self.data.to_mut(), + } + } +} + +pub trait GLWEPlaintextToRef { + fn to_ref(&self) -> GLWEPlaintext<&[u8]>; +} + +impl GLWEPlaintextToRef for GLWEPlaintext { + fn to_ref(&self) -> GLWEPlaintext<&[u8]> { + GLWEPlaintext { + data: self.data.to_ref(), + base2k: self.base2k, + k: self.k, + } + } +} + +pub trait GLWEPlaintextToMut { + fn to_ref(&mut self) -> GLWEPlaintext<&mut [u8]>; +} + +impl GLWEPlaintextToMut for GLWEPlaintext { + fn to_ref(&mut self) -> GLWEPlaintext<&mut [u8]> { + GLWEPlaintext { + base2k: self.base2k, + k: self.k, + data: self.data.to_mut(), + } } } diff --git a/poulpy-core/src/layouts/glwe_sk.rs b/poulpy-core/src/layouts/glwe_sk.rs index 8870d35..9166388 100644 --- a/poulpy-core/src/layouts/glwe_sk.rs +++ b/poulpy-core/src/layouts/glwe_sk.rs @@ -1,16 +1,19 @@ use poulpy_hal::{ - layouts::{Data, DataMut, DataRef, ReaderFrom, ScalarZnx, WriterTo, ZnxInfos, ZnxZero}, + layouts::{ + Backend, Data, DataMut, DataRef, Module, ReaderFrom, ScalarZnx, ScalarZnxToMut, ScalarZnxToRef, WriterTo, ZnxInfos, + ZnxZero, + }, source::Source, }; use crate::{ dist::Distribution, - layouts::{Base2K, Degree, GLWEInfos, LWEInfos, Rank, TorusPrecision}, + layouts::{Base2K, GLWEInfos, GetRingDegree, LWEInfos, Rank, RingDegree, TorusPrecision}, }; #[derive(PartialEq, Eq, Copy, Clone, Debug)] pub struct GLWESecretLayout { - pub n: Degree, + pub n: RingDegree, pub rank: Rank, } @@ -23,7 +26,7 @@ impl LWEInfos for GLWESecretLayout { TorusPrecision(0) } - fn n(&self) -> Degree { + fn n(&self) -> RingDegree { self.n } @@ -52,8 +55,8 @@ impl LWEInfos for GLWESecret { TorusPrecision(0) } - fn n(&self) -> Degree { - Degree(self.data.n() as u32) + fn n(&self) -> RingDegree { + RingDegree(self.data.n() as u32) } fn size(&self) -> usize { @@ -67,30 +70,67 @@ impl GLWEInfos for GLWESecret { } } -impl GLWESecret> { - pub fn alloc(infos: &A) -> Self - where - A: GLWEInfos, - { - Self::alloc_with(infos.n(), infos.rank()) - } - - pub fn alloc_with(n: Degree, rank: Rank) -> Self { - Self { - data: ScalarZnx::alloc(n.into(), rank.into()), +pub trait GLWESecretAlloc +where + Self: GetRingDegree, +{ + fn alloc_glwe_secret(&self, rank: Rank) -> GLWESecret> { + GLWESecret { + data: ScalarZnx::alloc(self.ring_degree().into(), rank.into()), dist: Distribution::NONE, } } - pub fn alloc_bytes(infos: &A) -> usize + fn alloc_glwe_secret_from_infos(&self, infos: &A) -> GLWESecret> where A: GLWEInfos, { - Self::alloc_bytes_with(infos.n(), infos.rank()) + self.alloc_glwe_secret(infos.rank()) } - pub fn alloc_bytes_with(n: Degree, rank: Rank) -> usize { - ScalarZnx::alloc_bytes(n.into(), rank.into()) + fn bytes_of_glwe_secret(&self, rank: Rank) -> usize { + ScalarZnx::bytes_of(self.ring_degree().into(), rank.into()) + } + + fn bytes_of_glwe_secret_from_infos(&self, infos: &A) -> usize + where + A: GLWEInfos, + { + self.bytes_of_glwe_secret(infos.rank()) + } +} + +impl GLWESecretAlloc for Module where Self: GetRingDegree {} + +impl GLWESecret> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: GLWEInfos, + M: GLWESecretAlloc, + { + module.alloc_glwe_secret_from_infos(infos) + } + + pub fn alloc(module: &M, rank: Rank) -> Self + where + M: GLWESecretAlloc, + { + module.alloc_glwe_secret(rank) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: GLWEInfos, + M: GLWESecretAlloc, + { + module.bytes_of_glwe_secret_from_infos(infos) + } + + pub fn bytes_of(module: &M, rank: Rank) -> usize + where + M: GLWESecretAlloc, + { + module.bytes_of_glwe_secret(rank) } } @@ -136,6 +176,32 @@ impl GLWESecret { } } +pub trait GLWESecretToMut { + fn to_mut(&mut self) -> GLWESecret<&mut [u8]>; +} + +impl GLWESecretToMut for GLWESecret { + fn to_mut(&mut self) -> GLWESecret<&mut [u8]> { + GLWESecret { + dist: self.dist, + data: self.data.to_mut(), + } + } +} + +pub trait GLWESecretToRef { + fn to_ref(&self) -> GLWESecret<&[u8]>; +} + +impl GLWESecretToRef for GLWESecret { + fn to_ref(&self) -> GLWESecret<&[u8]> { + GLWESecret { + data: self.data.to_ref(), + dist: self.dist, + } + } +} + impl ReaderFrom for GLWESecret { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { match Distribution::read_from(reader) { diff --git a/poulpy-core/src/layouts/glwe_to_lwe_ksk.rs b/poulpy-core/src/layouts/glwe_to_lwe_ksk.rs index f227c9c..671f018 100644 --- a/poulpy-core/src/layouts/glwe_to_lwe_ksk.rs +++ b/poulpy-core/src/layouts/glwe_to_lwe_ksk.rs @@ -1,15 +1,18 @@ use poulpy_hal::{ - layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, WriterTo}, + layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo}, source::Source, }; -use crate::layouts::{Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWESwitchingKey, GLWEInfos, LWEInfos, Rank, TorusPrecision}; +use crate::layouts::{ + Base2K, Dnum, Dsize, GGLWEInfos, GLWEInfos, GLWESwitchingKey, GLWESwitchingKeyAlloc, GLWESwitchingKeyToMut, + GLWESwitchingKeyToRef, LWEInfos, Rank, RingDegree, TorusPrecision, +}; use std::fmt; #[derive(PartialEq, Eq, Copy, Clone, Debug)] pub struct GLWEToLWEKeyLayout { - pub n: Degree, + pub n: RingDegree, pub base2k: Base2K, pub k: TorusPrecision, pub rank_in: Rank, @@ -17,7 +20,7 @@ pub struct GLWEToLWEKeyLayout { } impl LWEInfos for GLWEToLWEKeyLayout { - fn n(&self) -> Degree { + fn n(&self) -> RingDegree { self.n } @@ -54,11 +57,11 @@ impl GGLWEInfos for GLWEToLWEKeyLayout { } } -/// A special [GLWESwitchingKey] required to for the conversion from [GLWECiphertext] to [LWECiphertext]. +/// A special [GLWESwitchingKey] required to for the conversion from [GLWE] to [LWE]. #[derive(PartialEq, Eq, Clone)] -pub struct GLWEToLWEKey(pub(crate) GGLWESwitchingKey); +pub struct GLWEToLWESwitchingKey(pub(crate) GLWESwitchingKey); -impl LWEInfos for GLWEToLWEKey { +impl LWEInfos for GLWEToLWESwitchingKey { fn base2k(&self) -> Base2K { self.0.base2k() } @@ -67,7 +70,7 @@ impl LWEInfos for GLWEToLWEKey { self.0.k() } - fn n(&self) -> Degree { + fn n(&self) -> RingDegree { self.0.n() } @@ -76,12 +79,12 @@ impl LWEInfos for GLWEToLWEKey { } } -impl GLWEInfos for GLWEToLWEKey { +impl GLWEInfos for GLWEToLWESwitchingKey { fn rank(&self) -> Rank { self.rank_out() } } -impl GGLWEInfos for GLWEToLWEKey { +impl GGLWEInfos for GLWEToLWESwitchingKey { fn rank_in(&self) -> Rank { self.0.rank_in() } @@ -99,84 +102,145 @@ impl GGLWEInfos for GLWEToLWEKey { } } -impl fmt::Debug for GLWEToLWEKey { +impl fmt::Debug for GLWEToLWESwitchingKey { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{self}") } } -impl FillUniform for GLWEToLWEKey { +impl FillUniform for GLWEToLWESwitchingKey { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.0.fill_uniform(log_bound, source); } } -impl fmt::Display for GLWEToLWEKey { +impl fmt::Display for GLWEToLWESwitchingKey { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "(GLWEToLWESwitchingKey) {}", self.0) } } -impl ReaderFrom for GLWEToLWEKey { +impl ReaderFrom for GLWEToLWESwitchingKey { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { self.0.read_from(reader) } } -impl WriterTo for GLWEToLWEKey { +impl WriterTo for GLWEToLWESwitchingKey { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { self.0.write_to(writer) } } -impl GLWEToLWEKey> { - pub fn alloc(infos: &A) -> Self +pub trait GLWEToLWESwitchingKeyAlloc +where + Self: GLWESwitchingKeyAlloc, +{ + fn alloc_glwe_to_lwe_switching_key( + &self, + base2k: Base2K, + k: TorusPrecision, + rank_in: Rank, + dnum: Dnum, + ) -> GLWEToLWESwitchingKey> { + GLWEToLWESwitchingKey(self.alloc_glwe_switching_key(base2k, k, rank_in, Rank(1), dnum, Dsize(1))) + } + + fn alloc_glwe_to_lwe_switching_key_from_infos(&self, infos: &A) -> GLWEToLWESwitchingKey> where A: GGLWEInfos, { - debug_assert_eq!( + assert_eq!( infos.rank_out().0, 1, "rank_out > 1 is not supported for GLWEToLWESwitchingKey" ); - debug_assert_eq!( + assert_eq!( infos.dsize().0, 1, "dsize > 1 is not supported for GLWEToLWESwitchingKey" ); - Self(GGLWESwitchingKey::alloc(infos)) + self.alloc_glwe_to_lwe_switching_key(infos.base2k(), infos.k(), infos.rank_in(), infos.dnum()) } - pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> Self { - Self(GGLWESwitchingKey::alloc_with( - n, - base2k, - k, - rank_in, - Rank(1), - dnum, - Dsize(1), - )) + fn bytes_of_glwe_to_lwe_switching_key(&self, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> usize { + self.bytes_of_glwe_switching_key(base2k, k, rank_in, Rank(1), dnum, Dsize(1)) } - pub fn alloc_bytes(infos: &A) -> usize + fn bytes_of_glwe_to_lwe_switching_key_from_infos(&self, infos: &A) -> usize where A: GGLWEInfos, { - debug_assert_eq!( + assert_eq!( infos.rank_out().0, 1, "rank_out > 1 is not supported for GLWEToLWESwitchingKey" ); - debug_assert_eq!( + assert_eq!( infos.dsize().0, 1, "dsize > 1 is not supported for GLWEToLWESwitchingKey" ); - GGLWESwitchingKey::alloc_bytes(infos) - } - - pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> usize { - GGLWESwitchingKey::alloc_bytes_with(n, base2k, k, rank_in, Rank(1), dnum, Dsize(1)) + self.bytes_of_glwe_to_lwe_switching_key(infos.base2k(), infos.k(), infos.rank_in(), infos.dnum()) + } +} + +impl GLWEToLWESwitchingKeyAlloc for Module where Self: GLWESwitchingKeyAlloc {} + +impl GLWEToLWESwitchingKey> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: GGLWEInfos, + M: GLWEToLWESwitchingKeyAlloc, + { + module.alloc_glwe_to_lwe_switching_key_from_infos(infos) + } + + pub fn alloc(module: &M, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> Self + where + M: GLWEToLWESwitchingKeyAlloc, + { + module.alloc_glwe_to_lwe_switching_key(base2k, k, rank_in, dnum) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: GLWEToLWESwitchingKeyAlloc, + { + module.bytes_of_glwe_to_lwe_switching_key_from_infos(infos) + } + + pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> usize + where + M: GLWEToLWESwitchingKeyAlloc, + { + module.bytes_of_glwe_to_lwe_switching_key(base2k, k, rank_in, dnum) + } +} + +pub trait GLWEToLWESwitchingKeyToRef { + fn to_ref(&self) -> GLWEToLWESwitchingKey<&[u8]>; +} + +impl GLWEToLWESwitchingKeyToRef for GLWEToLWESwitchingKey +where + GLWESwitchingKey: GLWESwitchingKeyToRef, +{ + fn to_ref(&self) -> GLWEToLWESwitchingKey<&[u8]> { + GLWEToLWESwitchingKey(self.0.to_ref()) + } +} + +pub trait GLWEToLWESwitchingKeyToMut { + fn to_mut(&mut self) -> GLWEToLWESwitchingKey<&mut [u8]>; +} + +impl GLWEToLWESwitchingKeyToMut for GLWEToLWESwitchingKey +where + GLWESwitchingKey: GLWESwitchingKeyToMut, +{ + fn to_mut(&mut self) -> GLWEToLWESwitchingKey<&mut [u8]> { + GLWEToLWESwitchingKey(self.0.to_mut()) } } diff --git a/poulpy-core/src/layouts/lwe_ct.rs b/poulpy-core/src/layouts/lwe_ct.rs index 1560ea4..0c8831f 100644 --- a/poulpy-core/src/layouts/lwe_ct.rs +++ b/poulpy-core/src/layouts/lwe_ct.rs @@ -1,15 +1,15 @@ use std::fmt; use poulpy_hal::{ - layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, WriterTo, Zn, ZnToMut, ZnToRef, ZnxInfos}, + layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo, Zn, ZnToMut, ZnToRef, ZnxInfos}, source::Source, }; -use crate::layouts::{Base2K, BuildError, Degree, TorusPrecision}; +use crate::layouts::{Base2K, RingDegree, TorusPrecision}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; pub trait LWEInfos { - fn n(&self) -> Degree; + fn n(&self) -> RingDegree; fn k(&self) -> TorusPrecision; fn max_k(&self) -> TorusPrecision { TorusPrecision(self.k().0 * self.size() as u32) @@ -18,8 +18,8 @@ pub trait LWEInfos { fn size(&self) -> usize { self.k().0.div_ceil(self.base2k().0) as usize } - fn lwe_layout(&self) -> LWECiphertextLayout { - LWECiphertextLayout { + fn lwe_layout(&self) -> LWELayout { + LWELayout { n: self.n(), k: self.k(), base2k: self.base2k(), @@ -27,14 +27,19 @@ pub trait LWEInfos { } } +pub trait SetLWEInfos { + fn set_k(&mut self, k: TorusPrecision); + fn set_base2k(&mut self, base2k: Base2K); +} + #[derive(PartialEq, Eq, Copy, Clone, Debug)] -pub struct LWECiphertextLayout { - pub n: Degree, +pub struct LWELayout { + pub n: RingDegree, pub k: TorusPrecision, pub base2k: Base2K, } -impl LWEInfos for LWECiphertextLayout { +impl LWEInfos for LWELayout { fn base2k(&self) -> Base2K { self.base2k } @@ -43,19 +48,18 @@ impl LWEInfos for LWECiphertextLayout { self.k } - fn n(&self) -> Degree { + fn n(&self) -> RingDegree { self.n } } - #[derive(PartialEq, Eq, Clone)] -pub struct LWECiphertext { +pub struct LWE { pub(crate) data: Zn, pub(crate) k: TorusPrecision, pub(crate) base2k: Base2K, } -impl LWEInfos for LWECiphertext { +impl LWEInfos for LWE { fn base2k(&self) -> Base2K { self.base2k } @@ -63,8 +67,8 @@ impl LWEInfos for LWECiphertext { fn k(&self) -> TorusPrecision { self.k } - fn n(&self) -> Degree { - Degree(self.data.n() as u32 - 1) + fn n(&self) -> RingDegree { + RingDegree(self.data.n() as u32 - 1) } fn size(&self) -> usize { @@ -72,29 +76,39 @@ impl LWEInfos for LWECiphertext { } } -impl LWECiphertext { +impl SetLWEInfos for LWE { + fn set_base2k(&mut self, base2k: Base2K) { + self.base2k = base2k + } + + fn set_k(&mut self, k: TorusPrecision) { + self.k = k + } +} + +impl LWE { pub fn data(&self) -> &Zn { &self.data } } -impl LWECiphertext { +impl LWE { pub fn data_mut(&mut self) -> &Zn { &mut self.data } } -impl fmt::Debug for LWECiphertext { +impl fmt::Debug for LWE { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{self}") } } -impl fmt::Display for LWECiphertext { +impl fmt::Display for LWE { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "LWECiphertext: base2k={} k={}: {}", + "LWE: base2k={} k={}: {}", self.base2k().0, self.k().0, self.data @@ -102,7 +116,7 @@ impl fmt::Display for LWECiphertext { } } -impl FillUniform for LWECiphertext +impl FillUniform for LWE where Zn: FillUniform, { @@ -111,142 +125,98 @@ where } } -impl LWECiphertext> { - pub fn alloc(infos: &A) -> Self - where - A: LWEInfos, - { - Self::alloc_with(infos.n(), infos.base2k(), infos.k()) - } - - pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision) -> Self { - Self { +pub trait LWEAlloc { + fn alloc_lwe(&self, n: RingDegree, base2k: Base2K, k: TorusPrecision) -> LWE> { + LWE { data: Zn::alloc((n + 1).into(), 1, k.0.div_ceil(base2k.0) as usize), k, base2k, } } - pub fn alloc_bytes(infos: &A) -> usize + fn alloc_lwe_from_infos(&self, infos: &A) -> LWE> where A: LWEInfos, { - Self::alloc_bytes_with(infos.n(), infos.base2k(), infos.k()) + self.alloc_lwe(infos.n(), infos.base2k(), infos.k()) } - pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision) -> usize { - Zn::alloc_bytes((n + 1).into(), 1, k.0.div_ceil(base2k.0) as usize) + fn bytes_of_lwe(&self, n: RingDegree, base2k: Base2K, k: TorusPrecision) -> usize { + Zn::bytes_of((n + 1).into(), 1, k.0.div_ceil(base2k.0) as usize) } -} -impl LWECiphertextBuilder> { - #[inline] - pub fn layout(mut self, layout: A) -> Self + fn bytes_of_lwe_from_infos(&self, infos: &A) -> usize where A: LWEInfos, { - self.data = Some(Zn::alloc((layout.n() + 1).into(), 1, layout.size())); - self.base2k = Some(layout.base2k()); - self.k = Some(layout.k()); - self + self.bytes_of_lwe(infos.n(), infos.base2k(), infos.k()) } } -pub struct LWECiphertextBuilder { - data: Option>, - base2k: Option, - k: Option, +impl LWEAlloc for Module {} + +impl LWE> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: LWEInfos, + M: LWEAlloc, + { + module.alloc_lwe_from_infos(infos) + } + + pub fn alloc(module: &M, n: RingDegree, base2k: Base2K, k: TorusPrecision) -> Self + where + M: LWEAlloc, + { + module.alloc_lwe(n, base2k, k) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: LWEInfos, + M: LWEAlloc, + { + module.bytes_of_lwe_from_infos(infos) + } + + pub fn bytes_of(module: &M, n: RingDegree, base2k: Base2K, k: TorusPrecision) -> usize + where + M: LWEAlloc, + { + module.bytes_of_lwe(n, base2k, k) + } } -impl LWECiphertext { - #[inline] - pub fn builder() -> LWECiphertextBuilder { - LWECiphertextBuilder { - data: None, - base2k: None, - k: None, +pub trait LWEToRef { + fn to_ref(&self) -> LWE<&[u8]>; +} + +impl LWEToRef for LWE { + fn to_ref(&self) -> LWE<&[u8]> { + LWE { + k: self.k, + base2k: self.base2k, + data: self.data.to_ref(), } } } -impl LWECiphertextBuilder { - #[inline] - pub fn data(mut self, data: Zn) -> Self { - self.data = Some(data); - self - } - #[inline] - pub fn base2k(mut self, base2k: Base2K) -> Self { - self.base2k = Some(base2k); - self - } - #[inline] - pub fn k(mut self, k: TorusPrecision) -> Self { - self.k = Some(k); - self - } - - pub fn build(self) -> Result, BuildError> { - let data: Zn = self.data.ok_or(BuildError::MissingData)?; - let base2k: Base2K = self.base2k.ok_or(BuildError::MissingBase2K)?; - let k: TorusPrecision = self.k.ok_or(BuildError::MissingK)?; - - if base2k.0 == 0 { - return Err(BuildError::ZeroBase2K); - } - - if k.0 == 0 { - return Err(BuildError::ZeroTorusPrecision); - } - - if data.n() == 0 { - return Err(BuildError::ZeroDegree); - } - - if data.cols() == 0 { - return Err(BuildError::ZeroCols); - } - - if data.size() == 0 { - return Err(BuildError::ZeroLimbs); - } - - Ok(LWECiphertext { data, base2k, k }) - } -} - -pub trait LWECiphertextToRef { - fn to_ref(&self) -> LWECiphertext<&[u8]>; -} - -impl LWECiphertextToRef for LWECiphertext { - fn to_ref(&self) -> LWECiphertext<&[u8]> { - LWECiphertext::builder() - .base2k(self.base2k()) - .k(self.k()) - .data(self.data.to_ref()) - .build() - .unwrap() - } -} - -pub trait LWECiphertextToMut { +pub trait LWEToMut { #[allow(dead_code)] - fn to_mut(&mut self) -> LWECiphertext<&mut [u8]>; + fn to_mut(&mut self) -> LWE<&mut [u8]>; } -impl LWECiphertextToMut for LWECiphertext { - fn to_mut(&mut self) -> LWECiphertext<&mut [u8]> { - LWECiphertext::builder() - .base2k(self.base2k()) - .k(self.k()) - .data(self.data.to_mut()) - .build() - .unwrap() +impl LWEToMut for LWE { + fn to_mut(&mut self) -> LWE<&mut [u8]> { + LWE { + k: self.k, + base2k: self.base2k, + data: self.data.to_mut(), + } } } -impl ReaderFrom for LWECiphertext { +impl ReaderFrom for LWE { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { self.k = TorusPrecision(reader.read_u32::()?); self.base2k = Base2K(reader.read_u32::()?); @@ -254,7 +224,7 @@ impl ReaderFrom for LWECiphertext { } } -impl WriterTo for LWECiphertext { +impl WriterTo for LWE { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { writer.write_u32::(self.k.into())?; writer.write_u32::(self.base2k.into())?; diff --git a/poulpy-core/src/layouts/lwe_ksk.rs b/poulpy-core/src/layouts/lwe_ksk.rs index 314322c..2ff0391 100644 --- a/poulpy-core/src/layouts/lwe_ksk.rs +++ b/poulpy-core/src/layouts/lwe_ksk.rs @@ -1,22 +1,25 @@ use std::fmt; use poulpy_hal::{ - layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, WriterTo}, + layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo}, source::Source, }; -use crate::layouts::{Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWESwitchingKey, GLWEInfos, LWEInfos, Rank, TorusPrecision}; +use crate::layouts::{ + Base2K, Dnum, Dsize, GGLWEInfos, GLWEInfos, GLWESwitchingKey, GLWESwitchingKeyAlloc, GLWESwitchingKeyToMut, + GLWESwitchingKeyToRef, LWEInfos, Rank, RingDegree, TorusPrecision, +}; #[derive(PartialEq, Eq, Copy, Clone, Debug)] pub struct LWESwitchingKeyLayout { - pub n: Degree, + pub n: RingDegree, pub base2k: Base2K, pub k: TorusPrecision, pub dnum: Dnum, } impl LWEInfos for LWESwitchingKeyLayout { - fn n(&self) -> Degree { + fn n(&self) -> RingDegree { self.n } @@ -54,7 +57,7 @@ impl GGLWEInfos for LWESwitchingKeyLayout { } #[derive(PartialEq, Eq, Clone)] -pub struct LWESwitchingKey(pub(crate) GGLWESwitchingKey); +pub struct LWESwitchingKey(pub(crate) GLWESwitchingKey); impl LWEInfos for LWESwitchingKey { fn base2k(&self) -> Base2K { @@ -65,7 +68,7 @@ impl LWEInfos for LWESwitchingKey { self.0.k() } - fn n(&self) -> Degree { + fn n(&self) -> RingDegree { self.0.n() } @@ -98,65 +101,94 @@ impl GGLWEInfos for LWESwitchingKey { } } +pub trait LWESwitchingKeyAlloc +where + Self: GLWESwitchingKeyAlloc, +{ + fn alloc_lwe_switching_key(&self, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> LWESwitchingKey> { + LWESwitchingKey(self.alloc_glwe_switching_key(base2k, k, Rank(1), Rank(1), dnum, Dsize(1))) + } + + fn alloc_lwe_switching_key_from_infos(&self, infos: &A) -> LWESwitchingKey> + where + A: GGLWEInfos, + { + assert_eq!( + infos.dsize().0, + 1, + "dsize > 1 is not supported for LWESwitchingKey" + ); + assert_eq!( + infos.rank_in().0, + 1, + "rank_in > 1 is not supported for LWESwitchingKey" + ); + assert_eq!( + infos.rank_out().0, + 1, + "rank_out > 1 is not supported for LWESwitchingKey" + ); + self.alloc_lwe_switching_key(infos.base2k(), infos.k(), infos.dnum()) + } + + fn bytes_of_lwe_switching_key(&self, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> usize { + self.bytes_of_glwe_switching_key(base2k, k, Rank(1), Rank(1), dnum, Dsize(1)) + } + + fn bytes_of_lwe_switching_key_from_infos(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + assert_eq!( + infos.dsize().0, + 1, + "dsize > 1 is not supported for LWESwitchingKey" + ); + assert_eq!( + infos.rank_in().0, + 1, + "rank_in > 1 is not supported for LWESwitchingKey" + ); + assert_eq!( + infos.rank_out().0, + 1, + "rank_out > 1 is not supported for LWESwitchingKey" + ); + self.bytes_of_lwe_switching_key(infos.base2k(), infos.k(), infos.dnum()) + } +} + +impl LWESwitchingKeyAlloc for Module where Self: GLWESwitchingKeyAlloc {} + impl LWESwitchingKey> { - pub fn alloc(infos: &A) -> Self + pub fn alloc_from_infos(module: &M, infos: &A) -> Self where A: GGLWEInfos, + M: LWESwitchingKeyAlloc, { - debug_assert_eq!( - infos.dsize().0, - 1, - "dsize > 1 is not supported for LWESwitchingKey" - ); - debug_assert_eq!( - infos.rank_in().0, - 1, - "rank_in > 1 is not supported for LWESwitchingKey" - ); - debug_assert_eq!( - infos.rank_out().0, - 1, - "rank_out > 1 is not supported for LWESwitchingKey" - ); - Self(GGLWESwitchingKey::alloc(infos)) + module.alloc_lwe_switching_key_from_infos(infos) } - pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> Self { - Self(GGLWESwitchingKey::alloc_with( - n, - base2k, - k, - Rank(1), - Rank(1), - dnum, - Dsize(1), - )) + pub fn alloc(module: &M, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> Self + where + M: LWESwitchingKeyAlloc, + { + module.alloc_lwe_switching_key(base2k, k, dnum) } - pub fn alloc_bytes(infos: &A) -> usize + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize where A: GGLWEInfos, + M: LWESwitchingKeyAlloc, { - debug_assert_eq!( - infos.dsize().0, - 1, - "dsize > 1 is not supported for LWESwitchingKey" - ); - debug_assert_eq!( - infos.rank_in().0, - 1, - "rank_in > 1 is not supported for LWESwitchingKey" - ); - debug_assert_eq!( - infos.rank_out().0, - 1, - "rank_out > 1 is not supported for LWESwitchingKey" - ); - GGLWESwitchingKey::alloc_bytes(infos) + module.bytes_of_glwe_switching_key_from_infos(infos) } - pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> usize { - GGLWESwitchingKey::alloc_bytes_with(n, base2k, k, Rank(1), Rank(1), dnum, Dsize(1)) + pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> usize + where + M: LWESwitchingKeyAlloc, + { + module.bytes_of_lwe_switching_key(base2k, k, dnum) } } @@ -189,3 +221,29 @@ impl WriterTo for LWESwitchingKey { self.0.write_to(writer) } } + +pub trait LWESwitchingKeyToRef { + fn to_ref(&self) -> LWESwitchingKey<&[u8]>; +} + +impl LWESwitchingKeyToRef for LWESwitchingKey +where + GLWESwitchingKey: GLWESwitchingKeyToRef, +{ + fn to_ref(&self) -> LWESwitchingKey<&[u8]> { + LWESwitchingKey(self.0.to_ref()) + } +} + +pub trait LWESwitchingKeyToMut { + fn to_mut(&mut self) -> LWESwitchingKey<&mut [u8]>; +} + +impl LWESwitchingKeyToMut for LWESwitchingKey +where + GLWESwitchingKey: GLWESwitchingKeyToMut, +{ + fn to_mut(&mut self) -> LWESwitchingKey<&mut [u8]> { + LWESwitchingKey(self.0.to_mut()) + } +} diff --git a/poulpy-core/src/layouts/lwe_pt.rs b/poulpy-core/src/layouts/lwe_pt.rs index e739722..6ffd650 100644 --- a/poulpy-core/src/layouts/lwe_pt.rs +++ b/poulpy-core/src/layouts/lwe_pt.rs @@ -1,8 +1,8 @@ use std::fmt; -use poulpy_hal::layouts::{Data, DataMut, DataRef, Zn, ZnToMut, ZnToRef, ZnxInfos}; +use poulpy_hal::layouts::{Backend, Data, DataMut, DataRef, Module, Zn, ZnToMut, ZnToRef, ZnxInfos}; -use crate::layouts::{Base2K, Degree, LWEInfos, TorusPrecision}; +use crate::layouts::{Base2K, LWEInfos, RingDegree, TorusPrecision}; #[derive(PartialEq, Eq, Copy, Clone, Debug)] pub struct LWEPlaintextLayout { @@ -19,8 +19,8 @@ impl LWEInfos for LWEPlaintextLayout { self.k } - fn n(&self) -> Degree { - Degree(0) + fn n(&self) -> RingDegree { + RingDegree(0) } fn size(&self) -> usize { @@ -43,8 +43,8 @@ impl LWEInfos for LWEPlaintext { self.k } - fn n(&self) -> Degree { - Degree(self.data.n() as u32 - 1) + fn n(&self) -> RingDegree { + RingDegree(self.data.n() as u32 - 1) } fn size(&self) -> usize { @@ -52,21 +52,40 @@ impl LWEInfos for LWEPlaintext { } } -impl LWEPlaintext> { - pub fn alloc(infos: &A) -> Self - where - A: LWEInfos, - { - Self::alloc_with(infos.base2k(), infos.k()) - } - - pub fn alloc_with(base2k: Base2K, k: TorusPrecision) -> Self { - Self { +pub trait LWEPlaintextAlloc { + fn alloc_lwe_plaintext(&self, base2k: Base2K, k: TorusPrecision) -> LWEPlaintext> { + LWEPlaintext { data: Zn::alloc(1, 1, k.0.div_ceil(base2k.0) as usize), k, base2k, } } + + fn alloc_lwe_plaintext_from_infos(&self, infos: &A) -> LWEPlaintext> + where + A: LWEInfos, + { + self.alloc_lwe_plaintext(infos.base2k(), infos.k()) + } +} + +impl LWEPlaintextAlloc for Module {} + +impl LWEPlaintext> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: LWEInfos, + M: LWEPlaintextAlloc, + { + module.alloc_lwe_plaintext_from_infos(infos) + } + + pub fn alloc(module: &M, base2k: Base2K, k: TorusPrecision) -> Self + where + M: LWEPlaintextAlloc, + { + module.alloc_lwe_plaintext(base2k, k) + } } impl fmt::Display for LWEPlaintext { diff --git a/poulpy-core/src/layouts/lwe_sk.rs b/poulpy-core/src/layouts/lwe_sk.rs index a5b7d4e..d593ad5 100644 --- a/poulpy-core/src/layouts/lwe_sk.rs +++ b/poulpy-core/src/layouts/lwe_sk.rs @@ -1,11 +1,11 @@ use poulpy_hal::{ - layouts::{Data, DataMut, DataRef, ScalarZnx, ZnxInfos, ZnxView, ZnxZero}, + layouts::{Backend, Data, DataMut, DataRef, Module, ScalarZnx, ScalarZnxToMut, ScalarZnxToRef, ZnxInfos, ZnxView, ZnxZero}, source::Source, }; use crate::{ dist::Distribution, - layouts::{Base2K, Degree, LWEInfos, TorusPrecision}, + layouts::{Base2K, LWEInfos, RingDegree, TorusPrecision}, }; pub struct LWESecret { @@ -13,15 +13,26 @@ pub struct LWESecret { pub(crate) dist: Distribution, } -impl LWESecret> { - pub fn alloc(n: Degree) -> Self { - Self { +pub trait LWESecretAlloc { + fn alloc_lwe_secret(&self, n: RingDegree) -> LWESecret> { + LWESecret { data: ScalarZnx::alloc(n.into(), 1), dist: Distribution::NONE, } } } +impl LWESecretAlloc for Module {} + +impl LWESecret> { + pub fn alloc(module: &M, n: RingDegree) -> Self + where + M: LWESecretAlloc, + { + module.alloc_lwe_secret(n) + } +} + impl LWESecret { pub fn raw(&self) -> &[i64] { self.data.at(0, 0) @@ -44,8 +55,8 @@ impl LWEInfos for LWESecret { TorusPrecision(0) } - fn n(&self) -> Degree { - Degree(self.data.n() as u32) + fn n(&self) -> RingDegree { + RingDegree(self.data.n() as u32) } fn size(&self) -> usize { @@ -84,3 +95,29 @@ impl LWESecret { self.dist = Distribution::ZERO; } } + +pub trait LWESecretToRef { + fn to_ref(&self) -> LWESecret<&[u8]>; +} + +impl LWESecretToRef for LWESecret { + fn to_ref(&self) -> LWESecret<&[u8]> { + LWESecret { + dist: self.dist, + data: self.data.to_ref(), + } + } +} + +pub trait LWESecretToMut { + fn to_mut(&mut self) -> LWESecret<&mut [u8]>; +} + +impl LWESecretToMut for LWESecret { + fn to_mut(&mut self) -> LWESecret<&mut [u8]> { + LWESecret { + dist: self.dist, + data: self.data.to_mut(), + } + } +} diff --git a/poulpy-core/src/layouts/lwe_to_glwe_ksk.rs b/poulpy-core/src/layouts/lwe_to_glwe_ksk.rs index b3ba74b..72b7514 100644 --- a/poulpy-core/src/layouts/lwe_to_glwe_ksk.rs +++ b/poulpy-core/src/layouts/lwe_to_glwe_ksk.rs @@ -1,15 +1,18 @@ use std::fmt; use poulpy_hal::{ - layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, WriterTo}, + layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo}, source::Source, }; -use crate::layouts::{Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWESwitchingKey, GLWEInfos, LWEInfos, Rank, TorusPrecision}; +use crate::layouts::{ + Base2K, Dnum, Dsize, GGLWEInfos, GLWEInfos, GLWESwitchingKey, GLWESwitchingKeyAlloc, GLWESwitchingKeyToMut, + GLWESwitchingKeyToRef, LWEInfos, Rank, RingDegree, TorusPrecision, +}; #[derive(PartialEq, Eq, Copy, Clone, Debug)] pub struct LWEToGLWESwitchingKeyLayout { - pub n: Degree, + pub n: RingDegree, pub base2k: Base2K, pub k: TorusPrecision, pub rank_out: Rank, @@ -25,7 +28,7 @@ impl LWEInfos for LWEToGLWESwitchingKeyLayout { self.k } - fn n(&self) -> Degree { + fn n(&self) -> RingDegree { self.n } } @@ -55,7 +58,7 @@ impl GGLWEInfos for LWEToGLWESwitchingKeyLayout { } #[derive(PartialEq, Eq, Clone)] -pub struct LWEToGLWESwitchingKey(pub(crate) GGLWESwitchingKey); +pub struct LWEToGLWESwitchingKey(pub(crate) GLWESwitchingKey); impl LWEInfos for LWEToGLWESwitchingKey { fn base2k(&self) -> Base2K { @@ -66,7 +69,7 @@ impl LWEInfos for LWEToGLWESwitchingKey { self.0.k() } - fn n(&self) -> Degree { + fn n(&self) -> RingDegree { self.0.n() } @@ -128,54 +131,116 @@ impl WriterTo for LWEToGLWESwitchingKey { } } -impl LWEToGLWESwitchingKey> { - pub fn alloc(infos: &A) -> Self +pub trait LWEToGLWESwitchingKeyAlloc +where + Self: GLWESwitchingKeyAlloc, +{ + fn alloc_lwe_to_glwe_switching_key( + &self, + base2k: Base2K, + k: TorusPrecision, + rank_out: Rank, + dnum: Dnum, + ) -> LWEToGLWESwitchingKey> { + LWEToGLWESwitchingKey(self.alloc_glwe_switching_key(base2k, k, Rank(1), rank_out, dnum, Dsize(1))) + } + + fn alloc_lwe_to_glwe_switching_key_from_infos(&self, infos: &A) -> LWEToGLWESwitchingKey> where A: GGLWEInfos, { - debug_assert_eq!( + assert_eq!( infos.rank_in().0, 1, "rank_in > 1 is not supported for LWEToGLWESwitchingKey" ); - debug_assert_eq!( + assert_eq!( infos.dsize().0, 1, "dsize > 1 is not supported for LWEToGLWESwitchingKey" ); - Self(GGLWESwitchingKey::alloc(infos)) + + self.alloc_lwe_to_glwe_switching_key(infos.base2k(), infos.k(), infos.rank_out(), infos.dnum()) } - pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> Self { - Self(GGLWESwitchingKey::alloc_with( - n, - base2k, - k, - Rank(1), - rank_out, - dnum, - Dsize(1), - )) + fn bytes_of_lwe_to_glwe_switching_key(&self, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> usize { + self.bytes_of_glwe_switching_key(base2k, k, Rank(1), rank_out, dnum, Dsize(1)) } - pub fn alloc_bytes(infos: &A) -> usize + fn bytes_of_lwe_to_glwe_switching_key_from_infos(&self, infos: &A) -> usize where A: GGLWEInfos, { - debug_assert_eq!( + assert_eq!( infos.rank_in().0, 1, "rank_in > 1 is not supported for LWEToGLWESwitchingKey" ); - debug_assert_eq!( + assert_eq!( infos.dsize().0, 1, "dsize > 1 is not supported for LWEToGLWESwitchingKey" ); - GGLWESwitchingKey::alloc_bytes(infos) - } - - pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, dnum: Dnum, rank_out: Rank) -> usize { - GGLWESwitchingKey::alloc_bytes_with(n, base2k, k, Rank(1), rank_out, dnum, Dsize(1)) + self.bytes_of_lwe_to_glwe_switching_key(infos.base2k(), infos.k(), infos.rank_out(), infos.dnum()) + } +} + +impl LWEToGLWESwitchingKeyAlloc for Module where Self: GLWESwitchingKeyAlloc {} + +impl LWEToGLWESwitchingKey> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: GGLWEInfos, + M: LWEToGLWESwitchingKeyAlloc, + { + module.alloc_lwe_to_glwe_switching_key_from_infos(infos) + } + + pub fn alloc(module: &M, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> Self + where + M: LWEToGLWESwitchingKeyAlloc, + { + module.alloc_lwe_to_glwe_switching_key(base2k, k, rank_out, dnum) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: LWEToGLWESwitchingKeyAlloc, + { + module.bytes_of_lwe_to_glwe_switching_key_from_infos(infos) + } + + pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision, dnum: Dnum, rank_out: Rank) -> usize + where + M: LWEToGLWESwitchingKeyAlloc, + { + module.bytes_of_lwe_to_glwe_switching_key(base2k, k, rank_out, dnum) + } +} + +pub trait LWEToGLWESwitchingKeyToRef { + fn to_ref(&self) -> LWEToGLWESwitchingKey<&[u8]>; +} + +impl LWEToGLWESwitchingKeyToRef for LWEToGLWESwitchingKey +where + GLWESwitchingKey: GLWESwitchingKeyToRef, +{ + fn to_ref(&self) -> LWEToGLWESwitchingKey<&[u8]> { + LWEToGLWESwitchingKey(self.0.to_ref()) + } +} + +pub trait LWEToGLWESwitchingKeyToMut { + fn to_mut(&mut self) -> LWEToGLWESwitchingKey<&mut [u8]>; +} + +impl LWEToGLWESwitchingKeyToMut for LWEToGLWESwitchingKey +where + GLWESwitchingKey: GLWESwitchingKeyToMut, +{ + fn to_mut(&mut self) -> LWEToGLWESwitchingKey<&mut [u8]> { + LWEToGLWESwitchingKey(self.0.to_mut()) } } diff --git a/poulpy-core/src/layouts/mod.rs b/poulpy-core/src/layouts/mod.rs index 2b13751..09064f3 100644 --- a/poulpy-core/src/layouts/mod.rs +++ b/poulpy-core/src/layouts/mod.rs @@ -33,20 +33,16 @@ pub use lwe_pt::*; pub use lwe_sk::*; pub use lwe_to_glwe_ksk::*; -#[derive(Debug)] -pub enum BuildError { - MissingData, - MissingBase2K, - MissingK, - MissingDigits, - ZeroDegree, - NonPowerOfTwoDegree, - ZeroBase2K, - ZeroTorusPrecision, - ZeroCols, - ZeroLimbs, - ZeroRank, - ZeroDigits, +use poulpy_hal::layouts::{Backend, Module}; + +pub trait GetRingDegree { + fn ring_degree(&self) -> RingDegree; +} + +impl GetRingDegree for Module { + fn ring_degree(&self) -> RingDegree { + Self::n(&self).into() + } } /// Newtype over `u32` with arithmetic and comparisons against same type and `u32`. @@ -206,14 +202,14 @@ macro_rules! newtype_u32 { }; } -newtype_u32!(Degree); +newtype_u32!(RingDegree); newtype_u32!(TorusPrecision); newtype_u32!(Base2K); newtype_u32!(Dnum); newtype_u32!(Rank); newtype_u32!(Dsize); -impl Degree { +impl RingDegree { pub fn log2(&self) -> usize { let n: usize = self.0 as usize; (usize::BITS - (n - 1).leading_zeros()) as _ diff --git a/poulpy-core/src/layouts/prepared/gglwe_atk.rs b/poulpy-core/src/layouts/prepared/gglwe_atk.rs index 594aa0a..7c8d0b9 100644 --- a/poulpy-core/src/layouts/prepared/gglwe_atk.rs +++ b/poulpy-core/src/layouts/prepared/gglwe_atk.rs @@ -1,27 +1,21 @@ -use poulpy_hal::{ - api::{VmpPMatAlloc, VmpPMatAllocBytes, VmpPrepare}, - layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}, -}; +use poulpy_hal::layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}; use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWEAutomorphismKey, GGLWEInfos, GLWEInfos, LWEInfos, Rank, TorusPrecision, - prepared::{GGLWESwitchingKeyPrepared, Prepare, PrepareAlloc, PrepareScratchSpace}, + AutomorphismKeyToRef, Base2K, Dnum, Dsize, GGLWEInfos, GLWEInfos, LWEInfos, Rank, RingDegree, TorusPrecision, + prepared::{ + GLWESwitchingKeyPrepare, GLWESwitchingKeyPrepared, GLWESwitchingKeyPreparedAlloc, GLWESwitchingKeyPreparedToMut, + GLWESwitchingKeyPreparedToRef, + }, }; #[derive(PartialEq, Eq)] -pub struct GGLWEAutomorphismKeyPrepared { - pub(crate) key: GGLWESwitchingKeyPrepared, +pub struct AutomorphismKeyPrepared { + pub(crate) key: GLWESwitchingKeyPrepared, pub(crate) p: i64, } -impl GGLWEAutomorphismKeyPrepared { - pub fn p(&self) -> i64 { - self.p - } -} - -impl LWEInfos for GGLWEAutomorphismKeyPrepared { - fn n(&self) -> Degree { +impl LWEInfos for AutomorphismKeyPrepared { + fn n(&self) -> RingDegree { self.key.n() } @@ -38,13 +32,33 @@ impl LWEInfos for GGLWEAutomorphismKeyPrepared { } } -impl GLWEInfos for GGLWEAutomorphismKeyPrepared { +pub trait GetAutomorphismGaloisElement { + fn p(&self) -> i64; +} + +impl GetAutomorphismGaloisElement for AutomorphismKeyPrepared { + fn p(&self) -> i64 { + self.p + } +} + +pub trait SetAutomorphismGaloisElement { + fn set_p(&mut self, p: i64); +} + +impl SetAutomorphismGaloisElement for AutomorphismKeyPrepared { + fn set_p(&mut self, p: i64) { + self.p = p + } +} + +impl GLWEInfos for AutomorphismKeyPrepared { fn rank(&self) -> Rank { self.rank_out() } } -impl GGLWEInfos for GGLWEAutomorphismKeyPrepared { +impl GGLWEInfos for AutomorphismKeyPrepared { fn rank_in(&self) -> Rank { self.key.rank_in() } @@ -62,80 +76,170 @@ impl GGLWEInfos for GGLWEAutomorphismKeyPrepared { } } -impl GGLWEAutomorphismKeyPrepared, B> { - pub fn alloc(module: &Module, infos: &A) -> Self +pub trait AutomorphismKeyPreparedAlloc +where + Self: GLWESwitchingKeyPreparedAlloc, +{ + fn alloc_automorphism_key_prepared( + &self, + base2k: Base2K, + k: TorusPrecision, + rank: Rank, + dnum: Dnum, + dsize: Dsize, + ) -> AutomorphismKeyPrepared, B> { + AutomorphismKeyPrepared::, B> { + key: self.alloc_glwe_switching_key_prepared(base2k, k, rank, rank, dnum, dsize), + p: 0, + } + } + + fn alloc_automorphism_key_prepared_from_infos(&self, infos: &A) -> AutomorphismKeyPrepared, B> + where + A: GGLWEInfos, + { + assert_eq!( + infos.rank_in(), + infos.rank_out(), + "rank_in != rank_out is not supported for AutomorphismKeyPrepared" + ); + self.alloc_automorphism_key_prepared( + infos.base2k(), + infos.k(), + infos.rank(), + infos.dnum(), + infos.dsize(), + ) + } + + fn bytes_of_automorphism_key_prepared( + &self, + base2k: Base2K, + k: TorusPrecision, + rank: Rank, + dnum: Dnum, + dsize: Dsize, + ) -> usize { + self.bytes_of_glwe_switching_key_prepared(base2k, k, rank, rank, dnum, dsize) + } + + fn bytes_of_automorphism_key_prepared_from_infos(&self, infos: &A) -> usize where A: GGLWEInfos, - Module: VmpPMatAlloc, { assert_eq!( infos.rank_in(), infos.rank_out(), "rank_in != rank_out is not supported for GGLWEAutomorphismKeyPrepared" ); - GGLWEAutomorphismKeyPrepared::, B> { - key: GGLWESwitchingKeyPrepared::alloc(module, infos), - p: 0, - } + self.bytes_of_automorphism_key_prepared( + infos.base2k(), + infos.k(), + infos.rank(), + infos.dnum(), + infos.dsize(), + ) } +} - pub fn alloc_with(module: &Module, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self - where - Module: VmpPMatAlloc, - { - GGLWEAutomorphismKeyPrepared { - key: GGLWESwitchingKeyPrepared::alloc_with(module, base2k, k, rank, rank, dnum, dsize), - p: 0, - } - } +impl AutomorphismKeyPreparedAlloc for Module where Module: GLWESwitchingKeyPreparedAlloc {} - pub fn alloc_bytes(module: &Module, infos: &A) -> usize +impl AutomorphismKeyPrepared, B> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self where A: GGLWEInfos, - Module: VmpPMatAllocBytes, + M: AutomorphismKeyPreparedAlloc, { - assert_eq!( - infos.rank_in(), - infos.rank_out(), - "rank_in != rank_out is not supported for GGLWEAutomorphismKeyPrepared" - ); - GGLWESwitchingKeyPrepared::alloc_bytes(module, infos) + module.alloc_automorphism_key_prepared_from_infos(infos) } - pub fn alloc_bytes_with(module: &Module, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize + pub fn alloc(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self where - Module: VmpPMatAllocBytes, + M: AutomorphismKeyPreparedAlloc, { - GGLWESwitchingKeyPrepared::alloc_bytes_with(module, base2k, k, rank, rank, dnum, dsize) + module.alloc_automorphism_key_prepared(base2k, k, rank, dnum, dsize) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: AutomorphismKeyPreparedAlloc, + { + module.bytes_of_automorphism_key_prepared_from_infos(infos) + } + + pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize + where + M: AutomorphismKeyPreparedAlloc, + { + module.bytes_of_automorphism_key_prepared(base2k, k, rank, dnum, dsize) } } -impl PrepareScratchSpace for GGLWEAutomorphismKeyPrepared, B> +pub trait PrepareAutomorphismKey where - GGLWESwitchingKeyPrepared, B>: PrepareScratchSpace, + Self: GLWESwitchingKeyPrepare, { - fn prepare_scratch_space(module: &Module, infos: &A) -> usize { - GGLWESwitchingKeyPrepared::prepare_scratch_space(module, infos) + fn prepare_automorphism_key_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + self.prepare_glwe_switching_key_tmp_bytes(infos) + } + + fn prepare_automorphism_key(&self, res: &mut R, other: &O, scratch: &mut Scratch) + where + R: AutomorphismKeyPreparedToMut + SetAutomorphismGaloisElement, + O: AutomorphismKeyToRef + GetAutomorphismGaloisElement, + { + self.prepare_glwe_switching(&mut res.to_mut().key, &other.to_ref().key, scratch); + res.set_p(other.p()); } } -impl Prepare> for GGLWEAutomorphismKeyPrepared -where - Module: VmpPrepare, -{ - fn prepare(&mut self, module: &Module, other: &GGLWEAutomorphismKey, scratch: &mut Scratch) { - self.key.prepare(module, &other.key, scratch); - self.p = other.p; +impl PrepareAutomorphismKey for Module where Module: GLWESwitchingKeyPrepare {} + +impl AutomorphismKeyPrepared, B> { + pub fn prepare_tmp_bytes(&self, module: &M) -> usize + where + M: PrepareAutomorphismKey, + { + module.prepare_automorphism_key_tmp_bytes(self) } } -impl PrepareAlloc, B>> for GGLWEAutomorphismKey -where - Module: VmpPMatAlloc + VmpPrepare, -{ - fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> GGLWEAutomorphismKeyPrepared, B> { - let mut atk_prepared: GGLWEAutomorphismKeyPrepared, B> = GGLWEAutomorphismKeyPrepared::alloc(module, self); - atk_prepared.prepare(module, self, scratch); - atk_prepared +impl AutomorphismKeyPrepared { + pub fn prepare(&mut self, module: &M, other: &O, scratch: &mut Scratch) + where + O: AutomorphismKeyToRef + GetAutomorphismGaloisElement, + M: PrepareAutomorphismKey, + { + module.prepare_automorphism_key(self, other, scratch); + } +} + +pub trait AutomorphismKeyPreparedToMut { + fn to_mut(&mut self) -> AutomorphismKeyPrepared<&mut [u8], B>; +} + +impl AutomorphismKeyPreparedToMut for AutomorphismKeyPrepared { + fn to_mut(&mut self) -> AutomorphismKeyPrepared<&mut [u8], B> { + AutomorphismKeyPrepared { + p: self.p, + key: self.key.to_mut(), + } + } +} + +pub trait AutomorphismKeyPreparedToRef { + fn to_ref(&self) -> AutomorphismKeyPrepared<&[u8], B>; +} + +impl AutomorphismKeyPreparedToRef for AutomorphismKeyPrepared { + fn to_ref(&self) -> AutomorphismKeyPrepared<&[u8], B> { + AutomorphismKeyPrepared { + p: self.p, + key: self.key.to_ref(), + } } } diff --git a/poulpy-core/src/layouts/prepared/gglwe_ct.rs b/poulpy-core/src/layouts/prepared/gglwe_ct.rs index 4f22e6e..fd6033a 100644 --- a/poulpy-core/src/layouts/prepared/gglwe_ct.rs +++ b/poulpy-core/src/layouts/prepared/gglwe_ct.rs @@ -1,25 +1,23 @@ use poulpy_hal::{ - api::{VmpPMatAlloc, VmpPMatAllocBytes, VmpPrepare, VmpPrepareTmpBytes}, - layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VmpPMat, ZnxInfos}, - oep::VmpPMatAllocBytesImpl, + api::{VmpPMatAlloc, VmpPMatBytesOf, VmpPrepare, VmpPrepareTmpBytes}, + layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VmpPMat, VmpPMatToMut, VmpPMatToRef, ZnxInfos}, }; use crate::layouts::{ - Base2K, BuildError, Degree, Dnum, Dsize, GGLWECiphertext, GGLWEInfos, GLWEInfos, LWEInfos, Rank, TorusPrecision, - prepared::{Prepare, PrepareAlloc, PrepareScratchSpace}, + Base2K, Dnum, Dsize, GGLWE, GGLWEInfos, GGLWEToRef, GLWEInfos, GetRingDegree, LWEInfos, Rank, RingDegree, TorusPrecision, }; #[derive(PartialEq, Eq)] -pub struct GGLWECiphertextPrepared { +pub struct GGLWEPrepared { pub(crate) data: VmpPMat, pub(crate) k: TorusPrecision, pub(crate) base2k: Base2K, pub(crate) dsize: Dsize, } -impl LWEInfos for GGLWECiphertextPrepared { - fn n(&self) -> Degree { - Degree(self.data.n() as u32) +impl LWEInfos for GGLWEPrepared { + fn n(&self) -> RingDegree { + RingDegree(self.data.n() as u32) } fn base2k(&self) -> Base2K { @@ -35,13 +33,13 @@ impl LWEInfos for GGLWECiphertextPrepared { } } -impl GLWEInfos for GGLWECiphertextPrepared { +impl GLWEInfos for GGLWEPrepared { fn rank(&self) -> Rank { self.rank_out() } } -impl GGLWEInfos for GGLWECiphertextPrepared { +impl GGLWEInfos for GGLWEPrepared { fn rank_in(&self) -> Rank { Rank(self.data.cols_in() as u32) } @@ -59,117 +57,47 @@ impl GGLWEInfos for GGLWECiphertextPrepared { } } -pub struct GGLWECiphertextPreparedBuilder { - data: Option>, - base2k: Option, - k: Option, - dsize: Option, -} +pub trait GGLWEPreparedAlloc +where + Self: GetRingDegree + VmpPMatAlloc + VmpPMatBytesOf, +{ + fn alloc_gglwe_prepared( + &self, + base2k: Base2K, + k: TorusPrecision, + rank_in: Rank, + rank_out: Rank, + dnum: Dnum, + dsize: Dsize, + ) -> GGLWEPrepared, B> { + let size: usize = k.0.div_ceil(base2k.0) as usize; + debug_assert!( + size as u32 > dsize.0, + "invalid gglwe: ceil(k/base2k): {size} <= dsize: {}", + dsize.0 + ); -impl GGLWECiphertextPrepared { - #[inline] - pub fn builder() -> GGLWECiphertextPreparedBuilder { - GGLWECiphertextPreparedBuilder { - data: None, - base2k: None, - k: None, - dsize: None, - } - } -} + assert!( + dnum.0 * dsize.0 <= size as u32, + "invalid gglwe: dnum: {} * dsize:{} > ceil(k/base2k): {size}", + dnum.0, + dsize.0, + ); -impl GGLWECiphertextPreparedBuilder, B> { - #[inline] - pub fn layout(mut self, infos: &A) -> Self - where - A: GGLWEInfos, - B: VmpPMatAllocBytesImpl, - { - self.data = Some(VmpPMat::alloc( - infos.n().into(), - infos.dnum().into(), - infos.rank_in().into(), - (infos.rank_out() + 1).into(), - infos.size(), - )); - self.base2k = Some(infos.base2k()); - self.k = Some(infos.k()); - self.dsize = Some(infos.dsize()); - self - } -} - -impl GGLWECiphertextPreparedBuilder { - #[inline] - pub fn data(mut self, data: VmpPMat) -> Self { - self.data = Some(data); - self - } - #[inline] - pub fn base2k(mut self, base2k: Base2K) -> Self { - self.base2k = Some(base2k); - self - } - #[inline] - pub fn k(mut self, k: TorusPrecision) -> Self { - self.k = Some(k); - self - } - - #[inline] - pub fn dsize(mut self, dsize: Dsize) -> Self { - self.dsize = Some(dsize); - self - } - - pub fn build(self) -> Result, BuildError> { - let data: VmpPMat = self.data.ok_or(BuildError::MissingData)?; - let base2k: Base2K = self.base2k.ok_or(BuildError::MissingBase2K)?; - let k: TorusPrecision = self.k.ok_or(BuildError::MissingK)?; - let dsize: Dsize = self.dsize.ok_or(BuildError::MissingDigits)?; - - if base2k == 0_u32 { - return Err(BuildError::ZeroBase2K); - } - - if dsize == 0_u32 { - return Err(BuildError::ZeroBase2K); - } - - if k == 0_u32 { - return Err(BuildError::ZeroTorusPrecision); - } - - if data.n() == 0 { - return Err(BuildError::ZeroDegree); - } - - if data.cols() == 0 { - return Err(BuildError::ZeroCols); - } - - if data.size() == 0 { - return Err(BuildError::ZeroLimbs); - } - - Ok(GGLWECiphertextPrepared { - data, - base2k, + GGLWEPrepared { + data: self.vmp_pmat_alloc(dnum.into(), rank_in.into(), (rank_out + 1).into(), size), k, + base2k, dsize, - }) + } } -} -impl GGLWECiphertextPrepared, B> { - pub fn alloc(module: &Module, infos: &A) -> Self + fn alloc_gglwe_prepared_from_infos(&self, infos: &A) -> GGLWEPrepared, B> where A: GGLWEInfos, - Module: VmpPMatAlloc, { - debug_assert_eq!(module.n(), infos.n().0 as usize, "module.n() != infos.n()"); - Self::alloc_with( - module, + assert_eq!(self.ring_degree(), infos.n()); + self.alloc_gglwe_prepared( infos.base2k(), infos.k(), infos.rank_in(), @@ -179,8 +107,61 @@ impl GGLWECiphertextPrepared, B> { ) } - pub fn alloc_with( - module: &Module, + fn bytes_of_gglwe_prepared( + &self, + base2k: Base2K, + k: TorusPrecision, + rank_in: Rank, + rank_out: Rank, + dnum: Dnum, + dsize: Dsize, + ) -> usize { + let size: usize = k.0.div_ceil(base2k.0) as usize; + debug_assert!( + size as u32 > dsize.0, + "invalid gglwe: ceil(k/base2k): {size} <= dsize: {}", + dsize.0 + ); + + assert!( + dnum.0 * dsize.0 <= size as u32, + "invalid gglwe: dnum: {} * dsize:{} > ceil(k/base2k): {size}", + dnum.0, + dsize.0, + ); + + self.bytes_of_vmp_pmat(dnum.into(), rank_in.into(), (rank_out + 1).into(), size) + } + + fn bytes_of_gglwe_prepared_from_infos(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + assert_eq!(self.ring_degree(), infos.n()); + self.bytes_of_gglwe_prepared( + infos.base2k(), + infos.k(), + infos.rank_in(), + infos.rank_out(), + infos.dnum(), + infos.dsize(), + ) + } +} + +impl GGLWEPreparedAlloc for Module where Module: GetRingDegree + VmpPMatAlloc + VmpPMatBytesOf {} + +impl GGLWEPrepared, B> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: GGLWEInfos, + M: GGLWEPreparedAlloc, + { + module.alloc_gglwe_prepared_from_infos(infos) + } + + pub fn alloc( + module: &M, base2k: Base2K, k: TorusPrecision, rank_in: Rank, @@ -189,49 +170,21 @@ impl GGLWECiphertextPrepared, B> { dsize: Dsize, ) -> Self where - Module: VmpPMatAlloc, + M: GGLWEPreparedAlloc, { - let size: usize = k.0.div_ceil(base2k.0) as usize; - debug_assert!( - size as u32 > dsize.0, - "invalid gglwe: ceil(k/base2k): {size} <= dsize: {}", - dsize.0 - ); - - assert!( - dnum.0 * dsize.0 <= size as u32, - "invalid gglwe: dnum: {} * dsize:{} > ceil(k/base2k): {size}", - dnum.0, - dsize.0, - ); - - Self { - data: module.vmp_pmat_alloc(dnum.into(), rank_in.into(), (rank_out + 1).into(), size), - k, - base2k, - dsize, - } + module.alloc_gglwe_prepared(base2k, k, rank_in, rank_out, dnum, dsize) } - pub fn alloc_bytes(module: &Module, infos: &A) -> usize + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize where A: GGLWEInfos, - Module: VmpPMatAllocBytes, + M: GGLWEPreparedAlloc, { - debug_assert_eq!(module.n(), infos.n().0 as usize, "module.n() != infos.n()"); - Self::alloc_bytes_with( - module, - infos.base2k(), - infos.k(), - infos.rank_in(), - infos.rank_out(), - infos.dnum(), - infos.dsize(), - ) + module.bytes_of_gglwe_prepared_from_infos(infos) } - pub fn alloc_bytes_with( - module: &Module, + pub fn bytes_of( + module: &M, base2k: Base2K, k: TorusPrecision, rank_in: Rank, @@ -240,59 +193,93 @@ impl GGLWECiphertextPrepared, B> { dsize: Dsize, ) -> usize where - Module: VmpPMatAllocBytes, + M: GGLWEPreparedAlloc, { - let size: usize = k.0.div_ceil(base2k.0) as usize; - debug_assert!( - size as u32 > dsize.0, - "invalid gglwe: ceil(k/base2k): {size} <= dsize: {}", - dsize.0 - ); - - assert!( - dnum.0 * dsize.0 <= size as u32, - "invalid gglwe: dnum: {} * dsize:{} > ceil(k/base2k): {size}", - dnum.0, - dsize.0, - ); - - module.vmp_pmat_alloc_bytes(dnum.into(), rank_in.into(), (rank_out + 1).into(), size) + module.bytes_of_gglwe_prepared(base2k, k, rank_in, rank_out, dnum, dsize) } } -impl PrepareScratchSpace for GGLWECiphertextPrepared, B> +pub trait GGLWEPrepare where - Module: VmpPrepareTmpBytes, + Self: GetRingDegree + VmpPrepareTmpBytes + VmpPrepare, { - fn prepare_scratch_space(module: &Module, infos: &A) -> usize { - module.vmp_prepare_tmp_bytes( + fn prepare_gglwe_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + self.vmp_prepare_tmp_bytes( infos.dnum().into(), infos.rank_in().into(), (infos.rank() + 1).into(), infos.size(), ) } -} -impl Prepare> for GGLWECiphertextPrepared -where - Module: VmpPrepare, -{ - fn prepare(&mut self, module: &Module, other: &GGLWECiphertext, scratch: &mut Scratch) { - module.vmp_prepare(&mut self.data, &other.data, scratch); - self.k = other.k; - self.base2k = other.base2k; - self.dsize = other.dsize; + fn prepare_gglwe(&self, res: &mut R, other: &O, scratch: &mut Scratch) + where + R: GGLWEPreparedToMut, + O: GGLWEToRef, + { + let mut res: GGLWEPrepared<&mut [u8], B> = res.to_mut(); + let other: GGLWE<&[u8]> = other.to_ref(); + + assert_eq!(res.n(), self.ring_degree()); + assert_eq!(other.n(), self.ring_degree()); + assert_eq!(res.base2k, other.base2k); + assert_eq!(res.k, other.k); + assert_eq!(res.dsize, other.dsize); + + self.vmp_prepare(&mut res.data, &other.data, scratch); } } -impl PrepareAlloc, B>> for GGLWECiphertext -where - Module: VmpPMatAlloc + VmpPrepare, -{ - fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> GGLWECiphertextPrepared, B> { - let mut atk_prepared: GGLWECiphertextPrepared, B> = GGLWECiphertextPrepared::alloc(module, self); - atk_prepared.prepare(module, self, scratch); - atk_prepared +impl GGLWEPrepare for Module where Self: GetRingDegree + VmpPrepareTmpBytes + VmpPrepare {} + +impl GGLWEPrepared { + pub fn prepare(&mut self, module: &M, other: &O, scratch: &mut Scratch) + where + O: GGLWEToRef, + M: GGLWEPrepare, + { + module.prepare_gglwe(self, other, scratch); + } +} + +impl GGLWEPrepared, B> { + pub fn prepare_tmp_bytes(&self, module: &M) -> usize + where + M: GGLWEPrepare, + { + module.prepare_gglwe_tmp_bytes(self) + } +} + +pub trait GGLWEPreparedToMut { + fn to_mut(&mut self) -> GGLWEPrepared<&mut [u8], B>; +} + +impl GGLWEPreparedToMut for GGLWEPrepared { + fn to_mut(&mut self) -> GGLWEPrepared<&mut [u8], B> { + GGLWEPrepared { + k: self.k, + base2k: self.base2k, + dsize: self.dsize, + data: self.data.to_mut(), + } + } +} + +pub trait GGLWEPreparedToRef { + fn to_ref(&self) -> GGLWEPrepared<&[u8], B>; +} + +impl GGLWEPreparedToRef for GGLWEPrepared { + fn to_ref(&self) -> GGLWEPrepared<&[u8], B> { + GGLWEPrepared { + k: self.k, + base2k: self.base2k, + dsize: self.dsize, + data: self.data.to_ref(), + } } } diff --git a/poulpy-core/src/layouts/prepared/gglwe_ksk.rs b/poulpy-core/src/layouts/prepared/gglwe_ksk.rs index c9110c1..2154b3e 100644 --- a/poulpy-core/src/layouts/prepared/gglwe_ksk.rs +++ b/poulpy-core/src/layouts/prepared/gglwe_ksk.rs @@ -1,22 +1,40 @@ -use poulpy_hal::{ - api::{VmpPMatAlloc, VmpPMatAllocBytes, VmpPrepare}, - layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}, -}; +use poulpy_hal::layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}; use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWESwitchingKey, GLWEInfos, LWEInfos, Rank, TorusPrecision, - prepared::{GGLWECiphertextPrepared, Prepare, PrepareAlloc, PrepareScratchSpace}, + Base2K, Dnum, Dsize, GGLWEInfos, GLWEInfos, GLWESwitchingKeySetMetaData, GLWESwitchingKeyToRef, GLWESwtichingKeyGetMetaData, + LWEInfos, Rank, RingDegree, TorusPrecision, + prepared::{GGLWEPrepare, GGLWEPrepared, GGLWEPreparedAlloc, GGLWEPreparedToMut, GGLWEPreparedToRef}, }; #[derive(PartialEq, Eq)] -pub struct GGLWESwitchingKeyPrepared { - pub(crate) key: GGLWECiphertextPrepared, +pub struct GLWESwitchingKeyPrepared { + pub(crate) key: GGLWEPrepared, pub(crate) sk_in_n: usize, // Degree of sk_in pub(crate) sk_out_n: usize, // Degree of sk_out } -impl LWEInfos for GGLWESwitchingKeyPrepared { - fn n(&self) -> Degree { +impl GLWESwitchingKeySetMetaData for GLWESwitchingKeyPrepared { + fn set_sk_in_n(&mut self, sk_in_n: usize) { + self.sk_in_n = sk_in_n + } + + fn set_sk_out_n(&mut self, sk_out_n: usize) { + self.sk_out_n = sk_out_n + } +} + +impl GLWESwtichingKeyGetMetaData for GLWESwitchingKeyPrepared { + fn sk_in_n(&self) -> usize { + self.sk_in_n + } + + fn sk_out_n(&self) -> usize { + self.sk_out_n + } +} + +impl LWEInfos for GLWESwitchingKeyPrepared { + fn n(&self) -> RingDegree { self.key.n() } @@ -33,13 +51,13 @@ impl LWEInfos for GGLWESwitchingKeyPrepared { } } -impl GLWEInfos for GGLWESwitchingKeyPrepared { +impl GLWEInfos for GLWESwitchingKeyPrepared { fn rank(&self) -> Rank { self.rank_out() } } -impl GGLWEInfos for GGLWESwitchingKeyPrepared { +impl GGLWEInfos for GLWESwitchingKeyPrepared { fn rank_in(&self) -> Rank { self.key.rank_in() } @@ -57,22 +75,80 @@ impl GGLWEInfos for GGLWESwitchingKeyPrepared { } } -impl GGLWESwitchingKeyPrepared, B> { - pub fn alloc(module: &Module, infos: &A) -> Self - where - A: GGLWEInfos, - Module: VmpPMatAlloc, - { - debug_assert_eq!(module.n() as u32, infos.n(), "module.n() != infos.n()"); - GGLWESwitchingKeyPrepared::, B> { - key: GGLWECiphertextPrepared::alloc(module, infos), +pub trait GLWESwitchingKeyPreparedAlloc +where + Self: GGLWEPreparedAlloc, +{ + fn alloc_glwe_switching_key_prepared( + &self, + base2k: Base2K, + k: TorusPrecision, + rank_in: Rank, + rank_out: Rank, + dnum: Dnum, + dsize: Dsize, + ) -> GLWESwitchingKeyPrepared, B> { + GLWESwitchingKeyPrepared::, B> { + key: self.alloc_gglwe_prepared(base2k, k, rank_in, rank_out, dnum, dsize), sk_in_n: 0, sk_out_n: 0, } } - pub fn alloc_with( - module: &Module, + fn alloc_glwe_switching_key_prepared_from_infos(&self, infos: &A) -> GLWESwitchingKeyPrepared, B> + where + A: GGLWEInfos, + { + self.alloc_glwe_switching_key_prepared( + infos.base2k(), + infos.k(), + infos.rank_in(), + infos.rank_out(), + infos.dnum(), + infos.dsize(), + ) + } + + fn bytes_of_glwe_switching_key_prepared( + &self, + base2k: Base2K, + k: TorusPrecision, + rank_in: Rank, + rank_out: Rank, + dnum: Dnum, + dsize: Dsize, + ) -> usize { + self.bytes_of_gglwe_prepared(base2k, k, rank_in, rank_out, dnum, dsize) + } + + fn bytes_of_glwe_switching_key_prepared_from_infos(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + self.bytes_of_glwe_switching_key_prepared( + infos.base2k(), + infos.k(), + infos.rank_in(), + infos.rank_out(), + infos.dnum(), + infos.dsize(), + ) + } +} + +impl GLWESwitchingKeyPreparedAlloc for Module where Self: GGLWEPreparedAlloc {} + +impl GLWESwitchingKeyPrepared, B> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: GGLWEInfos, + M: GLWESwitchingKeyPreparedAlloc, + { + module.alloc_glwe_switching_key_prepared_from_infos(infos) + } + + pub fn alloc( + module: &M, base2k: Base2K, k: TorusPrecision, rank_in: Rank, @@ -81,26 +157,21 @@ impl GGLWESwitchingKeyPrepared, B> { dsize: Dsize, ) -> Self where - Module: VmpPMatAlloc, + M: GLWESwitchingKeyPreparedAlloc, { - GGLWESwitchingKeyPrepared::, B> { - key: GGLWECiphertextPrepared::alloc_with(module, base2k, k, rank_in, rank_out, dnum, dsize), - sk_in_n: 0, - sk_out_n: 0, - } + module.alloc_glwe_switching_key_prepared(base2k, k, rank_in, rank_out, dnum, dsize) } - pub fn alloc_bytes(module: &Module, infos: &A) -> usize + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize where A: GGLWEInfos, - Module: VmpPMatAllocBytes, + M: GLWESwitchingKeyPreparedAlloc, { - debug_assert_eq!(module.n() as u32, infos.n(), "module.n() != infos.n()"); - GGLWECiphertextPrepared::alloc_bytes(module, infos) + module.bytes_of_glwe_switching_key_prepared_from_infos(infos) } - pub fn alloc_bytes_with( - module: &Module, + pub fn bytes_of( + module: &M, base2k: Base2K, k: TorusPrecision, rank_in: Rank, @@ -109,39 +180,79 @@ impl GGLWESwitchingKeyPrepared, B> { dsize: Dsize, ) -> usize where - Module: VmpPMatAllocBytes, + M: GLWESwitchingKeyPreparedAlloc, { - GGLWECiphertextPrepared::alloc_bytes_with(module, base2k, k, rank_in, rank_out, dnum, dsize) + module.bytes_of_glwe_switching_key_prepared(base2k, k, rank_in, rank_out, dnum, dsize) } } -impl PrepareScratchSpace for GGLWESwitchingKeyPrepared, B> +pub trait GLWESwitchingKeyPrepare where - GGLWECiphertextPrepared, B>: PrepareScratchSpace, + Self: GGLWEPrepare, { - fn prepare_scratch_space(module: &Module, infos: &A) -> usize { - GGLWECiphertextPrepared::prepare_scratch_space(module, infos) + fn prepare_glwe_switching_key_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + self.prepare_gglwe_tmp_bytes(infos) + } + + fn prepare_glwe_switching(&self, res: &mut R, other: &O, scratch: &mut Scratch) + where + R: GLWESwitchingKeyPreparedToMut + GLWESwitchingKeySetMetaData, + O: GLWESwitchingKeyToRef + GLWESwtichingKeyGetMetaData, + { + self.prepare_gglwe(&mut res.to_mut().key, &other.to_ref().key, scratch); + res.set_sk_in_n(other.sk_in_n()); + res.set_sk_out_n(other.sk_out_n()); } } -impl Prepare> for GGLWESwitchingKeyPrepared -where - Module: VmpPrepare, -{ - fn prepare(&mut self, module: &Module, other: &GGLWESwitchingKey, scratch: &mut Scratch) { - self.key.prepare(module, &other.key, scratch); - self.sk_in_n = other.sk_in_n; - self.sk_out_n = other.sk_out_n; +impl GLWESwitchingKeyPrepare for Module where Self: GGLWEPrepare {} + +impl GLWESwitchingKeyPrepared { + pub fn prepare(&mut self, module: &M, other: &O, scratch: &mut Scratch) + where + O: GLWESwitchingKeyToRef + GLWESwtichingKeyGetMetaData, + M: GLWESwitchingKeyPrepare, + { + module.prepare_glwe_switching(self, other, scratch); } } -impl PrepareAlloc, B>> for GGLWESwitchingKey -where - Module: VmpPMatAlloc + VmpPrepare, -{ - fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> GGLWESwitchingKeyPrepared, B> { - let mut atk_prepared: GGLWESwitchingKeyPrepared, B> = GGLWESwitchingKeyPrepared::alloc(module, self); - atk_prepared.prepare(module, self, scratch); - atk_prepared +impl GLWESwitchingKeyPrepared, B> { + pub fn prepare_tmp_bytes(&self, module: &M) -> usize + where + M: GLWESwitchingKeyPrepare, + { + module.prepare_glwe_switching_key_tmp_bytes(self) + } +} + +pub trait GLWESwitchingKeyPreparedToMut { + fn to_mut(&mut self) -> GLWESwitchingKeyPrepared<&mut [u8], B>; +} + +impl GLWESwitchingKeyPreparedToMut for GLWESwitchingKeyPrepared { + fn to_mut(&mut self) -> GLWESwitchingKeyPrepared<&mut [u8], B> { + GLWESwitchingKeyPrepared { + sk_in_n: self.sk_in_n, + sk_out_n: self.sk_out_n, + key: self.key.to_mut(), + } + } +} + +pub trait GLWESwitchingKeyPreparedToRef { + fn to_ref(&self) -> GLWESwitchingKeyPrepared<&[u8], B>; +} + +impl GLWESwitchingKeyPreparedToRef for GLWESwitchingKeyPrepared { + fn to_ref(&self) -> GLWESwitchingKeyPrepared<&[u8], B> { + GLWESwitchingKeyPrepared { + sk_in_n: self.sk_in_n, + sk_out_n: self.sk_out_n, + key: self.key.to_ref(), + } } } diff --git a/poulpy-core/src/layouts/prepared/gglwe_tsk.rs b/poulpy-core/src/layouts/prepared/gglwe_tsk.rs index 4343e30..2d054c1 100644 --- a/poulpy-core/src/layouts/prepared/gglwe_tsk.rs +++ b/poulpy-core/src/layouts/prepared/gglwe_tsk.rs @@ -1,20 +1,20 @@ -use poulpy_hal::{ - api::{VmpPMatAlloc, VmpPMatAllocBytes, VmpPrepare}, - layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}, -}; +use poulpy_hal::layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}; use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWETensorKey, GLWEInfos, LWEInfos, Rank, TorusPrecision, - prepared::{GGLWESwitchingKeyPrepared, Prepare, PrepareAlloc, PrepareScratchSpace}, + Base2K, Dnum, Dsize, GGLWEInfos, GLWEInfos, LWEInfos, Rank, RingDegree, TensorKey, TensorKeyToRef, TorusPrecision, + prepared::{ + GLWESwitchingKeyPrepare, GLWESwitchingKeyPrepared, GLWESwitchingKeyPreparedAlloc, GLWESwitchingKeyPreparedToMut, + GLWESwitchingKeyPreparedToRef, + }, }; #[derive(PartialEq, Eq)] -pub struct GGLWETensorKeyPrepared { - pub(crate) keys: Vec>, +pub struct TensorKeyPrepared { + pub(crate) keys: Vec>, } -impl LWEInfos for GGLWETensorKeyPrepared { - fn n(&self) -> Degree { +impl LWEInfos for TensorKeyPrepared { + fn n(&self) -> RingDegree { self.keys[0].n() } @@ -31,13 +31,13 @@ impl LWEInfos for GGLWETensorKeyPrepared { } } -impl GLWEInfos for GGLWETensorKeyPrepared { +impl GLWEInfos for TensorKeyPrepared { fn rank(&self) -> Rank { self.rank_out() } } -impl GGLWEInfos for GGLWETensorKeyPrepared { +impl GGLWEInfos for TensorKeyPrepared { fn rank_in(&self) -> Rank { self.rank_out() } @@ -55,19 +55,36 @@ impl GGLWEInfos for GGLWETensorKeyPrepared { } } -impl GGLWETensorKeyPrepared, B> { - pub fn alloc(module: &Module, infos: &A) -> Self +pub trait TensorKeyPreparedAlloc +where + Self: GLWESwitchingKeyPreparedAlloc, +{ + fn alloc_tensor_key_prepared( + &self, + base2k: Base2K, + k: TorusPrecision, + dnum: Dnum, + dsize: Dsize, + rank: Rank, + ) -> TensorKeyPrepared, B> { + let pairs: u32 = (((rank.as_u32() + 1) * rank.as_u32()) >> 1).max(1); + TensorKeyPrepared { + keys: (0..pairs) + .map(|_| self.alloc_glwe_switching_key_prepared(base2k, k, Rank(1), rank, dnum, dsize)) + .collect(), + } + } + + fn alloc_tensor_key_prepared_from_infos(&self, infos: &A) -> TensorKeyPrepared, B> where A: GGLWEInfos, - Module: VmpPMatAlloc, { assert_eq!( infos.rank_in(), infos.rank_out(), "rank_in != rank_out is not supported for GGLWETensorKeyPrepared" ); - Self::alloc_with( - module, + self.alloc_tensor_key_prepared( infos.base2k(), infos.k(), infos.dnum(), @@ -76,62 +93,62 @@ impl GGLWETensorKeyPrepared, B> { ) } - pub fn alloc_with(module: &Module, base2k: Base2K, k: TorusPrecision, dnum: Dnum, dsize: Dsize, rank: Rank) -> Self - where - Module: VmpPMatAlloc, - { - let mut keys: Vec, B>> = Vec::new(); - let pairs: u32 = (((rank.0 + 1) * rank.0) >> 1).max(1); - (0..pairs).for_each(|_| { - keys.push(GGLWESwitchingKeyPrepared::alloc_with( - module, - base2k, - k, - Rank(1), - rank, - dnum, - dsize, - )); - }); - Self { keys } + fn bytes_of_tensor_key_prepared(&self, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { + let pairs: usize = (((rank.0 + 1) * rank.0) >> 1).max(1) as usize; + pairs * self.bytes_of_glwe_switching_key_prepared(base2k, k, Rank(1), rank, dnum, dsize) } - pub fn alloc_bytes(module: &Module, infos: &A) -> usize + fn bytes_of_tensor_key_prepared_from_infos(&self, infos: &A) -> usize where A: GGLWEInfos, - Module: VmpPMatAllocBytes, { - assert_eq!( - infos.rank_in(), - infos.rank_out(), - "rank_in != rank_out is not supported for GGLWETensorKey" - ); - let rank_out: usize = infos.rank_out().into(); - let pairs: usize = (((rank_out + 1) * rank_out) >> 1).max(1); - pairs - * GGLWESwitchingKeyPrepared::alloc_bytes_with( - module, - infos.base2k(), - infos.k(), - Rank(1), - infos.rank_out(), - infos.dnum(), - infos.dsize(), - ) - } - - pub fn alloc_bytes_with(module: &Module, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize - where - Module: VmpPMatAllocBytes, - { - let pairs: usize = (((rank.0 + 1) * rank.0) >> 1).max(1) as usize; - pairs * GGLWESwitchingKeyPrepared::alloc_bytes_with(module, base2k, k, Rank(1), rank, dnum, dsize) + self.bytes_of_tensor_key_prepared( + infos.base2k(), + infos.k(), + infos.rank(), + infos.dnum(), + infos.dsize(), + ) } } -impl GGLWETensorKeyPrepared { +impl TensorKeyPreparedAlloc for Module where Module: GLWESwitchingKeyPreparedAlloc {} + +impl TensorKeyPrepared, B> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: GGLWEInfos, + M: TensorKeyPreparedAlloc, + { + module.alloc_tensor_key_prepared_from_infos(infos) + } + + pub fn alloc_with(module: &M, base2k: Base2K, k: TorusPrecision, dnum: Dnum, dsize: Dsize, rank: Rank) -> Self + where + M: TensorKeyPreparedAlloc, + { + module.alloc_tensor_key_prepared(base2k, k, dnum, dsize, rank) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: TensorKeyPreparedAlloc, + { + module.bytes_of_tensor_key_prepared_from_infos(infos) + } + + pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize + where + M: TensorKeyPreparedAlloc, + { + module.bytes_of_tensor_key_prepared(base2k, k, rank, dnum, dsize) + } +} + +impl TensorKeyPrepared { // Returns a mutable reference to GLWESwitchingKey_{s}(s[i] * s[j]) - pub fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GGLWESwitchingKeyPrepared { + pub fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GLWESwitchingKeyPrepared { if i > j { std::mem::swap(&mut i, &mut j); }; @@ -140,9 +157,9 @@ impl GGLWETensorKeyPrepared { } } -impl GGLWETensorKeyPrepared { +impl TensorKeyPrepared { // Returns a reference to GLWESwitchingKey_{s}(s[i] * s[j]) - pub fn at(&self, mut i: usize, mut j: usize) -> &GGLWESwitchingKeyPrepared { + pub fn at(&self, mut i: usize, mut j: usize) -> &GLWESwitchingKeyPrepared { if i > j { std::mem::swap(&mut i, &mut j); }; @@ -151,40 +168,81 @@ impl GGLWETensorKeyPrepared { } } -impl PrepareScratchSpace for GGLWETensorKeyPrepared, B> +pub trait TensorKeyPrepare where - GGLWESwitchingKeyPrepared, B>: PrepareScratchSpace, + Self: GLWESwitchingKeyPrepare, { - fn prepare_scratch_space(module: &Module, infos: &A) -> usize { - GGLWESwitchingKeyPrepared::prepare_scratch_space(module, infos) + fn prepare_tensor_key_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + self.prepare_glwe_switching_key_tmp_bytes(infos) } -} -impl Prepare> for GGLWETensorKeyPrepared -where - Module: VmpPrepare, -{ - fn prepare(&mut self, module: &Module, other: &GGLWETensorKey, scratch: &mut Scratch) { - #[cfg(debug_assertions)] - { - assert_eq!(self.keys.len(), other.keys.len()); + fn prepare_tensor_key(&self, res: &mut R, other: &O, scratch: &mut Scratch) + where + R: TensorKeyPreparedToMut, + O: TensorKeyToRef, + { + let mut res: TensorKeyPrepared<&mut [u8], B> = res.to_mut(); + let other: TensorKey<&[u8]> = other.to_ref(); + + assert_eq!(res.keys.len(), other.keys.len()); + + for (a, b) in res.keys.iter_mut().zip(other.keys.iter()) { + self.prepare_glwe_switching(a, b, scratch); } - self.keys - .iter_mut() - .zip(other.keys.iter()) - .for_each(|(a, b)| { - a.prepare(module, b, scratch); - }); } } -impl PrepareAlloc, B>> for GGLWETensorKey -where - Module: VmpPMatAlloc + VmpPrepare, -{ - fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> GGLWETensorKeyPrepared, B> { - let mut tsk_prepared: GGLWETensorKeyPrepared, B> = GGLWETensorKeyPrepared::alloc(module, self); - tsk_prepared.prepare(module, self, scratch); - tsk_prepared +impl TensorKeyPrepare for Module where Self: GLWESwitchingKeyPrepare {} + +impl TensorKeyPrepared, B> { + fn prepare_tmp_bytes(&self, module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: TensorKeyPrepare, + { + module.prepare_tensor_key_tmp_bytes(infos) + } +} + +impl TensorKeyPrepared { + fn prepare(&mut self, module: &M, other: &O, scratch: &mut Scratch) + where + O: TensorKeyToRef, + M: TensorKeyPrepare, + { + module.prepare_tensor_key(self, other, scratch); + } +} + +pub trait TensorKeyPreparedToMut { + fn to_mut(&mut self) -> TensorKeyPrepared<&mut [u8], B>; +} + +impl TensorKeyPreparedToMut for TensorKeyPrepared +where + GLWESwitchingKeyPrepared: GLWESwitchingKeyPreparedToMut, +{ + fn to_mut(&mut self) -> TensorKeyPrepared<&mut [u8], B> { + TensorKeyPrepared { + keys: self.keys.iter_mut().map(|c| c.to_mut()).collect(), + } + } +} + +pub trait TensorKeyPreparedToRef { + fn to_ref(&self) -> TensorKeyPrepared<&[u8], B>; +} + +impl TensorKeyPreparedToRef for TensorKeyPrepared +where + GLWESwitchingKeyPrepared: GLWESwitchingKeyPreparedToRef, +{ + fn to_ref(&self) -> TensorKeyPrepared<&[u8], B> { + TensorKeyPrepared { + keys: self.keys.iter().map(|c| c.to_ref()).collect(), + } } } diff --git a/poulpy-core/src/layouts/prepared/ggsw_ct.rs b/poulpy-core/src/layouts/prepared/ggsw_ct.rs index eb79a5a..6365818 100644 --- a/poulpy-core/src/layouts/prepared/ggsw_ct.rs +++ b/poulpy-core/src/layouts/prepared/ggsw_ct.rs @@ -1,25 +1,23 @@ use poulpy_hal::{ - api::{VmpPMatAlloc, VmpPMatAllocBytes, VmpPrepare, VmpPrepareTmpBytes}, - layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VmpPMat, VmpPMatToRef, ZnxInfos}, - oep::VmpPMatAllocBytesImpl, + api::{VmpPMatAlloc, VmpPMatBytesOf, VmpPrepare, VmpPrepareTmpBytes}, + layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VmpPMat, VmpPMatToMut, VmpPMatToRef, ZnxInfos}, }; use crate::layouts::{ - Base2K, BuildError, Degree, Dnum, Dsize, GGSWCiphertext, GGSWInfos, GLWEInfos, LWEInfos, Rank, TorusPrecision, - prepared::{Prepare, PrepareAlloc, PrepareScratchSpace}, + Base2K, Dnum, Dsize, GGSW, GGSWInfos, GGSWToRef, GLWEInfos, GetRingDegree, LWEInfos, Rank, RingDegree, TorusPrecision, }; #[derive(PartialEq, Eq)] -pub struct GGSWCiphertextPrepared { +pub struct GGSWPrepared { pub(crate) data: VmpPMat, pub(crate) k: TorusPrecision, pub(crate) base2k: Base2K, pub(crate) dsize: Dsize, } -impl LWEInfos for GGSWCiphertextPrepared { - fn n(&self) -> Degree { - Degree(self.data.n() as u32) +impl LWEInfos for GGSWPrepared { + fn n(&self) -> RingDegree { + RingDegree(self.data.n() as u32) } fn base2k(&self) -> Base2K { @@ -35,13 +33,13 @@ impl LWEInfos for GGSWCiphertextPrepared { } } -impl GLWEInfos for GGSWCiphertextPrepared { +impl GLWEInfos for GGSWPrepared { fn rank(&self) -> Rank { Rank(self.data.cols_out() as u32 - 1) } } -impl GGSWInfos for GGSWCiphertextPrepared { +impl GGSWInfos for GGSWPrepared { fn dsize(&self) -> Dsize { self.dsize } @@ -51,143 +49,18 @@ impl GGSWInfos for GGSWCiphertextPrepared { } } -pub struct GGSWCiphertextPreparedBuilder { - data: Option>, - base2k: Option, - k: Option, - dsize: Option, -} - -impl GGSWCiphertextPrepared { - #[inline] - pub fn builder() -> GGSWCiphertextPreparedBuilder { - GGSWCiphertextPreparedBuilder { - data: None, - base2k: None, - k: None, - dsize: None, - } - } -} - -impl GGSWCiphertextPreparedBuilder, B> { - #[inline] - pub fn layout(mut self, infos: &A) -> Self - where - A: GGSWInfos, - B: VmpPMatAllocBytesImpl, - { - debug_assert!( - infos.size() as u32 > infos.dsize().0, - "invalid ggsw: ceil(k/base2k): {} <= dsize: {}", - infos.size(), - infos.dsize() - ); - - assert!( - infos.dnum().0 * infos.dsize().0 <= infos.size() as u32, - "invalid ggsw: dnum: {} * dsize:{} > ceil(k/base2k): {}", - infos.dnum(), - infos.dsize(), - infos.size(), - ); - - self.data = Some(VmpPMat::alloc( - infos.n().into(), - infos.dnum().into(), - (infos.rank() + 1).into(), - (infos.rank() + 1).into(), - infos.size(), - )); - self.base2k = Some(infos.base2k()); - self.k = Some(infos.k()); - self.dsize = Some(infos.dsize()); - self - } -} - -impl GGSWCiphertextPreparedBuilder { - #[inline] - pub fn data(mut self, data: VmpPMat) -> Self { - self.data = Some(data); - self - } - #[inline] - pub fn base2k(mut self, base2k: Base2K) -> Self { - self.base2k = Some(base2k); - self - } - #[inline] - pub fn k(mut self, k: TorusPrecision) -> Self { - self.k = Some(k); - self - } - - #[inline] - pub fn dsize(mut self, dsize: Dsize) -> Self { - self.dsize = Some(dsize); - self - } - - pub fn build(self) -> Result, BuildError> { - let data: VmpPMat = self.data.ok_or(BuildError::MissingData)?; - let base2k: Base2K = self.base2k.ok_or(BuildError::MissingBase2K)?; - let k: TorusPrecision = self.k.ok_or(BuildError::MissingK)?; - let dsize: Dsize = self.dsize.ok_or(BuildError::MissingDigits)?; - - if base2k == 0_u32 { - return Err(BuildError::ZeroBase2K); - } - - if dsize == 0_u32 { - return Err(BuildError::ZeroBase2K); - } - - if k == 0_u32 { - return Err(BuildError::ZeroTorusPrecision); - } - - if data.n() == 0 { - return Err(BuildError::ZeroDegree); - } - - if data.cols() == 0 { - return Err(BuildError::ZeroCols); - } - - if data.size() == 0 { - return Err(BuildError::ZeroLimbs); - } - - Ok(GGSWCiphertextPrepared { - data, - base2k, - k, - dsize, - }) - } -} - -impl GGSWCiphertextPrepared, B> { - pub fn alloc(module: &Module, infos: &A) -> Self - where - A: GGSWInfos, - Module: VmpPMatAlloc, - { - Self::alloc_with( - module, - infos.base2k(), - infos.k(), - infos.dnum(), - infos.dsize(), - infos.rank(), - ) - } - - pub fn alloc_with(module: &Module, base2k: Base2K, k: TorusPrecision, dnum: Dnum, dsize: Dsize, rank: Rank) -> Self - where - Module: VmpPMatAlloc, - { +pub trait GGSWPreparedAlloc +where + Self: GetRingDegree + VmpPMatAlloc + VmpPMatBytesOf, +{ + fn alloc_ggsw_prepared( + &self, + base2k: Base2K, + k: TorusPrecision, + dnum: Dnum, + dsize: Dsize, + rank: Rank, + ) -> GGSWPrepared, B> { let size: usize = k.0.div_ceil(base2k.0) as usize; debug_assert!( size as u32 > dsize.0, @@ -202,8 +75,8 @@ impl GGSWCiphertextPrepared, B> { dsize.0, ); - Self { - data: module.vmp_pmat_alloc( + GGSWPrepared { + data: self.vmp_pmat_alloc( dnum.into(), (rank + 1).into(), (rank + 1).into(), @@ -215,13 +88,12 @@ impl GGSWCiphertextPrepared, B> { } } - pub fn alloc_bytes(module: &Module, infos: &A) -> usize + fn alloc_ggsw_prepared_from_infos(&self, infos: &A) -> GGSWPrepared, B> where A: GGSWInfos, - Module: VmpPMatAllocBytes, { - Self::alloc_bytes_with( - module, + assert_eq!(self.ring_degree(), infos.n()); + self.alloc_ggsw_prepared( infos.base2k(), infos.k(), infos.dnum(), @@ -230,10 +102,7 @@ impl GGSWCiphertextPrepared, B> { ) } - pub fn alloc_bytes_with(module: &Module, base2k: Base2K, k: TorusPrecision, dnum: Dnum, dsize: Dsize, rank: Rank) -> usize - where - Module: VmpPMatAllocBytes, - { + fn bytes_of_ggsw_prepared(&self, base2k: Base2K, k: TorusPrecision, dnum: Dnum, dsize: Dsize, rank: Rank) -> usize { let size: usize = k.0.div_ceil(base2k.0) as usize; debug_assert!( size as u32 > dsize.0, @@ -248,65 +117,144 @@ impl GGSWCiphertextPrepared, B> { dsize.0, ); - module.vmp_pmat_alloc_bytes(dnum.into(), (rank + 1).into(), (rank + 1).into(), size) + self.bytes_of_vmp_pmat(dnum.into(), (rank + 1).into(), (rank + 1).into(), size) + } + + fn bytes_of_ggsw_prepared_from_infos(&self, infos: &A) -> usize + where + A: GGSWInfos, + { + assert_eq!(self.ring_degree(), infos.n()); + self.bytes_of_ggsw_prepared( + infos.base2k(), + infos.k(), + infos.dnum(), + infos.dsize(), + infos.rank(), + ) } } -impl GGSWCiphertextPrepared { +impl GGSWPreparedAlloc for Module where Self: GetRingDegree + VmpPMatAlloc + VmpPMatBytesOf {} + +impl GGSWPrepared, B> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: GGSWInfos, + M: GGSWPreparedAlloc, + { + module.alloc_ggsw_prepared_from_infos(infos) + } + + pub fn alloc(module: &M, base2k: Base2K, k: TorusPrecision, dnum: Dnum, dsize: Dsize, rank: Rank) -> Self + where + M: GGSWPreparedAlloc, + { + module.alloc_ggsw_prepared(base2k, k, dnum, dsize, rank) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: GGSWInfos, + M: GGSWPreparedAlloc, + { + module.bytes_of_ggsw_prepared_from_infos(infos) + } + + pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision, dnum: Dnum, dsize: Dsize, rank: Rank) -> usize + where + M: GGSWPreparedAlloc, + { + module.bytes_of_ggsw_prepared(base2k, k, dnum, dsize, rank) + } +} + +impl GGSWPrepared { pub fn data(&self) -> &VmpPMat { &self.data } } -impl PrepareScratchSpace for GGSWCiphertextPrepared, B> +pub trait GGSWPrepare where - Module: VmpPrepareTmpBytes, + Self: GetRingDegree + VmpPrepareTmpBytes + VmpPrepare, { - fn prepare_scratch_space(module: &Module, infos: &A) -> usize { - module.vmp_prepare_tmp_bytes( + fn ggsw_prepare_tmp_bytes(&self, infos: &A) -> usize + where + A: GGSWInfos, + { + assert_eq!(self.ring_degree(), infos.n()); + self.vmp_prepare_tmp_bytes( infos.dnum().into(), (infos.rank() + 1).into(), (infos.rank() + 1).into(), infos.size(), ) } -} - -impl Prepare> for GGSWCiphertextPrepared -where - Module: VmpPrepare, -{ - fn prepare(&mut self, module: &Module, other: &GGSWCiphertext, scratch: &mut Scratch) { - module.vmp_prepare(&mut self.data, &other.data, scratch); - self.k = other.k; - self.base2k = other.base2k; - self.dsize = other.dsize; + fn ggsw_prepare(&self, res: &mut R, other: &O, scratch: &mut Scratch) + where + R: GGSWPreparedToMut, + O: GGSWToRef, + { + let mut res: GGSWPrepared<&mut [u8], B> = res.to_mut(); + let other: GGSW<&[u8]> = other.to_ref(); + assert_eq!(res.n(), self.ring_degree()); + assert_eq!(other.n(), self.ring_degree()); + assert_eq!(res.k, other.k); + assert_eq!(res.base2k, other.base2k); + assert_eq!(res.dsize, other.dsize); + self.vmp_prepare(&mut res.data, &other.data, scratch); } } -impl PrepareAlloc, B>> for GGSWCiphertext -where - Module: VmpPMatAlloc + VmpPrepare, -{ - fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> GGSWCiphertextPrepared, B> { - let mut ggsw_prepared: GGSWCiphertextPrepared, B> = GGSWCiphertextPrepared::alloc(module, self); - ggsw_prepared.prepare(module, self, scratch); - ggsw_prepared +impl GGSWPrepare for Module where Self: GetRingDegree + VmpPrepareTmpBytes + VmpPrepare {} + +impl GGSWPrepared, B> { + pub fn prepare_tmp_bytes(&self, module: &M, infos: &A) -> usize + where + A: GGSWInfos, + M: GGSWPrepare, + { + module.ggsw_prepare_tmp_bytes(infos) } } -pub trait GGSWCiphertextPreparedToRef { - fn to_ref(&self) -> GGSWCiphertextPrepared<&[u8], B>; -} - -impl GGSWCiphertextPreparedToRef for GGSWCiphertextPrepared { - fn to_ref(&self) -> GGSWCiphertextPrepared<&[u8], B> { - GGSWCiphertextPrepared::builder() - .base2k(self.base2k()) - .dsize(self.dsize()) - .k(self.k()) - .data(self.data.to_ref()) - .build() - .unwrap() +impl GGSWPrepared { + pub fn prepare(&mut self, module: &M, other: &O, scratch: &mut Scratch) + where + O: GGSWToRef, + M: GGSWPrepare, + { + module.ggsw_prepare(self, other, scratch); + } +} + +pub trait GGSWPreparedToMut { + fn to_mut(&mut self) -> GGSWPrepared<&mut [u8], B>; +} + +impl GGSWPreparedToMut for GGSWPrepared { + fn to_mut(&mut self) -> GGSWPrepared<&mut [u8], B> { + GGSWPrepared { + base2k: self.base2k, + k: self.k, + dsize: self.dsize, + data: self.data.to_mut(), + } + } +} + +pub trait GGSWPreparedToRef { + fn to_ref(&self) -> GGSWPrepared<&[u8], B>; +} + +impl GGSWPreparedToRef for GGSWPrepared { + fn to_ref(&self) -> GGSWPrepared<&[u8], B> { + GGSWPrepared { + base2k: self.base2k, + k: self.k, + dsize: self.dsize, + data: self.data.to_ref(), + } } } diff --git a/poulpy-core/src/layouts/prepared/glwe_pk.rs b/poulpy-core/src/layouts/prepared/glwe_pk.rs index 6834f58..bca1826 100644 --- a/poulpy-core/src/layouts/prepared/glwe_pk.rs +++ b/poulpy-core/src/layouts/prepared/glwe_pk.rs @@ -1,14 +1,12 @@ use poulpy_hal::{ - api::{VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply}, - layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VecZnxDft, ZnxInfos}, - oep::VecZnxDftAllocBytesImpl, + api::{VecZnxDftAlloc, VecZnxDftApply, VecZnxDftBytesOf}, + layouts::{Backend, Data, DataMut, DataRef, Module, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ZnxInfos}, }; use crate::{ dist::Distribution, layouts::{ - Base2K, BuildError, Degree, GLWEInfos, GLWEPublicKey, LWEInfos, Rank, TorusPrecision, - prepared::{Prepare, PrepareAlloc, PrepareScratchSpace}, + Base2K, GLWEInfos, GLWEPublicKey, GLWEPublicKeyToRef, GetDist, GetRingDegree, LWEInfos, Rank, RingDegree, TorusPrecision, }, }; @@ -20,6 +18,16 @@ pub struct GLWEPublicKeyPrepared { pub(crate) dist: Distribution, } +pub(crate) trait SetDist { + fn set_dist(&mut self, dist: Distribution); +} + +impl SetDist for GLWEPublicKeyPrepared { + fn set_dist(&mut self, dist: Distribution) { + self.dist = dist + } +} + impl LWEInfos for GLWEPublicKeyPrepared { fn base2k(&self) -> Base2K { self.base2k @@ -33,8 +41,8 @@ impl LWEInfos for GLWEPublicKeyPrepared { self.data.size() } - fn n(&self) -> Degree { - Degree(self.data.n() as u32) + fn n(&self) -> RingDegree { + RingDegree(self.data.n() as u32) } } @@ -44,164 +52,138 @@ impl GLWEInfos for GLWEPublicKeyPrepared { } } -pub struct GLWEPublicKeyPreparedBuilder { - data: Option>, - base2k: Option, - k: Option, -} - -impl GLWEPublicKeyPrepared { - #[inline] - pub fn builder() -> GLWEPublicKeyPreparedBuilder { - GLWEPublicKeyPreparedBuilder { - data: None, - base2k: None, - k: None, - } - } -} - -impl GLWEPublicKeyPreparedBuilder, B> { - #[inline] - pub fn layout(mut self, layout: &A) -> Self - where - A: GLWEInfos, - B: VecZnxDftAllocBytesImpl, - { - self.data = Some(VecZnxDft::alloc( - layout.n().into(), - (layout.rank() + 1).into(), - layout.size(), - )); - self.base2k = Some(layout.base2k()); - self.k = Some(layout.k()); - self - } -} - -impl GLWEPublicKeyPreparedBuilder { - #[inline] - pub fn data(mut self, data: VecZnxDft) -> Self { - self.data = Some(data); - self - } - #[inline] - pub fn base2k(mut self, base2k: Base2K) -> Self { - self.base2k = Some(base2k); - self - } - #[inline] - pub fn k(mut self, k: TorusPrecision) -> Self { - self.k = Some(k); - self - } - - pub fn build(self) -> Result, BuildError> { - let data: VecZnxDft = self.data.ok_or(BuildError::MissingData)?; - let base2k: Base2K = self.base2k.ok_or(BuildError::MissingBase2K)?; - let k: TorusPrecision = self.k.ok_or(BuildError::MissingK)?; - - if base2k == 0_u32 { - return Err(BuildError::ZeroBase2K); - } - - if k == 0_u32 { - return Err(BuildError::ZeroTorusPrecision); - } - - if data.n() == 0 { - return Err(BuildError::ZeroDegree); - } - - if data.cols() == 0 { - return Err(BuildError::ZeroCols); - } - - if data.size() == 0 { - return Err(BuildError::ZeroLimbs); - } - - Ok(GLWEPublicKeyPrepared { - data, +pub trait GLWEPublicKeyPreparedAlloc +where + Self: GetRingDegree + VecZnxDftAlloc + VecZnxDftBytesOf, +{ + fn alloc_glwe_public_key_prepared(&self, base2k: Base2K, k: TorusPrecision, rank: Rank) -> GLWEPublicKeyPrepared, B> { + GLWEPublicKeyPrepared { + data: self.vec_znx_dft_alloc((rank + 1).into(), k.0.div_ceil(base2k.0) as usize), base2k, k, dist: Distribution::NONE, - }) + } + } + + fn alloc_glwe_public_key_prepared_from_infos(&self, infos: &A) -> GLWEPublicKeyPrepared, B> + where + A: GLWEInfos, + { + self.alloc_glwe_public_key_prepared(infos.base2k(), infos.k(), infos.rank()) + } + + fn bytes_of_glwe_public_key_prepared(&self, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize { + self.bytes_of_vec_znx_dft((rank + 1).into(), k.0.div_ceil(base2k.0) as usize) + } + + fn bytes_of_glwe_public_key_prepared_from_infos(&self, infos: &A) -> usize + where + A: GLWEInfos, + { + self.bytes_of_glwe_public_key_prepared(infos.base2k(), infos.k(), infos.rank()) } } +impl GLWEPublicKeyPreparedAlloc for Module where Self: VecZnxDftAlloc + VecZnxDftBytesOf {} + impl GLWEPublicKeyPrepared, B> { - pub fn alloc(module: &Module, infos: &A) -> Self + pub fn alloc_from_infos(module: &M, infos: &A) -> Self where A: GLWEInfos, - Module: VecZnxDftAlloc, + M: GLWEPublicKeyPreparedAlloc, { - debug_assert_eq!(module.n(), infos.n().0 as usize, "module.n() != infos.n()"); - Self::alloc_with(module, infos.base2k(), infos.k(), infos.rank()) + module.alloc_glwe_public_key_prepared_from_infos(infos) } - pub fn alloc_with(module: &Module, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self + pub fn alloc(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self where - Module: VecZnxDftAlloc, + M: GLWEPublicKeyPreparedAlloc, { - Self { - data: module.vec_znx_dft_alloc((rank + 1).into(), k.0.div_ceil(base2k.0) as usize), - base2k, - k, - dist: Distribution::NONE, - } + module.alloc_glwe_public_key_prepared(base2k, k, rank) } - pub fn alloc_bytes(module: &Module, infos: &A) -> usize + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize where A: GLWEInfos, - Module: VecZnxDftAllocBytes, + M: GLWEPublicKeyPreparedAlloc, { - debug_assert_eq!(module.n(), infos.n().0 as usize, "module.n() != infos.n()"); - Self::alloc_bytes_with(module, infos.base2k(), infos.k(), infos.rank()) + module.bytes_of_glwe_public_key_prepared_from_infos(infos) } - pub fn alloc_bytes_with(module: &Module, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize + pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize where - Module: VecZnxDftAllocBytes, + M: GLWEPublicKeyPreparedAlloc, { - module.vec_znx_dft_alloc_bytes((rank + 1).into(), k.0.div_ceil(base2k.0) as usize) + module.bytes_of_glwe_public_key_prepared(base2k, k, rank) } } -impl PrepareAlloc, B>> for GLWEPublicKey +pub trait GLWEPublicKeyPrepare where - Module: VecZnxDftAlloc + VecZnxDftApply, + Self: GetRingDegree + VecZnxDftApply, { - fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> GLWEPublicKeyPrepared, B> { - let mut pk_prepared: GLWEPublicKeyPrepared, B> = GLWEPublicKeyPrepared::alloc(module, self); - pk_prepared.prepare(module, self, scratch); - pk_prepared - } -} - -impl PrepareScratchSpace for GLWEPublicKeyPrepared, B> { - fn prepare_scratch_space(_module: &Module, _infos: &A) -> usize { - 0 - } -} - -impl Prepare> for GLWEPublicKeyPrepared -where - Module: VecZnxDftApply, -{ - fn prepare(&mut self, module: &Module, other: &GLWEPublicKey, _scratch: &mut Scratch) { - #[cfg(debug_assertions)] + fn prepare_glwe_public_key(&self, res: &mut R, other: &O) + where + R: GLWEPublicKeyPreparedToMut + SetDist, + O: GLWEPublicKeyToRef + GetDist, + { { - assert_eq!(self.n(), other.n()); - assert_eq!(self.size(), other.size()); + let mut res: GLWEPublicKeyPrepared<&mut [u8], B> = res.to_mut(); + let other: GLWEPublicKey<&[u8]> = other.to_ref(); + + assert_eq!(res.n(), self.ring_degree()); + assert_eq!(other.n(), self.ring_degree()); + assert_eq!(res.size(), other.size()); + assert_eq!(res.k(), other.k()); + assert_eq!(res.base2k(), other.base2k()); + + for i in 0..(res.rank() + 1).into() { + self.vec_znx_dft_apply(1, 0, &mut res.data, i, &other.data, i); + } } - (0..(self.rank() + 1).into()).for_each(|i| { - module.vec_znx_dft_apply(1, 0, &mut self.data, i, &other.data, i); - }); - self.k = other.k(); - self.base2k = other.base2k(); - self.dist = other.dist; + res.set_dist(other.get_dist()); + } +} + +impl GLWEPublicKeyPrepare for Module where Self: GetRingDegree + VecZnxDftApply {} + +impl GLWEPublicKeyPrepared { + pub fn prepare(&mut self, module: &M, other: &O) + where + O: GLWEPublicKeyToRef + GetDist, + M: GLWEPublicKeyPrepare, + { + module.prepare_glwe_public_key(self, other); + } +} + +pub trait GLWEPublicKeyPreparedToMut { + fn to_mut(&mut self) -> GLWEPublicKeyPrepared<&mut [u8], B>; +} + +impl GLWEPublicKeyPreparedToMut for GLWEPublicKeyPrepared { + fn to_mut(&mut self) -> GLWEPublicKeyPrepared<&mut [u8], B> { + GLWEPublicKeyPrepared { + dist: self.dist, + k: self.k, + base2k: self.base2k, + data: self.data.to_mut(), + } + } +} + +pub trait GLWEPublicKeyPreparedToRef { + fn to_ref(&self) -> GLWEPublicKeyPrepared<&[u8], B>; +} + +impl GLWEPublicKeyPreparedToRef for GLWEPublicKeyPrepared { + fn to_ref(&self) -> GLWEPublicKeyPrepared<&[u8], B> { + GLWEPublicKeyPrepared { + data: self.data.to_ref(), + dist: self.dist, + k: self.k, + base2k: self.base2k, + } } } diff --git a/poulpy-core/src/layouts/prepared/glwe_sk.rs b/poulpy-core/src/layouts/prepared/glwe_sk.rs index d3f638b..ac7e11f 100644 --- a/poulpy-core/src/layouts/prepared/glwe_sk.rs +++ b/poulpy-core/src/layouts/prepared/glwe_sk.rs @@ -1,13 +1,13 @@ use poulpy_hal::{ - api::{SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare}, - layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, SvpPPol, ZnxInfos}, + api::{SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare}, + layouts::{Backend, Data, DataMut, DataRef, Module, SvpPPol, SvpPPolToMut, SvpPPolToRef, ZnxInfos}, }; use crate::{ dist::Distribution, layouts::{ - Base2K, Degree, GLWEInfos, GLWESecret, LWEInfos, Rank, TorusPrecision, - prepared::{Prepare, PrepareAlloc, PrepareScratchSpace}, + Base2K, GLWEInfos, GLWESecret, GLWESecretToRef, GetDist, GetRingDegree, LWEInfos, Rank, RingDegree, TorusPrecision, + prepared::SetDist, }, }; @@ -16,6 +16,12 @@ pub struct GLWESecretPrepared { pub(crate) dist: Distribution, } +impl SetDist for GLWESecretPrepared { + fn set_dist(&mut self, dist: Distribution) { + self.dist = dist + } +} + impl LWEInfos for GLWESecretPrepared { fn base2k(&self) -> Base2K { Base2K(0) @@ -25,8 +31,8 @@ impl LWEInfos for GLWESecretPrepared { TorusPrecision(0) } - fn n(&self) -> Degree { - Degree(self.data.n() as u32) + fn n(&self) -> RingDegree { + RingDegree(self.data.n() as u32) } fn size(&self) -> usize { @@ -38,46 +44,74 @@ impl GLWEInfos for GLWESecretPrepared { Rank(self.data.cols() as u32) } } -impl GLWESecretPrepared, B> { - pub fn alloc(module: &Module, infos: &A) -> Self - where - A: GLWEInfos, - Module: SvpPPolAlloc, - { - assert_eq!(module.n() as u32, infos.n()); - Self::alloc_with(module, infos.rank()) - } - pub fn alloc_with(module: &Module, rank: Rank) -> Self - where - Module: SvpPPolAlloc, - { - Self { - data: module.svp_ppol_alloc(rank.into()), +pub trait GLWESecretPreparedAlloc +where + Self: GetRingDegree + SvpPPolBytesOf + SvpPPolAlloc, +{ + fn alloc_glwe_secret_prepared(&self, rank: Rank) -> GLWESecretPrepared, B> { + GLWESecretPrepared { + data: self.svp_ppol_alloc(rank.into()), dist: Distribution::NONE, } } - - pub fn alloc_bytes(module: &Module, infos: &A) -> usize + fn alloc_glwe_secret_prepared_from_infos(&self, infos: &A) -> GLWESecretPrepared, B> where A: GLWEInfos, - Module: SvpPPolAllocBytes, { - assert_eq!(module.n() as u32, infos.n()); - Self::alloc_bytes_with(module, infos.rank()) + assert_eq!(self.ring_degree(), infos.n()); + self.alloc_glwe_secret_prepared(infos.rank()) } - pub fn alloc_bytes_with(module: &Module, rank: Rank) -> usize + fn bytes_of_glwe_secret(&self, rank: Rank) -> usize { + self.bytes_of_svp_ppol(rank.into()) + } + fn bytes_of_glwe_secret_from_infos(&self, infos: &A) -> usize where - Module: SvpPPolAllocBytes, + A: GLWEInfos, { - module.svp_ppol_alloc_bytes(rank.into()) + assert_eq!(self.ring_degree(), infos.n()); + self.bytes_of_glwe_secret(infos.rank()) + } +} + +impl GLWESecretPreparedAlloc for Module where Self: GetRingDegree + SvpPPolBytesOf + SvpPPolAlloc {} + +impl GLWESecretPrepared, B> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: GLWEInfos, + M: GLWESecretPreparedAlloc, + { + module.alloc_glwe_secret_prepared_from_infos(infos) + } + + pub fn alloc(module: &M, rank: Rank) -> Self + where + M: GLWESecretPreparedAlloc, + { + module.alloc_glwe_secret_prepared(rank) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: GLWEInfos, + M: GLWESecretPreparedAlloc, + { + module.bytes_of_glwe_secret_from_infos(infos) + } + + pub fn bytes_of(module: &M, rank: Rank) -> usize + where + M: GLWESecretPreparedAlloc, + { + module.bytes_of_glwe_secret(rank) } } impl GLWESecretPrepared { - pub fn n(&self) -> Degree { - Degree(self.data.n() as u32) + pub fn n(&self) -> RingDegree { + RingDegree(self.data.n() as u32) } pub fn rank(&self) -> Rank { @@ -85,31 +119,62 @@ impl GLWESecretPrepared { } } -impl PrepareScratchSpace for GLWESecretPrepared, B> { - fn prepare_scratch_space(_module: &Module, _infos: &A) -> usize { - 0 +pub trait GLWESecretPrepare +where + Self: SvpPrepare, +{ + fn prepare_glwe_secret(&self, res: &mut R, other: &O) + where + R: GLWESecretPreparedToMut + SetDist, + O: GLWESecretToRef + GetDist, + { + { + let mut res: GLWESecretPrepared<&mut [u8], _> = res.to_mut(); + let other: GLWESecret<&[u8]> = other.to_ref(); + + for i in 0..res.rank().into() { + self.svp_prepare(&mut res.data, i, &other.data, i); + } + } + + res.set_dist(other.get_dist()); } } -impl PrepareAlloc, B>> for GLWESecret -where - Module: SvpPrepare + SvpPPolAlloc, -{ - fn prepare_alloc(&self, module: &Module, _scratch: &mut Scratch) -> GLWESecretPrepared, B> { - let mut sk_dft: GLWESecretPrepared, B> = GLWESecretPrepared::alloc(module, self); - sk_dft.prepare(module, self, _scratch); - sk_dft +impl GLWESecretPrepare for Module where Self: SvpPrepare {} + +impl GLWESecretPrepared { + pub fn prepare(&mut self, module: &M, other: &O) + where + M: GLWESecretPrepare, + O: GLWESecretToRef + GetDist, + { + module.prepare_glwe_secret(self, other); } } -impl Prepare> for GLWESecretPrepared -where - Module: SvpPrepare, -{ - fn prepare(&mut self, module: &Module, other: &GLWESecret, _scratch: &mut Scratch) { - (0..self.rank().into()).for_each(|i| { - module.svp_prepare(&mut self.data, i, &other.data, i); - }); - self.dist = other.dist +pub trait GLWESecretPreparedToRef { + fn to_ref(&self) -> GLWESecretPrepared<&[u8], B>; +} + +impl GLWESecretPreparedToRef for GLWESecretPrepared { + fn to_ref(&self) -> GLWESecretPrepared<&[u8], B> { + GLWESecretPrepared { + data: self.data.to_ref(), + dist: self.dist, + } + } +} + +pub trait GLWESecretPreparedToMut { + fn to_mut(&mut self) -> GLWESecretPrepared<&mut [u8], B>; +} + +impl GLWESecretPreparedToMut for GLWESecretPrepared { + fn to_mut(&mut self) -> GLWESecretPrepared<&mut [u8], B> { + GLWESecretPrepared { + dist: self.dist, + data: self.data.to_mut(), + } } } diff --git a/poulpy-core/src/layouts/prepared/glwe_to_lwe_ksk.rs b/poulpy-core/src/layouts/prepared/glwe_to_lwe_ksk.rs index f241c6d..55a7bf9 100644 --- a/poulpy-core/src/layouts/prepared/glwe_to_lwe_ksk.rs +++ b/poulpy-core/src/layouts/prepared/glwe_to_lwe_ksk.rs @@ -1,15 +1,15 @@ -use poulpy_hal::{ - api::{VmpPMatAlloc, VmpPMatAllocBytes, VmpPrepare}, - layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}, -}; +use poulpy_hal::layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}; use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWEInfos, GLWEInfos, GLWEToLWEKey, LWEInfos, Rank, TorusPrecision, - prepared::{GGLWESwitchingKeyPrepared, Prepare, PrepareAlloc, PrepareScratchSpace}, + Base2K, Dnum, Dsize, GGLWEInfos, GLWEInfos, GLWEToLWESwitchingKeyToRef, LWEInfos, Rank, RingDegree, TorusPrecision, + prepared::{ + GLWESwitchingKeyPrepare, GLWESwitchingKeyPrepared, GLWESwitchingKeyPreparedAlloc, GLWESwitchingKeyPreparedToMut, + GLWESwitchingKeyPreparedToRef, + }, }; #[derive(PartialEq, Eq)] -pub struct GLWEToLWESwitchingKeyPrepared(pub(crate) GGLWESwitchingKeyPrepared); +pub struct GLWEToLWESwitchingKeyPrepared(pub(crate) GLWESwitchingKeyPrepared); impl LWEInfos for GLWEToLWESwitchingKeyPrepared { fn base2k(&self) -> Base2K { @@ -20,7 +20,7 @@ impl LWEInfos for GLWEToLWESwitchingKeyPrepared { self.0.k() } - fn n(&self) -> Degree { + fn n(&self) -> RingDegree { self.0.n() } @@ -53,91 +53,156 @@ impl GGLWEInfos for GLWEToLWESwitchingKeyPrepared { } } +pub trait GLWEToLWESwitchingKeyPreparedAlloc +where + Self: GLWESwitchingKeyPreparedAlloc, +{ + fn alloc_glwe_to_lwe_switching_key_prepared( + &self, + base2k: Base2K, + k: TorusPrecision, + rank_in: Rank, + dnum: Dnum, + ) -> GLWEToLWESwitchingKeyPrepared, B> { + GLWEToLWESwitchingKeyPrepared(self.alloc_glwe_switching_key_prepared(base2k, k, rank_in, Rank(1), dnum, Dsize(1))) + } + fn alloc_glwe_to_lwe_switching_key_prepared_from_infos(&self, infos: &A) -> GLWEToLWESwitchingKeyPrepared, B> + where + A: GGLWEInfos, + { + debug_assert_eq!( + infos.rank_out().0, + 1, + "rank_out > 1 is not supported for GLWEToLWESwitchingKeyPrepared" + ); + debug_assert_eq!( + infos.dsize().0, + 1, + "dsize > 1 is not supported for GLWEToLWESwitchingKeyPrepared" + ); + self.alloc_glwe_to_lwe_switching_key_prepared(infos.base2k(), infos.k(), infos.rank_in(), infos.dnum()) + } + + fn bytes_of_glwe_to_lwe_switching_key_prepared(&self, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> usize { + self.bytes_of_glwe_switching_key_prepared(base2k, k, rank_in, Rank(1), dnum, Dsize(1)) + } + + fn bytes_of_glwe_to_lwe_switching_key_prepared_from_infos(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + debug_assert_eq!( + infos.rank_out().0, + 1, + "rank_out > 1 is not supported for GLWEToLWESwitchingKeyPrepared" + ); + debug_assert_eq!( + infos.dsize().0, + 1, + "dsize > 1 is not supported for GLWEToLWESwitchingKeyPrepared" + ); + self.bytes_of_glwe_to_lwe_switching_key_prepared(infos.base2k(), infos.k(), infos.rank_in(), infos.dnum()) + } +} + +impl GLWEToLWESwitchingKeyPreparedAlloc for Module where Self: GLWESwitchingKeyPreparedAlloc {} + impl GLWEToLWESwitchingKeyPrepared, B> { - pub fn alloc(module: &Module, infos: &A) -> Self + pub fn alloc_from_infos(module: &M, infos: &A) -> Self where A: GGLWEInfos, - Module: VmpPMatAlloc, + M: GLWEToLWESwitchingKeyPreparedAlloc, { - debug_assert_eq!( - infos.rank_out().0, - 1, - "rank_out > 1 is not supported for GLWEToLWESwitchingKeyPrepared" - ); - debug_assert_eq!( - infos.dsize().0, - 1, - "dsize > 1 is not supported for GLWEToLWESwitchingKeyPrepared" - ); - Self(GGLWESwitchingKeyPrepared::alloc(module, infos)) + module.alloc_glwe_to_lwe_switching_key_prepared_from_infos(infos) } - pub fn alloc_with(module: &Module, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> Self + pub fn alloc(module: &M, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> Self where - Module: VmpPMatAlloc, + M: GLWEToLWESwitchingKeyPreparedAlloc, { - Self(GGLWESwitchingKeyPrepared::alloc_with( - module, - base2k, - k, - rank_in, - Rank(1), - dnum, - Dsize(1), - )) + module.alloc_glwe_to_lwe_switching_key_prepared(base2k, k, rank_in, dnum) } - pub fn alloc_bytes(module: &Module, infos: &A) -> usize + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize where A: GGLWEInfos, - Module: VmpPMatAllocBytes, + M: GLWEToLWESwitchingKeyPreparedAlloc, { - debug_assert_eq!( - infos.rank_out().0, - 1, - "rank_out > 1 is not supported for GLWEToLWESwitchingKeyPrepared" - ); - debug_assert_eq!( - infos.dsize().0, - 1, - "dsize > 1 is not supported for GLWEToLWESwitchingKeyPrepared" - ); - GGLWESwitchingKeyPrepared::alloc_bytes(module, infos) + module.bytes_of_glwe_to_lwe_switching_key_prepared_from_infos(infos) } - pub fn alloc_bytes_with(module: &Module, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> usize + pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> usize where - Module: VmpPMatAllocBytes, + M: GLWEToLWESwitchingKeyPreparedAlloc, { - GGLWESwitchingKeyPrepared::alloc_bytes_with(module, base2k, k, rank_in, Rank(1), dnum, Dsize(1)) + module.bytes_of_glwe_to_lwe_switching_key_prepared(base2k, k, rank_in, dnum) } } -impl PrepareScratchSpace for GLWEToLWESwitchingKeyPrepared, B> +pub trait GLWEToLWESwitchingKeyPrepare where - GGLWESwitchingKeyPrepared, B>: PrepareScratchSpace, + Self: GLWESwitchingKeyPrepare, { - fn prepare_scratch_space(module: &Module, infos: &A) -> usize { - GGLWESwitchingKeyPrepared::prepare_scratch_space(module, infos) + fn prepare_glwe_to_lwe_switching_key_tmp_bytes(&self, infos: &A) + where + A: GGLWEInfos, + { + self.prepare_glwe_switching_key_tmp_bytes(infos); + } + + fn prepare_glwe_to_lwe_switching_key(&self, res: &mut R, other: &O, scratch: &mut Scratch) + where + R: GLWEToLWESwitchingKeyPreparedToMut, + O: GLWEToLWESwitchingKeyToRef, + { + self.prepare_glwe_switching(&mut res.to_mut().0, &other.to_ref().0, scratch); } } -impl PrepareAlloc, B>> for GLWEToLWEKey -where - Module: VmpPrepare + VmpPMatAlloc, -{ - fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> GLWEToLWESwitchingKeyPrepared, B> { - let mut ksk_prepared: GLWEToLWESwitchingKeyPrepared, B> = GLWEToLWESwitchingKeyPrepared::alloc(module, self); - ksk_prepared.prepare(module, self, scratch); - ksk_prepared +impl GLWEToLWESwitchingKeyPrepare for Module where Self: GLWESwitchingKeyPrepare {} + +impl GLWEToLWESwitchingKeyPrepared, B> { + pub fn prepare_tmp_bytes(&self, module: &M, infos: &A) + where + A: GGLWEInfos, + M: GLWEToLWESwitchingKeyPrepare, + { + module.prepare_glwe_to_lwe_switching_key_tmp_bytes(infos); } } -impl Prepare> for GLWEToLWESwitchingKeyPrepared -where - Module: VmpPrepare, -{ - fn prepare(&mut self, module: &Module, other: &GLWEToLWEKey, scratch: &mut Scratch) { - self.0.prepare(module, &other.0, scratch); +impl GLWEToLWESwitchingKeyPrepared { + fn prepare(&mut self, module: &M, other: &O, scratch: &mut Scratch) + where + O: GLWEToLWESwitchingKeyToRef, + M: GLWEToLWESwitchingKeyPrepare, + { + module.prepare_glwe_to_lwe_switching_key(self, other, scratch); + } +} + +pub trait GLWEToLWESwitchingKeyPreparedToRef { + fn to_ref(&self) -> GLWEToLWESwitchingKeyPrepared<&[u8], B>; +} + +impl GLWEToLWESwitchingKeyPreparedToRef for GLWEToLWESwitchingKeyPrepared +where + GLWESwitchingKeyPrepared: GLWESwitchingKeyPreparedToRef, +{ + fn to_ref(&self) -> GLWEToLWESwitchingKeyPrepared<&[u8], B> { + GLWEToLWESwitchingKeyPrepared(self.0.to_ref()) + } +} + +pub trait GLWEToLWESwitchingKeyPreparedToMut { + fn to_mut(&mut self) -> GLWEToLWESwitchingKeyPrepared<&mut [u8], B>; +} + +impl GLWEToLWESwitchingKeyPreparedToMut for GLWEToLWESwitchingKeyPrepared +where + GLWESwitchingKeyPrepared: GLWESwitchingKeyPreparedToMut, +{ + fn to_mut(&mut self) -> GLWEToLWESwitchingKeyPrepared<&mut [u8], B> { + GLWEToLWESwitchingKeyPrepared(self.0.to_mut()) } } diff --git a/poulpy-core/src/layouts/prepared/lwe_ksk.rs b/poulpy-core/src/layouts/prepared/lwe_ksk.rs index 5f0cf14..a857bf9 100644 --- a/poulpy-core/src/layouts/prepared/lwe_ksk.rs +++ b/poulpy-core/src/layouts/prepared/lwe_ksk.rs @@ -1,15 +1,15 @@ -use poulpy_hal::{ - api::{VmpPMatAlloc, VmpPMatAllocBytes, VmpPrepare}, - layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}, -}; +use poulpy_hal::layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}; use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWEInfos, GLWEInfos, LWEInfos, LWESwitchingKey, Rank, TorusPrecision, - prepared::{GGLWESwitchingKeyPrepared, Prepare, PrepareAlloc, PrepareScratchSpace}, + Base2K, Dnum, Dsize, GGLWEInfos, GLWEInfos, LWEInfos, LWESwitchingKeyToRef, Rank, RingDegree, TorusPrecision, + prepared::{ + GLWESwitchingKeyPrepare, GLWESwitchingKeyPrepared, GLWESwitchingKeyPreparedAlloc, GLWESwitchingKeyPreparedToMut, + GLWESwitchingKeyPreparedToRef, + }, }; #[derive(PartialEq, Eq)] -pub struct LWESwitchingKeyPrepared(pub(crate) GGLWESwitchingKeyPrepared); +pub struct LWESwitchingKeyPrepared(pub(crate) GLWESwitchingKeyPrepared); impl LWEInfos for LWESwitchingKeyPrepared { fn base2k(&self) -> Base2K { @@ -20,7 +20,7 @@ impl LWEInfos for LWESwitchingKeyPrepared { self.0.k() } - fn n(&self) -> Degree { + fn n(&self) -> RingDegree { self.0.n() } @@ -52,101 +52,165 @@ impl GGLWEInfos for LWESwitchingKeyPrepared { } } +pub trait LWESwitchingKeyPreparedAlloc +where + Self: GLWESwitchingKeyPreparedAlloc, +{ + fn alloc_lwe_switching_key_prepared( + &self, + base2k: Base2K, + k: TorusPrecision, + dnum: Dnum, + ) -> LWESwitchingKeyPrepared, B> { + LWESwitchingKeyPrepared(self.alloc_glwe_switching_key_prepared(base2k, k, Rank(1), Rank(1), dnum, Dsize(1))) + } + + fn alloc_lwe_switching_key_prepared_from_infos(&self, infos: &A) -> LWESwitchingKeyPrepared, B> + where + A: GGLWEInfos, + { + debug_assert_eq!( + infos.dsize().0, + 1, + "dsize > 1 is not supported for LWESwitchingKey" + ); + debug_assert_eq!( + infos.rank_in().0, + 1, + "rank_in > 1 is not supported for LWESwitchingKey" + ); + debug_assert_eq!( + infos.rank_out().0, + 1, + "rank_out > 1 is not supported for LWESwitchingKey" + ); + self.alloc_lwe_switching_key_prepared(infos.base2k(), infos.k(), infos.dnum()) + } + + fn bytes_of_lwe_switching_key_prepared(&self, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> usize { + self.bytes_of_glwe_switching_key_prepared(base2k, k, Rank(1), Rank(1), dnum, Dsize(1)) + } + + fn bytes_of_lwe_switching_key_prepared_from_infos(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + debug_assert_eq!( + infos.dsize().0, + 1, + "dsize > 1 is not supported for LWESwitchingKey" + ); + debug_assert_eq!( + infos.rank_in().0, + 1, + "rank_in > 1 is not supported for LWESwitchingKey" + ); + debug_assert_eq!( + infos.rank_out().0, + 1, + "rank_out > 1 is not supported for LWESwitchingKey" + ); + self.bytes_of_lwe_switching_key_prepared(infos.base2k(), infos.k(), infos.dnum()) + } +} + +impl LWESwitchingKeyPreparedAlloc for Module where Self: GLWESwitchingKeyPreparedAlloc {} + impl LWESwitchingKeyPrepared, B> { - pub fn alloc(module: &Module, infos: &A) -> Self + pub fn alloc_from_infos(module: &M, infos: &A) -> Self where A: GGLWEInfos, - Module: VmpPMatAlloc, + M: LWESwitchingKeyPreparedAlloc, { - debug_assert_eq!( - infos.dsize().0, - 1, - "dsize > 1 is not supported for LWESwitchingKey" - ); - debug_assert_eq!( - infos.rank_in().0, - 1, - "rank_in > 1 is not supported for LWESwitchingKey" - ); - debug_assert_eq!( - infos.rank_out().0, - 1, - "rank_out > 1 is not supported for LWESwitchingKey" - ); - Self(GGLWESwitchingKeyPrepared::alloc(module, infos)) + module.alloc_lwe_switching_key_prepared_from_infos(infos) } - pub fn alloc_with(module: &Module, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> Self + pub fn alloc(module: &M, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> Self where - Module: VmpPMatAlloc, + M: LWESwitchingKeyPreparedAlloc, { - Self(GGLWESwitchingKeyPrepared::alloc_with( - module, - base2k, - k, - Rank(1), - Rank(1), - dnum, - Dsize(1), - )) + module.alloc_lwe_switching_key_prepared(base2k, k, dnum) } - pub fn alloc_bytes(module: &Module, infos: &A) -> usize + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize where A: GGLWEInfos, - Module: VmpPMatAllocBytes, + M: LWESwitchingKeyPreparedAlloc, { - debug_assert_eq!( - infos.dsize().0, - 1, - "dsize > 1 is not supported for LWESwitchingKey" - ); - debug_assert_eq!( - infos.rank_in().0, - 1, - "rank_in > 1 is not supported for LWESwitchingKey" - ); - debug_assert_eq!( - infos.rank_out().0, - 1, - "rank_out > 1 is not supported for LWESwitchingKey" - ); - GGLWESwitchingKeyPrepared::alloc_bytes(module, infos) + module.bytes_of_lwe_switching_key_prepared_from_infos(infos) } - pub fn alloc_bytes_with(module: &Module, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> usize + pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> usize where - Module: VmpPMatAllocBytes, + M: LWESwitchingKeyPreparedAlloc, { - GGLWESwitchingKeyPrepared::alloc_bytes_with(module, base2k, k, Rank(1), Rank(1), dnum, Dsize(1)) + module.bytes_of_lwe_switching_key_prepared(base2k, k, dnum) } } -impl PrepareScratchSpace for LWESwitchingKeyPrepared, B> +pub trait LWESwitchingKeyPrepare where - GGLWESwitchingKeyPrepared, B>: PrepareScratchSpace, + Self: GLWESwitchingKeyPrepare, { - fn prepare_scratch_space(module: &Module, infos: &A) -> usize { - GGLWESwitchingKeyPrepared::prepare_scratch_space(module, infos) + fn prepare_lwe_switching_key_tmp_bytes(&self, infos: &A) + where + A: GGLWEInfos, + { + self.prepare_glwe_switching_key_tmp_bytes(infos); + } + fn prepare_lwe_switching_key(&self, res: &mut R, other: &O, scratch: &mut Scratch) + where + R: LWESwitchingKeyPreparedToMut, + O: LWESwitchingKeyToRef, + { + self.prepare_glwe_switching(&mut res.to_mut().0, &other.to_ref().0, scratch); } } -impl PrepareAlloc, B>> for LWESwitchingKey -where - Module: VmpPrepare + VmpPMatAlloc, -{ - fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> LWESwitchingKeyPrepared, B> { - let mut ksk_prepared: LWESwitchingKeyPrepared, B> = LWESwitchingKeyPrepared::alloc(module, self); - ksk_prepared.prepare(module, self, scratch); - ksk_prepared +impl LWESwitchingKeyPrepare for Module where Self: GLWESwitchingKeyPrepare {} + +impl LWESwitchingKeyPrepared, B> { + pub fn prepare_tmp_bytes(&self, module: &M, infos: &A) + where + A: GGLWEInfos, + M: LWESwitchingKeyPrepare, + { + module.prepare_lwe_switching_key_tmp_bytes(infos); } } -impl Prepare> for LWESwitchingKeyPrepared -where - Module: VmpPrepare, -{ - fn prepare(&mut self, module: &Module, other: &LWESwitchingKey, scratch: &mut Scratch) { - self.0.prepare(module, &other.0, scratch); +impl LWESwitchingKeyPrepared { + fn prepare(&mut self, module: &M, other: &O, scratch: &mut Scratch) + where + O: LWESwitchingKeyToRef, + M: LWESwitchingKeyPrepare, + { + module.prepare_lwe_switching_key(self, other, scratch); + } +} + +pub trait LWESwitchingKeyPreparedToRef { + fn to_ref(&self) -> LWESwitchingKeyPrepared<&[u8], B>; +} + +impl LWESwitchingKeyPreparedToRef for LWESwitchingKeyPrepared +where + GLWESwitchingKeyPrepared: GLWESwitchingKeyPreparedToRef, +{ + fn to_ref(&self) -> LWESwitchingKeyPrepared<&[u8], B> { + LWESwitchingKeyPrepared(self.0.to_ref()) + } +} + +pub trait LWESwitchingKeyPreparedToMut { + fn to_mut(&mut self) -> LWESwitchingKeyPrepared<&mut [u8], B>; +} + +impl LWESwitchingKeyPreparedToMut for LWESwitchingKeyPrepared +where + GLWESwitchingKeyPrepared: GLWESwitchingKeyPreparedToMut, +{ + fn to_mut(&mut self) -> LWESwitchingKeyPrepared<&mut [u8], B> { + LWESwitchingKeyPrepared(self.0.to_mut()) } } diff --git a/poulpy-core/src/layouts/prepared/lwe_to_glwe_ksk.rs b/poulpy-core/src/layouts/prepared/lwe_to_glwe_ksk.rs index 7c2023a..5df692e 100644 --- a/poulpy-core/src/layouts/prepared/lwe_to_glwe_ksk.rs +++ b/poulpy-core/src/layouts/prepared/lwe_to_glwe_ksk.rs @@ -1,16 +1,16 @@ -use poulpy_hal::{ - api::{VmpPMatAlloc, VmpPMatAllocBytes, VmpPrepare}, - layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}, -}; +use poulpy_hal::layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}; use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWEInfos, GLWEInfos, LWEInfos, LWEToGLWESwitchingKey, Rank, TorusPrecision, - prepared::{GGLWESwitchingKeyPrepared, Prepare, PrepareAlloc, PrepareScratchSpace}, + Base2K, Dnum, Dsize, GGLWEInfos, GLWEInfos, LWEInfos, LWEToGLWESwitchingKeyToRef, Rank, RingDegree, TorusPrecision, + prepared::{ + GLWESwitchingKeyPrepare, GLWESwitchingKeyPrepared, GLWESwitchingKeyPreparedAlloc, GLWESwitchingKeyPreparedToMut, + GLWESwitchingKeyPreparedToRef, + }, }; -/// A special [GLWESwitchingKey] required to for the conversion from [LWECiphertext] to [GLWECiphertext]. +/// A special [GLWESwitchingKey] required to for the conversion from [LWE] to [GLWE]. #[derive(PartialEq, Eq)] -pub struct LWEToGLWESwitchingKeyPrepared(pub(crate) GGLWESwitchingKeyPrepared); +pub struct LWEToGLWESwitchingKeyPrepared(pub(crate) GLWESwitchingKeyPrepared); impl LWEInfos for LWEToGLWESwitchingKeyPrepared { fn base2k(&self) -> Base2K { @@ -21,7 +21,7 @@ impl LWEInfos for LWEToGLWESwitchingKeyPrepared { self.0.k() } - fn n(&self) -> Degree { + fn n(&self) -> RingDegree { self.0.n() } @@ -54,91 +54,162 @@ impl GGLWEInfos for LWEToGLWESwitchingKeyPrepared { } } +pub trait LWEToGLWESwitchingKeyPreparedAlloc +where + Self: GLWESwitchingKeyPreparedAlloc, +{ + fn alloc_lwe_to_glwe_switching_key_prepared( + &self, + base2k: Base2K, + k: TorusPrecision, + rank_out: Rank, + dnum: Dnum, + ) -> LWEToGLWESwitchingKeyPrepared, B> { + LWEToGLWESwitchingKeyPrepared(self.alloc_glwe_switching_key_prepared(base2k, k, Rank(1), rank_out, dnum, Dsize(1))) + } + fn alloc_lwe_to_glwe_switching_key_prepared_from_infos(&self, infos: &A) -> LWEToGLWESwitchingKeyPrepared, B> + where + A: GGLWEInfos, + { + debug_assert_eq!( + infos.rank_in().0, + 1, + "rank_in > 1 is not supported for LWEToGLWESwitchingKey" + ); + debug_assert_eq!( + infos.dsize().0, + 1, + "dsize > 1 is not supported for LWEToGLWESwitchingKey" + ); + self.alloc_lwe_to_glwe_switching_key_prepared(infos.base2k(), infos.k(), infos.rank_out(), infos.dnum()) + } + + fn bytes_of_lwe_to_glwe_switching_key_prepared( + &self, + base2k: Base2K, + k: TorusPrecision, + rank_out: Rank, + dnum: Dnum, + ) -> usize { + self.bytes_of_glwe_switching_key_prepared(base2k, k, Rank(1), rank_out, dnum, Dsize(1)) + } + + fn bytes_of_lwe_to_glwe_switching_key_prepared_from_infos(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + debug_assert_eq!( + infos.rank_in().0, + 1, + "rank_in > 1 is not supported for LWEToGLWESwitchingKey" + ); + debug_assert_eq!( + infos.dsize().0, + 1, + "dsize > 1 is not supported for LWEToGLWESwitchingKey" + ); + self.bytes_of_lwe_to_glwe_switching_key_prepared(infos.base2k(), infos.k(), infos.rank_out(), infos.dnum()) + } +} + +impl LWEToGLWESwitchingKeyPreparedAlloc for Module where Self: GLWESwitchingKeyPreparedAlloc {} + impl LWEToGLWESwitchingKeyPrepared, B> { - pub fn alloc(module: &Module, infos: &A) -> Self + pub fn alloc_from_infos(module: &M, infos: &A) -> Self where A: GGLWEInfos, - Module: VmpPMatAlloc, + M: LWEToGLWESwitchingKeyPreparedAlloc, { - debug_assert_eq!( - infos.rank_in().0, - 1, - "rank_in > 1 is not supported for LWEToGLWESwitchingKey" - ); - debug_assert_eq!( - infos.dsize().0, - 1, - "dsize > 1 is not supported for LWEToGLWESwitchingKey" - ); - Self(GGLWESwitchingKeyPrepared::alloc(module, infos)) + module.alloc_lwe_to_glwe_switching_key_prepared_from_infos(infos) } - pub fn alloc_with(module: &Module, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> Self + pub fn alloc(module: &M, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> Self where - Module: VmpPMatAlloc, + M: LWEToGLWESwitchingKeyPreparedAlloc, { - Self(GGLWESwitchingKeyPrepared::alloc_with( - module, - base2k, - k, - Rank(1), - rank_out, - dnum, - Dsize(1), - )) + module.alloc_lwe_to_glwe_switching_key_prepared(base2k, k, rank_out, dnum) } - pub fn alloc_bytes(module: &Module, infos: &A) -> usize + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize where A: GGLWEInfos, - Module: VmpPMatAllocBytes, + M: LWEToGLWESwitchingKeyPreparedAlloc, { - debug_assert_eq!( - infos.rank_in().0, - 1, - "rank_in > 1 is not supported for LWEToGLWESwitchingKey" - ); - debug_assert_eq!( - infos.dsize().0, - 1, - "dsize > 1 is not supported for LWEToGLWESwitchingKey" - ); - GGLWESwitchingKeyPrepared::alloc_bytes(module, infos) + module.bytes_of_lwe_to_glwe_switching_key_prepared_from_infos(infos) } - pub fn alloc_bytes_with(module: &Module, base2k: Base2K, k: TorusPrecision, dnum: Dnum, rank_out: Rank) -> usize + pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> usize where - Module: VmpPMatAllocBytes, + M: LWEToGLWESwitchingKeyPreparedAlloc, { - GGLWESwitchingKeyPrepared::alloc_bytes_with(module, base2k, k, Rank(1), rank_out, dnum, Dsize(1)) + module.bytes_of_lwe_to_glwe_switching_key_prepared(base2k, k, rank_out, dnum) } } -impl PrepareScratchSpace for LWEToGLWESwitchingKeyPrepared, B> +pub trait LWEToGLWESwitchingKeyPrepare where - GGLWESwitchingKeyPrepared, B>: PrepareScratchSpace, + Self: GLWESwitchingKeyPrepare, { - fn prepare_scratch_space(module: &Module, infos: &A) -> usize { - GGLWESwitchingKeyPrepared::prepare_scratch_space(module, infos) + fn prepare_lwe_to_glwe_switching_key_tmp_bytes(&self, infos: &A) + where + A: GGLWEInfos, + { + self.prepare_glwe_switching_key_tmp_bytes(infos); + } + + fn prepare_lwe_to_glwe_switching_key(&self, res: &mut R, other: &O, scratch: &mut Scratch) + where + R: LWEToGLWESwitchingKeyPreparedToMut, + O: LWEToGLWESwitchingKeyToRef, + { + self.prepare_glwe_switching(&mut res.to_mut().0, &other.to_ref().0, scratch); } } -impl PrepareAlloc, B>> for LWEToGLWESwitchingKey -where - Module: VmpPrepare + VmpPMatAlloc, -{ - fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> LWEToGLWESwitchingKeyPrepared, B> { - let mut ksk_prepared: LWEToGLWESwitchingKeyPrepared, B> = LWEToGLWESwitchingKeyPrepared::alloc(module, self); - ksk_prepared.prepare(module, self, scratch); - ksk_prepared +impl LWEToGLWESwitchingKeyPrepare for Module where Self: GLWESwitchingKeyPrepare {} + +impl LWEToGLWESwitchingKeyPrepared, B> { + pub fn prepare_tmp_bytes(&self, module: &M, infos: &A) + where + A: GGLWEInfos, + M: LWEToGLWESwitchingKeyPrepare, + { + module.prepare_lwe_to_glwe_switching_key_tmp_bytes(infos); } } -impl Prepare> for LWEToGLWESwitchingKeyPrepared -where - Module: VmpPrepare, -{ - fn prepare(&mut self, module: &Module, other: &LWEToGLWESwitchingKey, scratch: &mut Scratch) { - self.0.prepare(module, &other.0, scratch); +impl LWEToGLWESwitchingKeyPrepared { + fn prepare(&mut self, module: &M, other: &O, scratch: &mut Scratch) + where + O: LWEToGLWESwitchingKeyToRef, + M: LWEToGLWESwitchingKeyPrepare, + { + module.prepare_lwe_to_glwe_switching_key(self, other, scratch); + } +} + +pub trait LWEToGLWESwitchingKeyPreparedToRef { + fn to_ref(&self) -> LWEToGLWESwitchingKeyPrepared<&[u8], B>; +} + +impl LWEToGLWESwitchingKeyPreparedToRef for LWEToGLWESwitchingKeyPrepared +where + GLWESwitchingKeyPrepared: GLWESwitchingKeyPreparedToRef, +{ + fn to_ref(&self) -> LWEToGLWESwitchingKeyPrepared<&[u8], B> { + LWEToGLWESwitchingKeyPrepared(self.0.to_ref()) + } +} + +pub trait LWEToGLWESwitchingKeyPreparedToMut { + fn to_mut(&mut self) -> LWEToGLWESwitchingKeyPrepared<&mut [u8], B>; +} + +impl LWEToGLWESwitchingKeyPreparedToMut for LWEToGLWESwitchingKeyPrepared +where + GLWESwitchingKeyPrepared: GLWESwitchingKeyPreparedToMut, +{ + fn to_mut(&mut self) -> LWEToGLWESwitchingKeyPrepared<&mut [u8], B> { + LWEToGLWESwitchingKeyPrepared(self.0.to_mut()) } } diff --git a/poulpy-core/src/layouts/prepared/mod.rs b/poulpy-core/src/layouts/prepared/mod.rs index eb47848..296144a 100644 --- a/poulpy-core/src/layouts/prepared/mod.rs +++ b/poulpy-core/src/layouts/prepared/mod.rs @@ -19,16 +19,3 @@ pub use glwe_sk::*; pub use glwe_to_lwe_ksk::*; pub use lwe_ksk::*; pub use lwe_to_glwe_ksk::*; -use poulpy_hal::layouts::{Backend, Module, Scratch}; - -pub trait PrepareScratchSpace { - fn prepare_scratch_space(module: &Module, infos: &T) -> usize; -} - -pub trait PrepareAlloc { - fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> T; -} - -pub trait Prepare { - fn prepare(&mut self, module: &Module, other: &T, scratch: &mut Scratch); -} diff --git a/poulpy-core/src/lib.rs b/poulpy-core/src/lib.rs index 70035af..15e6c76 100644 --- a/poulpy-core/src/lib.rs +++ b/poulpy-core/src/lib.rs @@ -14,6 +14,7 @@ mod utils; pub use operations::*; pub mod layouts; +pub use conversion::*; pub use dist::*; pub use external_product::*; pub use glwe_packing::*; @@ -22,4 +23,4 @@ pub use encryption::SIGMA; pub use scratch::*; -pub mod tests; +// pub mod tests; diff --git a/poulpy-core/src/noise/gglwe_ct.rs b/poulpy-core/src/noise/gglwe_ct.rs index 0712b7f..516a210 100644 --- a/poulpy-core/src/noise/gglwe_ct.rs +++ b/poulpy-core/src/noise/gglwe_ct.rs @@ -1,16 +1,16 @@ use poulpy_hal::{ api::{ ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, - VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, - VecZnxNormalizeTmpBytes, VecZnxSubScalarInplace, + VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalizeTmpBytes, + VecZnxSubScalarInplace, }, layouts::{Backend, DataRef, Module, ScalarZnx, ScratchOwned, ZnxZero}, - oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl}, + oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl}, }; -use crate::layouts::{GGLWECiphertext, GGLWEInfos, GLWECiphertext, GLWEPlaintext, LWEInfos, prepared::GLWESecretPrepared}; +use crate::layouts::{GGLWE, GGLWEInfos, GLWE, GLWEPlaintext, LWEInfos, prepared::GLWESecretPrepared}; -impl GGLWECiphertext { +impl GGLWE { pub fn assert_noise( &self, module: &Module, @@ -20,8 +20,8 @@ impl GGLWECiphertext { ) where DataSk: DataRef, DataWant: DataRef, - Module: VecZnxDftAllocBytes - + VecZnxBigAllocBytes + Module: VecZnxDftBytesOf + + VecZnxBigBytesOf + VecZnxDftApply + SvpApplyDftToDftInplace + VecZnxIdftApplyConsume @@ -30,13 +30,13 @@ impl GGLWECiphertext { + VecZnxBigNormalize + VecZnxNormalizeTmpBytes + VecZnxSubScalarInplace, - B: Backend + TakeVecZnxDftImpl + TakeVecZnxBigImpl + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, + B: Backend + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, { let dsize: usize = self.dsize().into(); let base2k: usize = self.base2k().into(); - let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space(module, self)); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(self); + let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWE::decrypt_tmp_bytes(module, self)); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(module, self); (0..self.rank_in().into()).for_each(|col_i| { (0..self.dnum().into()).for_each(|row_i| { diff --git a/poulpy-core/src/noise/ggsw_ct.rs b/poulpy-core/src/noise/ggsw_ct.rs index 03bb0c0..92f806b 100644 --- a/poulpy-core/src/noise/ggsw_ct.rs +++ b/poulpy-core/src/noise/ggsw_ct.rs @@ -1,19 +1,17 @@ use poulpy_hal::{ api::{ ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, VecZnxAddScalarInplace, VecZnxBigAddInplace, - VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, - VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, - VecZnxNormalizeTmpBytes, VecZnxSubInplace, + VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, + VecZnxDftAlloc, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalizeTmpBytes, + VecZnxSubInplace, }, layouts::{Backend, DataRef, Module, ScalarZnx, ScratchOwned, VecZnxBig, VecZnxDft, ZnxZero}, - oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl}, + oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl}, }; -use crate::layouts::{ - GGSWCiphertext, GGSWInfos, GLWECiphertext, GLWEInfos, GLWEPlaintext, LWEInfos, prepared::GLWESecretPrepared, -}; +use crate::layouts::{GGSW, GGSWInfos, GLWE, GLWEInfos, GLWEPlaintext, LWEInfos, prepared::GLWESecretPrepared}; -impl GGSWCiphertext { +impl GGSW { pub fn assert_noise( &self, module: &Module, @@ -23,8 +21,8 @@ impl GGSWCiphertext { ) where DataSk: DataRef, DataScalar: DataRef, - Module: VecZnxDftAllocBytes - + VecZnxBigAllocBytes + Module: VecZnxDftBytesOf + + VecZnxBigBytesOf + VecZnxDftApply + SvpApplyDftToDftInplace + VecZnxIdftApplyConsume @@ -38,19 +36,19 @@ impl GGSWCiphertext { + VecZnxIdftApplyTmpA + VecZnxAddScalarInplace + VecZnxSubInplace, - B: Backend + TakeVecZnxDftImpl + TakeVecZnxBigImpl + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, + B: Backend + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, F: Fn(usize) -> f64, { let base2k: usize = self.base2k().into(); let dsize: usize = self.dsize().into(); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(self); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(self); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(module, self); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(module, self); let mut pt_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(1, self.size()); let mut pt_big: VecZnxBig, B> = module.vec_znx_big_alloc(1, self.size()); let mut scratch: ScratchOwned = - ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space(module, self) | module.vec_znx_normalize_tmp_bytes()); + ScratchOwned::alloc(GLWE::decrypt_tmp_bytes(module, self) | module.vec_znx_normalize_tmp_bytes()); (0..(self.rank() + 1).into()).for_each(|col_j| { (0..self.dnum().into()).for_each(|row_i| { @@ -87,7 +85,7 @@ impl GGSWCiphertext { } } -impl GGSWCiphertext { +impl GGSW { pub fn print_noise( &self, module: &Module, @@ -96,8 +94,8 @@ impl GGSWCiphertext { ) where DataSk: DataRef, DataScalar: DataRef, - Module: VecZnxDftAllocBytes - + VecZnxBigAllocBytes + Module: VecZnxDftBytesOf + + VecZnxBigBytesOf + VecZnxDftApply + SvpApplyDftToDftInplace + VecZnxIdftApplyConsume @@ -111,18 +109,18 @@ impl GGSWCiphertext { + VecZnxIdftApplyTmpA + VecZnxAddScalarInplace + VecZnxSubInplace, - B: Backend + TakeVecZnxDftImpl + TakeVecZnxBigImpl + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, + B: Backend + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, { let base2k: usize = self.base2k().into(); let dsize: usize = self.dsize().into(); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(self); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(self); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(module, self); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(module, self); let mut pt_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(1, self.size()); let mut pt_big: VecZnxBig, B> = module.vec_znx_big_alloc(1, self.size()); let mut scratch: ScratchOwned = - ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space(module, self) | module.vec_znx_normalize_tmp_bytes()); + ScratchOwned::alloc(GLWE::decrypt_tmp_bytes(module, self) | module.vec_znx_normalize_tmp_bytes()); (0..(self.rank() + 1).into()).for_each(|col_j| { (0..self.dnum().into()).for_each(|row_i| { diff --git a/poulpy-core/src/noise/glwe_ct.rs b/poulpy-core/src/noise/glwe_ct.rs index f7af2a1..40b86d9 100644 --- a/poulpy-core/src/noise/glwe_ct.rs +++ b/poulpy-core/src/noise/glwe_ct.rs @@ -1,16 +1,16 @@ use poulpy_hal::{ api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAddInplace, - VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, - VecZnxIdftApplyConsume, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSubInplace, + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, + VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalizeInplace, + VecZnxNormalizeTmpBytes, VecZnxSubInplace, }, layouts::{Backend, DataRef, Module, Scratch, ScratchOwned}, - oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl}, + oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl}, }; -use crate::layouts::{GLWECiphertext, GLWEPlaintext, LWEInfos, prepared::GLWESecretPrepared}; +use crate::layouts::{GLWE, GLWEPlaintext, LWEInfos, prepared::GLWESecretPrepared}; -impl GLWECiphertext { +impl GLWE { pub fn noise( &self, module: &Module, @@ -30,9 +30,9 @@ impl GLWECiphertext { + VecZnxBigAddInplace + VecZnxBigAddSmallInplace + VecZnxBigNormalize, - Scratch: TakeVecZnxDft + TakeVecZnxBig, + Scratch:, { - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(self); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(module, self); self.decrypt(module, &mut pt_have, sk_prepared, scratch); module.vec_znx_sub_inplace(&mut pt_have.data, 0, &pt_want.data, 0); module.vec_znx_normalize_inplace(self.base2k().into(), &mut pt_have.data, 0, scratch); @@ -48,8 +48,8 @@ impl GLWECiphertext { ) where DataSk: DataRef, DataPt: DataRef, - Module: VecZnxDftAllocBytes - + VecZnxBigAllocBytes + Module: VecZnxDftBytesOf + + VecZnxBigBytesOf + VecZnxDftApply + SvpApplyDftToDftInplace + VecZnxIdftApplyConsume @@ -59,9 +59,9 @@ impl GLWECiphertext { + VecZnxNormalizeTmpBytes + VecZnxSubInplace + VecZnxNormalizeInplace, - B: Backend + TakeVecZnxDftImpl + TakeVecZnxBigImpl + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, + B: Backend + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, { - let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space(module, self)); + let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWE::decrypt_tmp_bytes(module, self)); let noise_have: f64 = self.noise(module, sk_prepared, pt_want, scratch.borrow()); assert!(noise_have <= max_noise, "{noise_have} {max_noise}"); } diff --git a/poulpy-core/src/operations/glwe.rs b/poulpy-core/src/operations/glwe.rs index b0ee4f6..b8b32ce 100644 --- a/poulpy-core/src/operations/glwe.rs +++ b/poulpy-core/src/operations/glwe.rs @@ -7,24 +7,22 @@ use poulpy_hal::{ layouts::{Backend, DataMut, Module, Scratch, VecZnx, ZnxZero}, }; -use crate::layouts::{ - GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEInfos, GLWELayoutSet, GLWEPlaintext, LWEInfos, TorusPrecision, -}; +use crate::layouts::{GLWE, GLWEInfos, GLWEPlaintext, GLWEToMut, GLWEToRef, LWEInfos, SetGLWEInfos, TorusPrecision}; impl GLWEOperations for GLWEPlaintext where D: DataMut, - GLWEPlaintext: GLWECiphertextToMut + GLWEInfos, + GLWEPlaintext: GLWEToMut + GLWEInfos, { } -impl GLWEOperations for GLWECiphertext where GLWECiphertext: GLWECiphertextToMut + GLWEInfos {} +impl GLWEOperations for GLWE where GLWE: GLWEToMut + GLWEInfos {} -pub trait GLWEOperations: GLWECiphertextToMut + GLWEInfos + GLWELayoutSet + Sized { +pub trait GLWEOperations: GLWEToMut + GLWEInfos + SetGLWEInfos + Sized { fn add(&mut self, module: &Module, a: &A, b: &B) where - A: GLWECiphertextToRef + GLWEInfos, - B: GLWECiphertextToRef + GLWEInfos, + A: GLWEToRef + GLWEInfos, + B: GLWEToRef + GLWEInfos, Module: VecZnxAdd + VecZnxCopy, { #[cfg(debug_assertions)] @@ -39,9 +37,9 @@ pub trait GLWEOperations: GLWECiphertextToMut + GLWEInfos + GLWELayoutSet + Size let max_col: usize = (a.rank().max(b.rank() + 1)).into(); let self_col: usize = (self.rank() + 1).into(); - let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); - let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref(); - let b_ref: &GLWECiphertext<&[u8]> = &b.to_ref(); + let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut(); + let a_ref: &GLWE<&[u8]> = &a.to_ref(); + let b_ref: &GLWE<&[u8]> = &b.to_ref(); (0..min_col).for_each(|i| { module.vec_znx_add(&mut self_mut.data, i, &a_ref.data, i, &b_ref.data, i); @@ -64,13 +62,13 @@ pub trait GLWEOperations: GLWECiphertextToMut + GLWEInfos + GLWELayoutSet + Size }); }); - self.set_basek(a.base2k()); + self.set_base2k(a.base2k()); self.set_k(set_k_binary(self, a, b)); } fn add_inplace(&mut self, module: &Module, a: &A) where - A: GLWECiphertextToRef + GLWEInfos, + A: GLWEToRef + GLWEInfos, Module: VecZnxAddInplace, { #[cfg(debug_assertions)] @@ -80,8 +78,8 @@ pub trait GLWEOperations: GLWECiphertextToMut + GLWEInfos + GLWELayoutSet + Size assert!(self.rank() >= a.rank()) } - let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); - let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref(); + let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut(); + let a_ref: &GLWE<&[u8]> = &a.to_ref(); (0..(a.rank() + 1).into()).for_each(|i| { module.vec_znx_add_inplace(&mut self_mut.data, i, &a_ref.data, i); @@ -92,8 +90,8 @@ pub trait GLWEOperations: GLWECiphertextToMut + GLWEInfos + GLWELayoutSet + Size fn sub(&mut self, module: &Module, a: &A, b: &B) where - A: GLWECiphertextToRef + GLWEInfos, - B: GLWECiphertextToRef + GLWEInfos, + A: GLWEToRef + GLWEInfos, + B: GLWEToRef + GLWEInfos, Module: VecZnxSub + VecZnxCopy + VecZnxNegateInplace, { #[cfg(debug_assertions)] @@ -108,9 +106,9 @@ pub trait GLWEOperations: GLWECiphertextToMut + GLWEInfos + GLWELayoutSet + Size let max_col: usize = (a.rank().max(b.rank() + 1)).into(); let self_col: usize = (self.rank() + 1).into(); - let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); - let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref(); - let b_ref: &GLWECiphertext<&[u8]> = &b.to_ref(); + let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut(); + let a_ref: &GLWE<&[u8]> = &a.to_ref(); + let b_ref: &GLWE<&[u8]> = &b.to_ref(); (0..min_col).for_each(|i| { module.vec_znx_sub(&mut self_mut.data, i, &a_ref.data, i, &b_ref.data, i); @@ -134,13 +132,13 @@ pub trait GLWEOperations: GLWECiphertextToMut + GLWEInfos + GLWELayoutSet + Size }); }); - self.set_basek(a.base2k()); + self.set_base2k(a.base2k()); self.set_k(set_k_binary(self, a, b)); } fn sub_inplace_ab(&mut self, module: &Module, a: &A) where - A: GLWECiphertextToRef + GLWEInfos, + A: GLWEToRef + GLWEInfos, Module: VecZnxSubInplace, { #[cfg(debug_assertions)] @@ -150,8 +148,8 @@ pub trait GLWEOperations: GLWECiphertextToMut + GLWEInfos + GLWELayoutSet + Size assert!(self.rank() >= a.rank()) } - let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); - let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref(); + let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut(); + let a_ref: &GLWE<&[u8]> = &a.to_ref(); (0..(a.rank() + 1).into()).for_each(|i| { module.vec_znx_sub_inplace(&mut self_mut.data, i, &a_ref.data, i); @@ -162,7 +160,7 @@ pub trait GLWEOperations: GLWECiphertextToMut + GLWEInfos + GLWELayoutSet + Size fn sub_inplace_ba(&mut self, module: &Module, a: &A) where - A: GLWECiphertextToRef + GLWEInfos, + A: GLWEToRef + GLWEInfos, Module: VecZnxSubNegateInplace, { #[cfg(debug_assertions)] @@ -172,8 +170,8 @@ pub trait GLWEOperations: GLWECiphertextToMut + GLWEInfos + GLWELayoutSet + Size assert!(self.rank() >= a.rank()) } - let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); - let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref(); + let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut(); + let a_ref: &GLWE<&[u8]> = &a.to_ref(); (0..(a.rank() + 1).into()).for_each(|i| { module.vec_znx_sub_negate_inplace(&mut self_mut.data, i, &a_ref.data, i); @@ -184,7 +182,7 @@ pub trait GLWEOperations: GLWECiphertextToMut + GLWEInfos + GLWELayoutSet + Size fn rotate(&mut self, module: &Module, k: i64, a: &A) where - A: GLWECiphertextToRef + GLWEInfos, + A: GLWEToRef + GLWEInfos, Module: VecZnxRotate, { #[cfg(debug_assertions)] @@ -193,14 +191,14 @@ pub trait GLWEOperations: GLWECiphertextToMut + GLWEInfos + GLWELayoutSet + Size assert_eq!(self.rank(), a.rank()) } - let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); - let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref(); + let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut(); + let a_ref: &GLWE<&[u8]> = &a.to_ref(); (0..(a.rank() + 1).into()).for_each(|i| { module.vec_znx_rotate(k, &mut self_mut.data, i, &a_ref.data, i); }); - self.set_basek(a.base2k()); + self.set_base2k(a.base2k()); self.set_k(set_k_unary(self, a)) } @@ -208,7 +206,7 @@ pub trait GLWEOperations: GLWECiphertextToMut + GLWEInfos + GLWELayoutSet + Size where Module: VecZnxRotateInplace, { - let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); + let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut(); (0..(self_mut.rank() + 1).into()).for_each(|i| { module.vec_znx_rotate_inplace(k, &mut self_mut.data, i, scratch); @@ -217,7 +215,7 @@ pub trait GLWEOperations: GLWECiphertextToMut + GLWEInfos + GLWELayoutSet + Size fn mul_xp_minus_one(&mut self, module: &Module, k: i64, a: &A) where - A: GLWECiphertextToRef + GLWEInfos, + A: GLWEToRef + GLWEInfos, Module: VecZnxMulXpMinusOne, { #[cfg(debug_assertions)] @@ -226,14 +224,14 @@ pub trait GLWEOperations: GLWECiphertextToMut + GLWEInfos + GLWELayoutSet + Size assert_eq!(self.rank(), a.rank()) } - let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); - let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref(); + let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut(); + let a_ref: &GLWE<&[u8]> = &a.to_ref(); (0..(a.rank() + 1).into()).for_each(|i| { module.vec_znx_mul_xp_minus_one(k, &mut self_mut.data, i, &a_ref.data, i); }); - self.set_basek(a.base2k()); + self.set_base2k(a.base2k()); self.set_k(set_k_unary(self, a)) } @@ -241,17 +239,17 @@ pub trait GLWEOperations: GLWECiphertextToMut + GLWEInfos + GLWELayoutSet + Size where Module: VecZnxMulXpMinusOneInplace, { - let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); + let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut(); (0..(self_mut.rank() + 1).into()).for_each(|i| { module.vec_znx_mul_xp_minus_one_inplace(k, &mut self_mut.data, i, scratch); }); } - fn copy(&mut self, module: &Module, a: &A) + fn copy(&mut self, module: &M, a: &A) where - A: GLWECiphertextToRef + GLWEInfos, - Module: VecZnxCopy, + A: GLWEToRef + GLWEInfos, + M: VecZnxCopy, { #[cfg(debug_assertions)] { @@ -259,15 +257,15 @@ pub trait GLWEOperations: GLWECiphertextToMut + GLWEInfos + GLWELayoutSet + Size assert_eq!(self.rank(), a.rank()); } - let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); - let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref(); + let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut(); + let a_ref: &GLWE<&[u8]> = &a.to_ref(); (0..(self_mut.rank() + 1).into()).for_each(|i| { module.vec_znx_copy(&mut self_mut.data, i, &a_ref.data, i); }); self.set_k(a.k().min(self.max_k())); - self.set_basek(a.base2k()); + self.set_base2k(a.base2k()); } fn rsh(&mut self, module: &Module, k: usize, scratch: &mut Scratch) @@ -282,7 +280,7 @@ pub trait GLWEOperations: GLWECiphertextToMut + GLWEInfos + GLWELayoutSet + Size fn normalize(&mut self, module: &Module, a: &A, scratch: &mut Scratch) where - A: GLWECiphertextToRef + GLWEInfos, + A: GLWEToRef + GLWEInfos, Module: VecZnxNormalize, { #[cfg(debug_assertions)] @@ -291,8 +289,8 @@ pub trait GLWEOperations: GLWECiphertextToMut + GLWEInfos + GLWELayoutSet + Size assert_eq!(self.rank(), a.rank()); } - let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); - let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref(); + let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut(); + let a_ref: &GLWE<&[u8]> = &a.to_ref(); (0..(self_mut.rank() + 1).into()).for_each(|i| { module.vec_znx_normalize( @@ -305,7 +303,7 @@ pub trait GLWEOperations: GLWECiphertextToMut + GLWEInfos + GLWELayoutSet + Size scratch, ); }); - self.set_basek(a.base2k()); + self.set_base2k(a.base2k()); self.set_k(a.k().min(self.k())); } @@ -313,16 +311,16 @@ pub trait GLWEOperations: GLWECiphertextToMut + GLWEInfos + GLWELayoutSet + Size where Module: VecZnxNormalizeInplace, { - let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); + let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut(); (0..(self_mut.rank() + 1).into()).for_each(|i| { module.vec_znx_normalize_inplace(self_mut.base2k().into(), &mut self_mut.data, i, scratch); }); } } -impl GLWECiphertext> { - pub fn rsh_scratch_space(n: usize) -> usize { - VecZnx::rsh_scratch_space(n) +impl GLWE> { + pub fn rsh_tmp_bytes(n: usize) -> usize { + VecZnx::rsh_tmp_bytes(n) } } diff --git a/poulpy-core/src/scratch.rs b/poulpy-core/src/scratch.rs index 1a5a6ce..d1e95f9 100644 --- a/poulpy-core/src/scratch.rs +++ b/poulpy-core/src/scratch.rs @@ -1,365 +1,230 @@ use poulpy_hal::{ - api::{TakeMatZnx, TakeScalarZnx, TakeSvpPPol, TakeVecZnx, TakeVecZnxDft, TakeVmpPMat}, + api::{ModuleN, ScratchAvailable, ScratchTakeBasic, SvpPPolBytesOf, VecZnxDftBytesOf, VmpPMatBytesOf}, layouts::{Backend, Scratch}, }; use crate::{ dist::Distribution, layouts::{ - Degree, GGLWEAutomorphismKey, GGLWECiphertext, GGLWEInfos, GGLWESwitchingKey, GGLWETensorKey, GGSWCiphertext, GGSWInfos, - GLWECiphertext, GLWEInfos, GLWEPlaintext, GLWEPublicKey, GLWESecret, Rank, + AutomorphismKey, GGLWE, GGLWEInfos, GGSW, GGSWInfos, GLWE, GLWEInfos, GLWEPlaintext, GLWEPublicKey, GLWESecret, + GLWESwitchingKey, Rank, TensorKey, prepared::{ - GGLWEAutomorphismKeyPrepared, GGLWECiphertextPrepared, GGLWESwitchingKeyPrepared, GGLWETensorKeyPrepared, - GGSWCiphertextPrepared, GLWEPublicKeyPrepared, GLWESecretPrepared, + AutomorphismKeyPrepared, GGLWEPrepared, GGSWPrepared, GLWEPublicKeyPrepared, GLWESecretPrepared, + GLWESwitchingKeyPrepared, TensorKeyPrepared, }, }, }; -pub trait TakeGLWECt { - fn take_glwe_ct(&mut self, infos: &A) -> (GLWECiphertext<&mut [u8]>, &mut Self) - where - A: GLWEInfos; -} - -pub trait TakeGLWECtSlice { - fn take_glwe_ct_slice(&mut self, size: usize, infos: &A) -> (Vec>, &mut Self) - where - A: GLWEInfos; -} - -pub trait TakeGLWEPt { - fn take_glwe_pt(&mut self, infos: &A) -> (GLWEPlaintext<&mut [u8]>, &mut Self) - where - A: GLWEInfos; -} - -pub trait TakeGGLWE { - fn take_gglwe(&mut self, infos: &A) -> (GGLWECiphertext<&mut [u8]>, &mut Self) - where - A: GGLWEInfos; -} - -pub trait TakeGGLWEPrepared { - fn take_gglwe_prepared(&mut self, infos: &A) -> (GGLWECiphertextPrepared<&mut [u8], B>, &mut Self) - where - A: GGLWEInfos; -} - -pub trait TakeGGSW { - fn take_ggsw(&mut self, infos: &A) -> (GGSWCiphertext<&mut [u8]>, &mut Self) - where - A: GGSWInfos; -} - -pub trait TakeGGSWPrepared { - fn take_ggsw_prepared(&mut self, infos: &A) -> (GGSWCiphertextPrepared<&mut [u8], B>, &mut Self) - where - A: GGSWInfos; -} - -pub trait TakeGGSWPreparedSlice { - fn take_ggsw_prepared_slice(&mut self, size: usize, infos: &A) -> (Vec>, &mut Self) - where - A: GGSWInfos; -} - -pub trait TakeGLWESecret { - fn take_glwe_secret(&mut self, n: Degree, rank: Rank) -> (GLWESecret<&mut [u8]>, &mut Self); -} - -pub trait TakeGLWESecretPrepared { - fn take_glwe_secret_prepared(&mut self, n: Degree, rank: Rank) -> (GLWESecretPrepared<&mut [u8], B>, &mut Self); -} - -pub trait TakeGLWEPk { - fn take_glwe_pk(&mut self, infos: &A) -> (GLWEPublicKey<&mut [u8]>, &mut Self) - where - A: GLWEInfos; -} - -pub trait TakeGLWEPkPrepared { - fn take_glwe_pk_prepared(&mut self, infos: &A) -> (GLWEPublicKeyPrepared<&mut [u8], B>, &mut Self) - where - A: GLWEInfos; -} - -pub trait TakeGLWESwitchingKey { - fn take_glwe_switching_key(&mut self, infos: &A) -> (GGLWESwitchingKey<&mut [u8]>, &mut Self) - where - A: GGLWEInfos; -} - -pub trait TakeGGLWESwitchingKeyPrepared { - fn take_gglwe_switching_key_prepared(&mut self, infos: &A) -> (GGLWESwitchingKeyPrepared<&mut [u8], B>, &mut Self) - where - A: GGLWEInfos; -} - -pub trait TakeTensorKey { - fn take_tensor_key(&mut self, infos: &A) -> (GGLWETensorKey<&mut [u8]>, &mut Self) - where - A: GGLWEInfos; -} - -pub trait TakeGGLWETensorKeyPrepared { - fn take_gglwe_tensor_key_prepared(&mut self, infos: &A) -> (GGLWETensorKeyPrepared<&mut [u8], B>, &mut Self) - where - A: GGLWEInfos; -} - -pub trait TakeGGLWEAutomorphismKey { - fn take_gglwe_automorphism_key(&mut self, infos: &A) -> (GGLWEAutomorphismKey<&mut [u8]>, &mut Self) - where - A: GGLWEInfos; -} - -pub trait TakeGGLWEAutomorphismKeyPrepared { - fn take_gglwe_automorphism_key_prepared(&mut self, infos: &A) -> (GGLWEAutomorphismKeyPrepared<&mut [u8], B>, &mut Self) - where - A: GGLWEInfos; -} - -impl TakeGLWECt for Scratch +pub trait ScratchTakeCore where - Scratch: TakeVecZnx, + Self: ScratchTakeBasic + ScratchAvailable, { - fn take_glwe_ct(&mut self, infos: &A) -> (GLWECiphertext<&mut [u8]>, &mut Self) + fn take_glwe_ct(&mut self, module: &M, infos: &A) -> (GLWE<&mut [u8]>, &mut Self) where A: GLWEInfos, + M: ModuleN, { - let (data, scratch) = self.take_vec_znx(infos.n().into(), (infos.rank() + 1).into(), infos.size()); + assert_eq!(module.n() as u32, infos.n()); + let (data, scratch) = self.take_vec_znx(module, (infos.rank() + 1).into(), infos.size()); ( - GLWECiphertext::builder() - .base2k(infos.base2k()) - .k(infos.k()) - .data(data) - .build() - .unwrap(), + GLWE { + k: infos.k(), + base2k: infos.base2k(), + data, + }, scratch, ) } -} -impl TakeGLWECtSlice for Scratch -where - Scratch: TakeVecZnx, -{ - fn take_glwe_ct_slice(&mut self, size: usize, infos: &A) -> (Vec>, &mut Self) + fn take_glwe_ct_slice(&mut self, module: &M, size: usize, infos: &A) -> (Vec>, &mut Self) where A: GLWEInfos, + M: ModuleN, { - let mut scratch: &mut Scratch = self; - let mut cts: Vec> = Vec::with_capacity(size); + let mut scratch: &mut Self = self; + let mut cts: Vec> = Vec::with_capacity(size); for _ in 0..size { - let (ct, new_scratch) = scratch.take_glwe_ct(infos); + let (ct, new_scratch) = scratch.take_glwe_ct(module, infos); scratch = new_scratch; cts.push(ct); } (cts, scratch) } -} -impl TakeGLWEPt for Scratch -where - Scratch: TakeVecZnx, -{ - fn take_glwe_pt(&mut self, infos: &A) -> (GLWEPlaintext<&mut [u8]>, &mut Self) + fn take_glwe_pt(&mut self, module: &M, infos: &A) -> (GLWEPlaintext<&mut [u8]>, &mut Self) where A: GLWEInfos, + M: ModuleN, { - let (data, scratch) = self.take_vec_znx(infos.n().into(), 1, infos.size()); + assert_eq!(module.n() as u32, infos.n()); + let (data, scratch) = self.take_vec_znx(module, 1, infos.size()); ( - GLWEPlaintext::builder() - .base2k(infos.base2k()) - .k(infos.k()) - .data(data) - .build() - .unwrap(), + GLWEPlaintext { + k: infos.k(), + base2k: infos.base2k(), + data, + }, scratch, ) } -} -impl TakeGGLWE for Scratch -where - Scratch: TakeMatZnx, -{ - fn take_gglwe(&mut self, infos: &A) -> (GGLWECiphertext<&mut [u8]>, &mut Self) + fn take_gglwe(&mut self, module: &M, infos: &A) -> (GGLWE<&mut [u8]>, &mut Self) where A: GGLWEInfos, + M: ModuleN, { + assert_eq!(module.n() as u32, infos.n()); let (data, scratch) = self.take_mat_znx( - infos.n().into(), + module, infos.dnum().0.div_ceil(infos.dsize().0) as usize, infos.rank_in().into(), (infos.rank_out() + 1).into(), infos.size(), ); ( - GGLWECiphertext::builder() - .base2k(infos.base2k()) - .k(infos.k()) - .dsize(infos.dsize()) - .data(data) - .build() - .unwrap(), + GGLWE { + k: infos.k(), + base2k: infos.base2k(), + dsize: infos.dsize(), + data, + }, scratch, ) } -} -impl TakeGGLWEPrepared for Scratch -where - Scratch: TakeVmpPMat, -{ - fn take_gglwe_prepared(&mut self, infos: &A) -> (GGLWECiphertextPrepared<&mut [u8], B>, &mut Self) + fn take_gglwe_prepared(&mut self, module: &M, infos: &A) -> (GGLWEPrepared<&mut [u8], B>, &mut Self) where A: GGLWEInfos, + M: ModuleN + VmpPMatBytesOf, { + assert_eq!(module.n() as u32, infos.n()); let (data, scratch) = self.take_vmp_pmat( - infos.n().into(), + module, infos.dnum().into(), infos.rank_in().into(), (infos.rank_out() + 1).into(), infos.size(), ); ( - GGLWECiphertextPrepared::builder() - .base2k(infos.base2k()) - .dsize(infos.dsize()) - .k(infos.k()) - .data(data) - .build() - .unwrap(), + GGLWEPrepared { + k: infos.k(), + base2k: infos.base2k(), + dsize: infos.dsize(), + data, + }, scratch, ) } -} -impl TakeGGSW for Scratch -where - Scratch: TakeMatZnx, -{ - fn take_ggsw(&mut self, infos: &A) -> (GGSWCiphertext<&mut [u8]>, &mut Self) + fn take_ggsw(&mut self, module: &M, infos: &A) -> (GGSW<&mut [u8]>, &mut Self) where A: GGSWInfos, + M: ModuleN, { + assert_eq!(module.n() as u32, infos.n()); let (data, scratch) = self.take_mat_znx( - infos.n().into(), + module, infos.dnum().into(), (infos.rank() + 1).into(), (infos.rank() + 1).into(), infos.size(), ); ( - GGSWCiphertext::builder() - .base2k(infos.base2k()) - .dsize(infos.dsize()) - .k(infos.k()) - .data(data) - .build() - .unwrap(), + GGSW { + k: infos.k(), + base2k: infos.base2k(), + dsize: infos.dsize(), + data, + }, scratch, ) } -} -impl TakeGGSWPrepared for Scratch -where - Scratch: TakeVmpPMat, -{ - fn take_ggsw_prepared(&mut self, infos: &A) -> (GGSWCiphertextPrepared<&mut [u8], B>, &mut Self) + fn take_ggsw_prepared(&mut self, module: &M, infos: &A) -> (GGSWPrepared<&mut [u8], B>, &mut Self) where A: GGSWInfos, + M: ModuleN + VmpPMatBytesOf, { + assert_eq!(module.n() as u32, infos.n()); let (data, scratch) = self.take_vmp_pmat( - infos.n().into(), + module, infos.dnum().into(), (infos.rank() + 1).into(), (infos.rank() + 1).into(), infos.size(), ); ( - GGSWCiphertextPrepared::builder() - .base2k(infos.base2k()) - .dsize(infos.dsize()) - .k(infos.k()) - .data(data) - .build() - .unwrap(), + GGSWPrepared { + k: infos.k(), + base2k: infos.base2k(), + dsize: infos.dsize(), + data, + }, scratch, ) } -} -impl TakeGGSWPreparedSlice for Scratch -where - Scratch: TakeGGSWPrepared, -{ - fn take_ggsw_prepared_slice(&mut self, size: usize, infos: &A) -> (Vec>, &mut Self) + fn take_ggsw_prepared_slice( + &mut self, + module: &M, + size: usize, + infos: &A, + ) -> (Vec>, &mut Self) where A: GGSWInfos, + M: ModuleN + VmpPMatBytesOf, { - let mut scratch: &mut Scratch = self; - let mut cts: Vec> = Vec::with_capacity(size); + let mut scratch: &mut Self = self; + let mut cts: Vec> = Vec::with_capacity(size); for _ in 0..size { - let (ct, new_scratch) = scratch.take_ggsw_prepared(infos); + let (ct, new_scratch) = scratch.take_ggsw_prepared(module, infos); scratch = new_scratch; cts.push(ct) } (cts, scratch) } -} -impl TakeGLWEPk for Scratch -where - Scratch: TakeVecZnx, -{ - fn take_glwe_pk(&mut self, infos: &A) -> (GLWEPublicKey<&mut [u8]>, &mut Self) + fn take_glwe_pk(&mut self, module: &M, infos: &A) -> (GLWEPublicKey<&mut [u8]>, &mut Self) where A: GLWEInfos, + M: ModuleN, { - let (data, scratch) = self.take_vec_znx(infos.n().into(), (infos.rank() + 1).into(), infos.size()); + assert_eq!(module.n() as u32, infos.n()); + let (data, scratch) = self.take_vec_znx(module, (infos.rank() + 1).into(), infos.size()); ( - GLWEPublicKey::builder() - .base2k(infos.base2k()) - .k(infos.k()) - .base2k(infos.base2k()) - .data(data) - .build() - .unwrap(), + GLWEPublicKey { + k: infos.k(), + dist: Distribution::NONE, + base2k: infos.base2k(), + data, + }, scratch, ) } -} -impl TakeGLWEPkPrepared for Scratch -where - Scratch: TakeVecZnxDft, -{ - fn take_glwe_pk_prepared(&mut self, infos: &A) -> (GLWEPublicKeyPrepared<&mut [u8], B>, &mut Self) + fn take_glwe_pk_prepared(&mut self, module: &M, infos: &A) -> (GLWEPublicKeyPrepared<&mut [u8], B>, &mut Self) where A: GLWEInfos, + M: ModuleN + VecZnxDftBytesOf, { - let (data, scratch) = self.take_vec_znx_dft(infos.n().into(), (infos.rank() + 1).into(), infos.size()); + assert_eq!(module.n() as u32, infos.n()); + let (data, scratch) = self.take_vec_znx_dft(module, (infos.rank() + 1).into(), infos.size()); ( - GLWEPublicKeyPrepared::builder() - .base2k(infos.base2k()) - .k(infos.k()) - .data(data) - .build() - .unwrap(), + GLWEPublicKeyPrepared { + k: infos.k(), + dist: Distribution::NONE, + base2k: infos.base2k(), + data, + }, scratch, ) } -} -impl TakeGLWESecret for Scratch -where - Scratch: TakeScalarZnx, -{ - fn take_glwe_secret(&mut self, n: Degree, rank: Rank) -> (GLWESecret<&mut [u8]>, &mut Self) { - let (data, scratch) = self.take_scalar_znx(n.into(), rank.into()); + fn take_glwe_secret(&mut self, module: &M, rank: Rank) -> (GLWESecret<&mut [u8]>, &mut Self) + where + M: ModuleN, + { + let (data, scratch) = self.take_scalar_znx(module, rank.into()); ( GLWESecret { data, @@ -368,14 +233,12 @@ where scratch, ) } -} -impl TakeGLWESecretPrepared for Scratch -where - Scratch: TakeSvpPPol, -{ - fn take_glwe_secret_prepared(&mut self, n: Degree, rank: Rank) -> (GLWESecretPrepared<&mut [u8], B>, &mut Self) { - let (data, scratch) = self.take_svp_ppol(n.into(), rank.into()); + fn take_glwe_secret_prepared(&mut self, module: &M, rank: Rank) -> (GLWESecretPrepared<&mut [u8], B>, &mut Self) + where + M: ModuleN + SvpPPolBytesOf, + { + let (data, scratch) = self.take_svp_ppol(module, rank.into()); ( GLWESecretPrepared { data, @@ -384,19 +247,16 @@ where scratch, ) } -} -impl TakeGLWESwitchingKey for Scratch -where - Scratch: TakeMatZnx, -{ - fn take_glwe_switching_key(&mut self, infos: &A) -> (GGLWESwitchingKey<&mut [u8]>, &mut Self) + fn take_glwe_switching_key(&mut self, module: &M, infos: &A) -> (GLWESwitchingKey<&mut [u8]>, &mut Self) where A: GGLWEInfos, + M: ModuleN, { - let (data, scratch) = self.take_gglwe(infos); + assert_eq!(module.n() as u32, infos.n()); + let (data, scratch) = self.take_gglwe(module, infos); ( - GGLWESwitchingKey { + GLWESwitchingKey { key: data, sk_in_n: 0, sk_out_n: 0, @@ -404,19 +264,20 @@ where scratch, ) } -} -impl TakeGGLWESwitchingKeyPrepared for Scratch -where - Scratch: TakeGGLWEPrepared, -{ - fn take_gglwe_switching_key_prepared(&mut self, infos: &A) -> (GGLWESwitchingKeyPrepared<&mut [u8], B>, &mut Self) + fn take_gglwe_switching_key_prepared( + &mut self, + module: &M, + infos: &A, + ) -> (GLWESwitchingKeyPrepared<&mut [u8], B>, &mut Self) where A: GGLWEInfos, + M: ModuleN + VmpPMatBytesOf, { - let (data, scratch) = self.take_gglwe_prepared(infos); + assert_eq!(module.n() as u32, infos.n()); + let (data, scratch) = self.take_gglwe_prepared(module, infos); ( - GGLWESwitchingKeyPrepared { + GLWESwitchingKeyPrepared { key: data, sk_in_n: 0, sk_out_n: 0, @@ -424,101 +285,95 @@ where scratch, ) } -} -impl TakeGGLWEAutomorphismKey for Scratch -where - Scratch: TakeMatZnx, -{ - fn take_gglwe_automorphism_key(&mut self, infos: &A) -> (GGLWEAutomorphismKey<&mut [u8]>, &mut Self) + fn take_gglwe_automorphism_key(&mut self, module: &M, infos: &A) -> (AutomorphismKey<&mut [u8]>, &mut Self) where A: GGLWEInfos, + M: ModuleN, { - let (data, scratch) = self.take_glwe_switching_key(infos); - (GGLWEAutomorphismKey { key: data, p: 0 }, scratch) + assert_eq!(module.n() as u32, infos.n()); + let (data, scratch) = self.take_glwe_switching_key(module, infos); + (AutomorphismKey { key: data, p: 0 }, scratch) } -} -impl TakeGGLWEAutomorphismKeyPrepared for Scratch -where - Scratch: TakeGGLWESwitchingKeyPrepared, -{ - fn take_gglwe_automorphism_key_prepared(&mut self, infos: &A) -> (GGLWEAutomorphismKeyPrepared<&mut [u8], B>, &mut Self) + fn take_gglwe_automorphism_key_prepared( + &mut self, + module: &M, + infos: &A, + ) -> (AutomorphismKeyPrepared<&mut [u8], B>, &mut Self) where A: GGLWEInfos, + M: ModuleN + VmpPMatBytesOf, { - let (data, scratch) = self.take_gglwe_switching_key_prepared(infos); - (GGLWEAutomorphismKeyPrepared { key: data, p: 0 }, scratch) + assert_eq!(module.n() as u32, infos.n()); + let (data, scratch) = self.take_gglwe_switching_key_prepared(module, infos); + (AutomorphismKeyPrepared { key: data, p: 0 }, scratch) } -} -impl TakeTensorKey for Scratch -where - Scratch: TakeMatZnx, -{ - fn take_tensor_key(&mut self, infos: &A) -> (GGLWETensorKey<&mut [u8]>, &mut Self) + fn take_tensor_key(&mut self, module: &M, infos: &A) -> (TensorKey<&mut [u8]>, &mut Self) where A: GGLWEInfos, + M: ModuleN, { + assert_eq!(module.n() as u32, infos.n()); assert_eq!( infos.rank_in(), infos.rank_out(), "rank_in != rank_out is not supported for GGLWETensorKey" ); - let mut keys: Vec> = Vec::new(); + let mut keys: Vec> = Vec::new(); let pairs: usize = (((infos.rank_out().0 + 1) * infos.rank_out().0) >> 1).max(1) as usize; - let mut scratch: &mut Scratch = self; + let mut scratch: &mut Self = self; - let mut ksk_infos: crate::layouts::GGLWECiphertextLayout = infos.layout(); + let mut ksk_infos: crate::layouts::GGLWELayout = infos.gglwe_layout(); ksk_infos.rank_in = Rank(1); if pairs != 0 { - let (gglwe, s) = scratch.take_glwe_switching_key(&ksk_infos); + let (gglwe, s) = scratch.take_glwe_switching_key(module, &ksk_infos); scratch = s; keys.push(gglwe); } for _ in 1..pairs { - let (gglwe, s) = scratch.take_glwe_switching_key(&ksk_infos); + let (gglwe, s) = scratch.take_glwe_switching_key(module, &ksk_infos); scratch = s; keys.push(gglwe); } - (GGLWETensorKey { keys }, scratch) + (TensorKey { keys }, scratch) } -} -impl TakeGGLWETensorKeyPrepared for Scratch -where - Scratch: TakeVmpPMat, -{ - fn take_gglwe_tensor_key_prepared(&mut self, infos: &A) -> (GGLWETensorKeyPrepared<&mut [u8], B>, &mut Self) + fn take_gglwe_tensor_key_prepared(&mut self, module: &M, infos: &A) -> (TensorKeyPrepared<&mut [u8], B>, &mut Self) where A: GGLWEInfos, + M: ModuleN + VmpPMatBytesOf, { + assert_eq!(module.n() as u32, infos.n()); assert_eq!( infos.rank_in(), infos.rank_out(), "rank_in != rank_out is not supported for GGLWETensorKeyPrepared" ); - let mut keys: Vec> = Vec::new(); + let mut keys: Vec> = Vec::new(); let pairs: usize = (((infos.rank_out().0 + 1) * infos.rank_out().0) >> 1).max(1) as usize; - let mut scratch: &mut Scratch = self; + let mut scratch: &mut Self = self; - let mut ksk_infos: crate::layouts::GGLWECiphertextLayout = infos.layout(); + let mut ksk_infos: crate::layouts::GGLWELayout = infos.gglwe_layout(); ksk_infos.rank_in = Rank(1); if pairs != 0 { - let (gglwe, s) = scratch.take_gglwe_switching_key_prepared(&ksk_infos); + let (gglwe, s) = scratch.take_gglwe_switching_key_prepared(module, &ksk_infos); scratch = s; keys.push(gglwe); } for _ in 1..pairs { - let (gglwe, s) = scratch.take_gglwe_switching_key_prepared(&ksk_infos); + let (gglwe, s) = scratch.take_gglwe_switching_key_prepared(module, &ksk_infos); scratch = s; keys.push(gglwe); } - (GGLWETensorKeyPrepared { keys }, scratch) + (TensorKeyPrepared { keys }, scratch) } } + +impl ScratchTakeCore for Scratch where Self: ScratchTakeBasic + ScratchAvailable {} diff --git a/poulpy-core/src/tests/serialization.rs b/poulpy-core/src/tests/serialization.rs index 8fe477c..3c4e564 100644 --- a/poulpy-core/src/tests/serialization.rs +++ b/poulpy-core/src/tests/serialization.rs @@ -1,12 +1,12 @@ use poulpy_hal::test_suite::serialization::test_reader_writer_interface; use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWEAutomorphismKey, GGLWECiphertext, GGLWESwitchingKey, GGLWETensorKey, GGSWCiphertext, - GLWECiphertext, GLWEToLWEKey, LWECiphertext, LWESwitchingKey, LWEToGLWESwitchingKey, Rank, TorusPrecision, + AutomorphismKey, Base2K, Degree, Dnum, Dsize, GGLWE, GGSW, GLWE, GLWESwitchingKey, GLWEToLWESwitchingKey, LWE, + LWESwitchingKey, LWEToGLWESwitchingKey, Rank, TensorKey, TorusPrecision, compressed::{ - GGLWEAutomorphismKeyCompressed, GGLWECiphertextCompressed, GGLWESwitchingKeyCompressed, GGLWETensorKeyCompressed, - GGSWCiphertextCompressed, GLWECiphertextCompressed, GLWEToLWESwitchingKeyCompressed, LWECiphertextCompressed, - LWESwitchingKeyCompressed, LWEToGLWESwitchingKeyCompressed, + AutomorphismKeyCompressed, GGLWECompressed, GGSWCompressed, GLWECompressed, GLWESwitchingKeyCompressed, + GLWEToLWESwitchingKeyCompressed, LWECompressed, LWESwitchingKeyCompressed, LWEToGLWESwitchingKeyCompressed, + TensorKeyCompressed, }, }; @@ -20,95 +20,93 @@ const DSIZE: Dsize = Dsize(1); #[test] fn glwe_serialization() { - let original: GLWECiphertext> = GLWECiphertext::alloc_with(N_GLWE, BASE2K, K, RANK); + let original: GLWE> = GLWE::alloc(N_GLWE, BASE2K, K, RANK); poulpy_hal::test_suite::serialization::test_reader_writer_interface(original); } #[test] fn glwe_compressed_serialization() { - let original: GLWECiphertextCompressed> = GLWECiphertextCompressed::alloc_with(N_GLWE, BASE2K, K, RANK); + let original: GLWECompressed> = GLWECompressed::alloc(N_GLWE, BASE2K, K, RANK); test_reader_writer_interface(original); } #[test] fn lwe_serialization() { - let original: LWECiphertext> = LWECiphertext::alloc_with(N_LWE, BASE2K, K); + let original: LWE> = LWE::alloc(N_LWE, BASE2K, K); test_reader_writer_interface(original); } #[test] fn lwe_compressed_serialization() { - let original: LWECiphertextCompressed> = LWECiphertextCompressed::alloc_with(BASE2K, K); + let original: LWECompressed> = LWECompressed::alloc(BASE2K, K); test_reader_writer_interface(original); } #[test] fn test_gglwe_serialization() { - let original: GGLWECiphertext> = GGLWECiphertext::alloc_with(N_GLWE, BASE2K, K, RANK, RANK, DNUM, DSIZE); + let original: GGLWE> = GGLWE::alloc(N_GLWE, BASE2K, K, RANK, RANK, DNUM, DSIZE); test_reader_writer_interface(original); } #[test] fn test_gglwe_compressed_serialization() { - let original: GGLWECiphertextCompressed> = - GGLWECiphertextCompressed::alloc_with(N_GLWE, BASE2K, K, RANK, RANK, DNUM, DSIZE); + let original: GGLWECompressed> = GGLWECompressed::alloc(N_GLWE, BASE2K, K, RANK, RANK, DNUM, DSIZE); test_reader_writer_interface(original); } #[test] fn test_glwe_switching_key_serialization() { - let original: GGLWESwitchingKey> = GGLWESwitchingKey::alloc_with(N_GLWE, BASE2K, K, RANK, RANK, DNUM, DSIZE); + let original: GLWESwitchingKey> = GLWESwitchingKey::alloc(N_GLWE, BASE2K, K, RANK, RANK, DNUM, DSIZE); test_reader_writer_interface(original); } #[test] fn test_glwe_switching_key_compressed_serialization() { - let original: GGLWESwitchingKeyCompressed> = - GGLWESwitchingKeyCompressed::alloc_with(N_GLWE, BASE2K, K, RANK, RANK, DNUM, DSIZE); + let original: GLWESwitchingKeyCompressed> = + GLWESwitchingKeyCompressed::alloc(N_GLWE, BASE2K, K, RANK, RANK, DNUM, DSIZE); test_reader_writer_interface(original); } #[test] fn test_automorphism_key_serialization() { - let original: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc_with(N_GLWE, BASE2K, K, RANK, DNUM, DSIZE); + let original: AutomorphismKey> = AutomorphismKey::alloc_with(N_GLWE, BASE2K, K, RANK, DNUM, DSIZE); test_reader_writer_interface(original); } #[test] fn test_automorphism_key_compressed_serialization() { - let original: GGLWEAutomorphismKeyCompressed> = - GGLWEAutomorphismKeyCompressed::alloc_with(N_GLWE, BASE2K, K, RANK, DNUM, DSIZE); + let original: AutomorphismKeyCompressed> = AutomorphismKeyCompressed::alloc(N_GLWE, BASE2K, K, RANK, DNUM, DSIZE); test_reader_writer_interface(original); } #[test] fn test_tensor_key_serialization() { - let original: GGLWETensorKey> = GGLWETensorKey::alloc_with(N_GLWE, BASE2K, K, RANK, DNUM, DSIZE); + let original: TensorKey> = TensorKey::alloc(N_GLWE, BASE2K, K, RANK, DNUM, DSIZE); test_reader_writer_interface(original); } #[test] fn test_tensor_key_compressed_serialization() { - let original: GGLWETensorKeyCompressed> = GGLWETensorKeyCompressed::alloc_with(N_GLWE, BASE2K, K, RANK, DNUM, DSIZE); + let original: TensorKeyCompressed> = TensorKeyCompressed::alloc(N_GLWE, BASE2K, K, RANK, DNUM, DSIZE); test_reader_writer_interface(original); } #[test] fn glwe_to_lwe_switching_key_serialization() { - let original: GLWEToLWEKey> = GLWEToLWEKey::alloc_with(N_GLWE, BASE2K, K, RANK, DNUM); + let original: GLWEToLWESwitchingKey> = GLWEToLWESwitchingKey::alloc(N_GLWE, BASE2K, K, RANK, DNUM); test_reader_writer_interface(original); } #[test] fn glwe_to_lwe_switching_key_compressed_serialization() { let original: GLWEToLWESwitchingKeyCompressed> = - GLWEToLWESwitchingKeyCompressed::alloc_with(N_GLWE, BASE2K, K, RANK, DNUM); + GLWEToLWESwitchingKeyCompressed::alloc(N_GLWE, BASE2K, K, RANK, DNUM); test_reader_writer_interface(original); } #[test] fn lwe_to_glwe_switching_key_serialization() { - let original: LWEToGLWESwitchingKey> = LWEToGLWESwitchingKey::alloc_with(N_GLWE, BASE2K, K, RANK, DNUM); + let original: LWEToGLWESwitchingKey> = LWEToGLWESwitchingKey::alloc(N_GLWE, BASE2K, K, RANK, DNUM); test_reader_writer_interface(original); } @@ -121,24 +119,24 @@ fn lwe_to_glwe_switching_key_compressed_serialization() { #[test] fn lwe_switching_key_serialization() { - let original: LWESwitchingKey> = LWESwitchingKey::alloc_with(N_GLWE, BASE2K, K, DNUM); + let original: LWESwitchingKey> = LWESwitchingKey::alloc(N_GLWE, BASE2K, K, DNUM); test_reader_writer_interface(original); } #[test] fn lwe_switching_key_compressed_serialization() { - let original: LWESwitchingKeyCompressed> = LWESwitchingKeyCompressed::alloc_with(N_GLWE, BASE2K, K, DNUM); + let original: LWESwitchingKeyCompressed> = LWESwitchingKeyCompressed::alloc(N_GLWE, BASE2K, K, DNUM); test_reader_writer_interface(original); } #[test] fn ggsw_serialization() { - let original: GGSWCiphertext> = GGSWCiphertext::alloc_with(N_GLWE, BASE2K, K, RANK, DNUM, DSIZE); + let original: GGSW> = GGSW::alloc(N_GLWE, BASE2K, K, RANK, DNUM, DSIZE); test_reader_writer_interface(original); } #[test] fn ggsw_compressed_serialization() { - let original: GGSWCiphertextCompressed> = GGSWCiphertextCompressed::alloc_with(N_GLWE, BASE2K, K, RANK, DNUM, DSIZE); + let original: GGSWCompressed> = GGSWCompressed::alloc(N_GLWE, BASE2K, K, RANK, DNUM, DSIZE); test_reader_writer_interface(original); } diff --git a/poulpy-core/src/tests/test_suite/automorphism/gglwe_atk.rs b/poulpy-core/src/tests/test_suite/automorphism/gglwe_atk.rs index 1dd3e58..37b5403 100644 --- a/poulpy-core/src/tests/test_suite/automorphism/gglwe_atk.rs +++ b/poulpy-core/src/tests/test_suite/automorphism/gglwe_atk.rs @@ -1,9 +1,9 @@ use poulpy_hal::{ api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, - VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, - VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, + VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, + VecZnxCopy, VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSubScalarInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, }, @@ -18,8 +18,8 @@ use poulpy_hal::{ use crate::{ encryption::SIGMA, layouts::{ - GGLWEAutomorphismKey, GGLWEAutomorphismKeyLayout, GGLWEInfos, GLWEPlaintext, GLWESecret, - prepared::{GGLWEAutomorphismKeyPrepared, GLWESecretPrepared, Prepare, PrepareAlloc}, + AutomorphismKey, AutomorphismKeyLayout, GGLWEInfos, GLWEPlaintext, GLWESecret, + prepared::{AutomorphismKeyPrepared, GLWESecretPrepared, Prepare, PrepareAlloc}, }, noise::log2_std_noise_gglwe_product, }; @@ -27,7 +27,7 @@ use crate::{ #[allow(clippy::too_many_arguments)] pub fn test_gglwe_automorphism_key_automorphism(module: &Module) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft @@ -38,8 +38,8 @@ where + VecZnxBigNormalize + VecZnxAutomorphism + VecZnxAutomorphismInplace - + SvpPPolAllocBytes - + VecZnxDftAllocBytes + + SvpPPolBytesOf + + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes + VmpPMatAlloc + VmpPrepare @@ -84,7 +84,7 @@ where let dnum_out: usize = k_out / (base2k * di); let dnum_apply: usize = k_in.div_ceil(base2k * di); - let auto_key_in_infos: GGLWEAutomorphismKeyLayout = GGLWEAutomorphismKeyLayout { + let auto_key_in_infos: AutomorphismKeyLayout = AutomorphismKeyLayout { n: n.into(), base2k: base2k.into(), k: k_in.into(), @@ -93,7 +93,7 @@ where rank: rank.into(), }; - let auto_key_out_infos: GGLWEAutomorphismKeyLayout = GGLWEAutomorphismKeyLayout { + let auto_key_out_infos: AutomorphismKeyLayout = AutomorphismKeyLayout { n: n.into(), base2k: base2k.into(), k: k_out.into(), @@ -102,7 +102,7 @@ where rank: rank.into(), }; - let auto_key_apply_infos: GGLWEAutomorphismKeyLayout = GGLWEAutomorphismKeyLayout { + let auto_key_apply_infos: AutomorphismKeyLayout = AutomorphismKeyLayout { n: n.into(), base2k: base2k.into(), k: k_apply.into(), @@ -111,18 +111,18 @@ where rank: rank.into(), }; - let mut auto_key_in: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&auto_key_in_infos); - let mut auto_key_out: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&auto_key_out_infos); - let mut auto_key_apply: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&auto_key_apply_infos); + let mut auto_key_in: AutomorphismKey> = AutomorphismKey::alloc_from_infos(&auto_key_in_infos); + let mut auto_key_out: AutomorphismKey> = AutomorphismKey::alloc_from_infos(&auto_key_out_infos); + let mut auto_key_apply: AutomorphismKey> = AutomorphismKey::alloc_from_infos(&auto_key_apply_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, &auto_key_in_infos) - | GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, &auto_key_apply_infos) - | GGLWEAutomorphismKey::automorphism_scratch_space( + AutomorphismKey::encrypt_sk_tmp_bytes(module, &auto_key_in_infos) + | AutomorphismKey::encrypt_sk_tmp_bytes(module, &auto_key_apply_infos) + | AutomorphismKey::automorphism_tmp_bytes( module, &auto_key_out_infos, &auto_key_in_infos, @@ -130,7 +130,7 @@ where ), ); - let mut sk: GLWESecret> = GLWESecret::alloc(&auto_key_in); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&auto_key_in); sk.fill_ternary_prob(0.5, &mut source_xs); // gglwe_{s1}(s0) = s0 -> s1 @@ -153,8 +153,8 @@ where scratch.borrow(), ); - let mut auto_key_apply_prepared: GGLWEAutomorphismKeyPrepared, B> = - GGLWEAutomorphismKeyPrepared::alloc(module, &auto_key_apply_infos); + let mut auto_key_apply_prepared: AutomorphismKeyPrepared, B> = + AutomorphismKeyPrepared::alloc_from_infos(module, &auto_key_apply_infos); auto_key_apply_prepared.prepare(module, &auto_key_apply, scratch.borrow()); @@ -166,9 +166,9 @@ where scratch.borrow(), ); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&auto_key_out_infos); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&auto_key_out_infos); - let mut sk_auto: GLWESecret> = GLWESecret::alloc(&auto_key_out_infos); + let mut sk_auto: GLWESecret> = GLWESecret::alloc_from_infos(&auto_key_out_infos); sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk for i in 0..rank { module.vec_znx_automorphism( @@ -224,7 +224,7 @@ where #[allow(clippy::too_many_arguments)] pub fn test_gglwe_automorphism_key_automorphism_inplace(module: &Module) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -238,9 +238,9 @@ where + VecZnxNormalize + VecZnxSub + SvpPrepare - + SvpPPolAllocBytes + + SvpPPolBytesOf + SvpPPolAlloc - + VecZnxDftAllocBytes + + VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDftTmpBytes @@ -255,8 +255,8 @@ where + VecZnxAddScalarInplace + VecZnxAutomorphism + VecZnxAutomorphismInplace - + VecZnxDftAllocBytes - + VecZnxBigAllocBytes + + VecZnxDftBytesOf + + VecZnxBigBytesOf + VecZnxDftApply + SvpApplyDftToDftInplace + VecZnxIdftApplyConsume @@ -293,7 +293,7 @@ where let dnum_in: usize = k_in / (base2k * di); let dnum_apply: usize = k_in.div_ceil(base2k * di); - let auto_key_layout: GGLWEAutomorphismKeyLayout = GGLWEAutomorphismKeyLayout { + let auto_key_layout: AutomorphismKeyLayout = AutomorphismKeyLayout { n: n.into(), base2k: base2k.into(), k: k_in.into(), @@ -302,7 +302,7 @@ where rank: rank.into(), }; - let auto_key_apply_layout: GGLWEAutomorphismKeyLayout = GGLWEAutomorphismKeyLayout { + let auto_key_apply_layout: AutomorphismKeyLayout = AutomorphismKeyLayout { n: n.into(), base2k: base2k.into(), k: k_apply.into(), @@ -311,20 +311,20 @@ where rank: rank.into(), }; - let mut auto_key: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&auto_key_layout); - let mut auto_key_apply: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&auto_key_apply_layout); + let mut auto_key: AutomorphismKey> = AutomorphismKey::alloc_from_infos(&auto_key_layout); + let mut auto_key_apply: AutomorphismKey> = AutomorphismKey::alloc_from_infos(&auto_key_apply_layout); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, &auto_key) - | GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, &auto_key_apply) - | GGLWEAutomorphismKey::automorphism_inplace_scratch_space(module, &auto_key, &auto_key_apply), + AutomorphismKey::encrypt_sk_tmp_bytes(module, &auto_key) + | AutomorphismKey::encrypt_sk_tmp_bytes(module, &auto_key_apply) + | AutomorphismKey::automorphism_inplace_tmp_bytes(module, &auto_key, &auto_key_apply), ); - let mut sk: GLWESecret> = GLWESecret::alloc(&auto_key); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&auto_key); sk.fill_ternary_prob(0.5, &mut source_xs); // gglwe_{s1}(s0) = s0 -> s1 @@ -347,17 +347,17 @@ where scratch.borrow(), ); - let mut auto_key_apply_prepared: GGLWEAutomorphismKeyPrepared, B> = - GGLWEAutomorphismKeyPrepared::alloc(module, &auto_key_apply_layout); + let mut auto_key_apply_prepared: AutomorphismKeyPrepared, B> = + AutomorphismKeyPrepared::alloc_from_infos(module, &auto_key_apply_layout); auto_key_apply_prepared.prepare(module, &auto_key_apply, scratch.borrow()); // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) auto_key.automorphism_inplace(module, &auto_key_apply_prepared, scratch.borrow()); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&auto_key); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&auto_key); - let mut sk_auto: GLWESecret> = GLWESecret::alloc(&auto_key); + let mut sk_auto: GLWESecret> = GLWESecret::alloc_from_infos(&auto_key); sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk for i in 0..rank { diff --git a/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs b/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs index 2fd7151..7c2a427 100644 --- a/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs +++ b/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs @@ -1,9 +1,9 @@ use poulpy_hal::{ api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, - VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigNormalize, - VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, + VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigBytesOf, VecZnxBigNormalize, + VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftApply, VecZnxDftBytesOf, VecZnxDftCopy, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, @@ -19,16 +19,16 @@ use poulpy_hal::{ use crate::{ encryption::SIGMA, layouts::{ - GGLWEAutomorphismKey, GGLWETensorKey, GGLWETensorKeyLayout, GGSWCiphertext, GGSWCiphertextLayout, GLWESecret, - prepared::{GGLWEAutomorphismKeyPrepared, GGLWETensorKeyPrepared, GLWESecretPrepared, Prepare, PrepareAlloc}, + AutomorphismKey, GGSW, GGSWCiphertextLayout, GLWESecret, TensorKey, TensorKeyLayout, + prepared::{AutomorphismKeyPrepared, GLWESecretPrepared, Prepare, PrepareAlloc, TensorKeyPrepared}, }, noise::noise_ggsw_keyswitch, }; pub fn test_ggsw_automorphism(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxBigAllocBytes + Module: VecZnxDftBytesOf + + VecZnxBigBytesOf + VecZnxDftApply + SvpApplyDftToDftInplace + VecZnxIdftApplyConsume @@ -41,7 +41,7 @@ where + VecZnxBigNormalizeTmpBytes + VecZnxIdftApplyTmpA + SvpPrepare - + SvpPPolAllocBytes + + SvpPPolBytesOf + SvpPPolAlloc + VecZnxAddScalarInplace + VecZnxCopy @@ -110,7 +110,7 @@ where rank: rank.into(), }; - let tensor_key_layout: GGLWETensorKeyLayout = GGLWETensorKeyLayout { + let tensor_key_layout: TensorKeyLayout = TensorKeyLayout { n: n.into(), base2k: base2k.into(), k: k_tsk.into(), @@ -119,7 +119,7 @@ where rank: rank.into(), }; - let auto_key_layout: GGLWETensorKeyLayout = GGLWETensorKeyLayout { + let auto_key_layout: TensorKeyLayout = TensorKeyLayout { n: n.into(), base2k: base2k.into(), k: k_ksk.into(), @@ -128,10 +128,10 @@ where rank: rank.into(), }; - let mut ct_in: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_in_layout); - let mut ct_out: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_out_layout); - let mut tensor_key: GGLWETensorKey> = GGLWETensorKey::alloc(&tensor_key_layout); - let mut auto_key: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&auto_key_layout); + let mut ct_in: GGSW> = GGSW::alloc_from_infos(&ggsw_in_layout); + let mut ct_out: GGSW> = GGSW::alloc_from_infos(&ggsw_out_layout); + let mut tensor_key: TensorKey> = TensorKey::alloc_from_infos(&tensor_key_layout); + let mut auto_key: AutomorphismKey> = AutomorphismKey::alloc_from_infos(&auto_key_layout); let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); let mut source_xs: Source = Source::new([0u8; 32]); @@ -139,15 +139,15 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, &ct_in) - | GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, &auto_key) - | GGLWETensorKey::encrypt_sk_scratch_space(module, &tensor_key) - | GGSWCiphertext::automorphism_scratch_space(module, &ct_out, &ct_in, &auto_key, &tensor_key), + GGSW::encrypt_sk_tmp_bytes(module, &ct_in) + | AutomorphismKey::encrypt_sk_tmp_bytes(module, &auto_key) + | TensorKey::encrypt_sk_tmp_bytes(module, &tensor_key) + | GGSW::automorphism_tmp_bytes(module, &ct_out, &ct_in, &auto_key, &tensor_key), ); let var_xs: f64 = 0.5; - let mut sk: GLWESecret> = GLWESecret::alloc(&ct_out); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&ct_out); sk.fill_ternary_prob(var_xs, &mut source_xs); let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); @@ -178,11 +178,11 @@ where scratch.borrow(), ); - let mut auto_key_prepared: GGLWEAutomorphismKeyPrepared, B> = - GGLWEAutomorphismKeyPrepared::alloc(module, &auto_key_layout); + let mut auto_key_prepared: AutomorphismKeyPrepared, B> = + AutomorphismKeyPrepared::alloc_from_infos(module, &auto_key_layout); auto_key_prepared.prepare(module, &auto_key, scratch.borrow()); - let mut tsk_prepared: GGLWETensorKeyPrepared, B> = GGLWETensorKeyPrepared::alloc(module, &tensor_key_layout); + let mut tsk_prepared: TensorKeyPrepared, B> = TensorKeyPrepared::alloc_from_infos(module, &tensor_key_layout); tsk_prepared.prepare(module, &tensor_key, scratch.borrow()); ct_out.automorphism( @@ -219,8 +219,8 @@ where #[allow(clippy::too_many_arguments)] pub fn test_ggsw_automorphism_inplace(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxBigAllocBytes + Module: VecZnxDftBytesOf + + VecZnxBigBytesOf + VecZnxDftApply + SvpApplyDftToDftInplace + VecZnxIdftApplyConsume @@ -233,7 +233,7 @@ where + VecZnxBigNormalizeTmpBytes + VecZnxIdftApplyTmpA + SvpPrepare - + SvpPPolAllocBytes + + SvpPPolBytesOf + SvpPPolAlloc + VecZnxAddScalarInplace + VecZnxCopy @@ -291,7 +291,7 @@ where rank: rank.into(), }; - let tensor_key_layout: GGLWETensorKeyLayout = GGLWETensorKeyLayout { + let tensor_key_layout: TensorKeyLayout = TensorKeyLayout { n: n.into(), base2k: base2k.into(), k: k_tsk.into(), @@ -300,7 +300,7 @@ where rank: rank.into(), }; - let auto_key_layout: GGLWETensorKeyLayout = GGLWETensorKeyLayout { + let auto_key_layout: TensorKeyLayout = TensorKeyLayout { n: n.into(), base2k: base2k.into(), k: k_ksk.into(), @@ -309,9 +309,9 @@ where rank: rank.into(), }; - let mut ct: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_out_layout); - let mut tensor_key: GGLWETensorKey> = GGLWETensorKey::alloc(&tensor_key_layout); - let mut auto_key: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&auto_key_layout); + let mut ct: GGSW> = GGSW::alloc_from_infos(&ggsw_out_layout); + let mut tensor_key: TensorKey> = TensorKey::alloc_from_infos(&tensor_key_layout); + let mut auto_key: AutomorphismKey> = AutomorphismKey::alloc_from_infos(&auto_key_layout); let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); let mut source_xs: Source = Source::new([0u8; 32]); @@ -319,15 +319,15 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, &ct) - | GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, &auto_key) - | GGLWETensorKey::encrypt_sk_scratch_space(module, &tensor_key) - | GGSWCiphertext::automorphism_inplace_scratch_space(module, &ct, &auto_key, &tensor_key), + GGSW::encrypt_sk_tmp_bytes(module, &ct) + | AutomorphismKey::encrypt_sk_tmp_bytes(module, &auto_key) + | TensorKey::encrypt_sk_tmp_bytes(module, &tensor_key) + | GGSW::automorphism_inplace_tmp_bytes(module, &ct, &auto_key, &tensor_key), ); let var_xs: f64 = 0.5; - let mut sk: GLWESecret> = GLWESecret::alloc(&ct); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&ct); sk.fill_ternary_prob(var_xs, &mut source_xs); let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); @@ -358,11 +358,11 @@ where scratch.borrow(), ); - let mut auto_key_prepared: GGLWEAutomorphismKeyPrepared, B> = - GGLWEAutomorphismKeyPrepared::alloc(module, &auto_key_layout); + let mut auto_key_prepared: AutomorphismKeyPrepared, B> = + AutomorphismKeyPrepared::alloc_from_infos(module, &auto_key_layout); auto_key_prepared.prepare(module, &auto_key, scratch.borrow()); - let mut tsk_prepared: GGLWETensorKeyPrepared, B> = GGLWETensorKeyPrepared::alloc(module, &tensor_key_layout); + let mut tsk_prepared: TensorKeyPrepared, B> = TensorKeyPrepared::alloc_from_infos(module, &tensor_key_layout); tsk_prepared.prepare(module, &tensor_key, scratch.borrow()); ct.automorphism_inplace(module, &auto_key_prepared, &tsk_prepared, scratch.borrow()); diff --git a/poulpy-core/src/tests/test_suite/automorphism/glwe_ct.rs b/poulpy-core/src/tests/test_suite/automorphism/glwe_ct.rs index 5828c48..02afcb8 100644 --- a/poulpy-core/src/tests/test_suite/automorphism/glwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/automorphism/glwe_ct.rs @@ -1,9 +1,9 @@ use poulpy_hal::{ api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, - VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, + VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, + VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, }, @@ -18,15 +18,15 @@ use poulpy_hal::{ use crate::{ encryption::SIGMA, layouts::{ - GGLWEAutomorphismKey, GGLWEAutomorphismKeyLayout, GLWECiphertext, GLWECiphertextLayout, GLWEPlaintext, GLWESecret, - prepared::{GGLWEAutomorphismKeyPrepared, GLWESecretPrepared, Prepare, PrepareAlloc}, + AutomorphismKey, AutomorphismKeyLayout, GLWE, GLWELayout, GLWEPlaintext, GLWESecret, + prepared::{AutomorphismKeyPrepared, GLWESecretPrepared, Prepare, PrepareAlloc}, }, noise::log2_std_noise_gglwe_product, }; pub fn test_glwe_automorphism(module: &Module) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -40,9 +40,9 @@ where + VecZnxNormalize + VecZnxSub + SvpPrepare - + SvpPPolAllocBytes + + SvpPPolBytesOf + SvpPPolAlloc - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxBigAddInplace + VecZnxBigAddSmallInplace + VmpApplyDftToDftTmpBytes @@ -77,21 +77,21 @@ where let n: usize = module.n(); let dnum: usize = k_in.div_ceil(base2k * dsize); - let ct_in_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let ct_in_infos: GLWELayout = GLWELayout { n: n.into(), base2k: base2k.into(), k: k_in.into(), rank: rank.into(), }; - let ct_out_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let ct_out_infos: GLWELayout = GLWELayout { n: n.into(), base2k: base2k.into(), k: k_out.into(), rank: rank.into(), }; - let autokey_infos: GGLWEAutomorphismKeyLayout = GGLWEAutomorphismKeyLayout { + let autokey_infos: AutomorphismKeyLayout = AutomorphismKeyLayout { n: n.into(), base2k: base2k.into(), k: k_out.into(), @@ -100,10 +100,10 @@ where dsize: di.into(), }; - let mut autokey: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&autokey_infos); - let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(&ct_in_infos); - let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(&ct_out_infos); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&ct_out_infos); + let mut autokey: AutomorphismKey> = AutomorphismKey::alloc_from_infos(&autokey_infos); + let mut ct_in: GLWE> = GLWE::alloc_from_infos(&ct_in_infos); + let mut ct_out: GLWE> = GLWE::alloc_from_infos(&ct_out_infos); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&ct_out_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -112,13 +112,13 @@ where module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, &autokey) - | GLWECiphertext::decrypt_scratch_space(module, &ct_out) - | GLWECiphertext::encrypt_sk_scratch_space(module, &ct_in) - | GLWECiphertext::automorphism_scratch_space(module, &ct_out, &ct_in, &autokey), + AutomorphismKey::encrypt_sk_tmp_bytes(module, &autokey) + | GLWE::decrypt_tmp_bytes(module, &ct_out) + | GLWE::encrypt_sk_tmp_bytes(module, &ct_in) + | GLWE::automorphism_tmp_bytes(module, &ct_out, &ct_in, &autokey), ); - let mut sk: GLWESecret> = GLWESecret::alloc(&ct_out); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&ct_out); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); @@ -140,8 +140,8 @@ where scratch.borrow(), ); - let mut autokey_prepared: GGLWEAutomorphismKeyPrepared, B> = - GGLWEAutomorphismKeyPrepared::alloc(module, &autokey_infos); + let mut autokey_prepared: AutomorphismKeyPrepared, B> = + AutomorphismKeyPrepared::alloc_from_infos(module, &autokey_infos); autokey_prepared.prepare(module, &autokey, scratch.borrow()); ct_out.automorphism(module, &ct_in, &autokey_prepared, scratch.borrow()); @@ -169,7 +169,7 @@ where #[allow(clippy::too_many_arguments)] pub fn test_glwe_automorphism_inplace(module: &Module) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -183,9 +183,9 @@ where + VecZnxNormalize + VecZnxSub + SvpPrepare - + SvpPPolAllocBytes + + SvpPPolBytesOf + SvpPPolAlloc - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxBigAddInplace + VecZnxBigAddSmallInplace + VmpApplyDftToDftTmpBytes @@ -219,14 +219,14 @@ where let n: usize = module.n(); let dnum: usize = k_out.div_ceil(base2k * dsize); - let ct_out_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let ct_out_infos: GLWELayout = GLWELayout { n: n.into(), base2k: base2k.into(), k: k_out.into(), rank: rank.into(), }; - let autokey_infos: GGLWEAutomorphismKeyLayout = GGLWEAutomorphismKeyLayout { + let autokey_infos: AutomorphismKeyLayout = AutomorphismKeyLayout { n: n.into(), base2k: base2k.into(), k: k_ksk.into(), @@ -235,9 +235,9 @@ where dsize: di.into(), }; - let mut autokey: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&autokey_infos); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&ct_out_infos); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&ct_out_infos); + let mut autokey: AutomorphismKey> = AutomorphismKey::alloc_from_infos(&autokey_infos); + let mut ct: GLWE> = GLWE::alloc_from_infos(&ct_out_infos); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&ct_out_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -246,13 +246,13 @@ where module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, &autokey) - | GLWECiphertext::decrypt_scratch_space(module, &ct) - | GLWECiphertext::encrypt_sk_scratch_space(module, &ct) - | GLWECiphertext::automorphism_inplace_scratch_space(module, &ct, &autokey), + AutomorphismKey::encrypt_sk_tmp_bytes(module, &autokey) + | GLWE::decrypt_tmp_bytes(module, &ct) + | GLWE::encrypt_sk_tmp_bytes(module, &ct) + | GLWE::automorphism_inplace_tmp_bytes(module, &ct, &autokey), ); - let mut sk: GLWESecret> = GLWESecret::alloc(&ct); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&ct); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); @@ -274,8 +274,8 @@ where scratch.borrow(), ); - let mut autokey_prepared: GGLWEAutomorphismKeyPrepared, B> = - GGLWEAutomorphismKeyPrepared::alloc(module, &autokey); + let mut autokey_prepared: AutomorphismKeyPrepared, B> = + AutomorphismKeyPrepared::alloc_from_infos(module, &autokey); autokey_prepared.prepare(module, &autokey, scratch.borrow()); ct.automorphism_inplace(module, &autokey_prepared, scratch.borrow()); diff --git a/poulpy-core/src/tests/test_suite/conversion.rs b/poulpy-core/src/tests/test_suite/conversion.rs index c2c81dc..972e0c6 100644 --- a/poulpy-core/src/tests/test_suite/conversion.rs +++ b/poulpy-core/src/tests/test_suite/conversion.rs @@ -1,9 +1,9 @@ use poulpy_hal::{ api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigAddInplace, - VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, - VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, + VecZnxBigAddSmallInplace, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftApply, + VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, ZnAddNormal, ZnFillUniform, ZnNormalizeInplace, }, @@ -16,15 +16,14 @@ use poulpy_hal::{ }; use crate::layouts::{ - Base2K, Degree, Dnum, GLWECiphertext, GLWECiphertextLayout, GLWEPlaintext, GLWESecret, GLWEToLWEKey, GLWEToLWEKeyLayout, - LWECiphertext, LWECiphertextLayout, LWEPlaintext, LWESecret, LWEToGLWESwitchingKey, LWEToGLWESwitchingKeyLayout, Rank, - TorusPrecision, + Base2K, Degree, Dnum, GLWE, GLWELayout, GLWEPlaintext, GLWESecret, GLWEToLWEKeyLayout, GLWEToLWESwitchingKey, LWE, + LWECiphertextLayout, LWEPlaintext, LWESecret, LWEToGLWESwitchingKey, LWEToGLWESwitchingKeyLayout, Rank, TorusPrecision, prepared::{GLWESecretPrepared, GLWEToLWESwitchingKeyPrepared, LWEToGLWESwitchingKeyPrepared, PrepareAlloc}, }; pub fn test_lwe_to_glwe(module: &Module) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -37,9 +36,9 @@ where + VecZnxNormalize + VecZnxSub + SvpPrepare - + SvpPPolAllocBytes + + SvpPPolBytesOf + SvpPPolAlloc - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxBigAddInplace + VecZnxBigAddSmallInplace + VecZnxNormalizeTmpBytes @@ -83,7 +82,7 @@ where rank_out: rank, }; - let glwe_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_infos: GLWELayout = GLWELayout { n: n_glwe, base2k: Base2K(17), k: TorusPrecision(34), @@ -97,12 +96,12 @@ where }; let mut scratch: ScratchOwned = ScratchOwned::alloc( - LWEToGLWESwitchingKey::encrypt_sk_scratch_space(module, &lwe_to_glwe_infos) - | GLWECiphertext::from_lwe_scratch_space(module, &glwe_infos, &lwe_infos, &lwe_to_glwe_infos) - | GLWECiphertext::decrypt_scratch_space(module, &glwe_infos), + LWEToGLWESwitchingKey::encrypt_sk_tmp_bytes(module, &lwe_to_glwe_infos) + | GLWE::from_lwe_tmp_bytes(module, &glwe_infos, &lwe_infos, &lwe_to_glwe_infos) + | GLWE::decrypt_tmp_bytes(module, &glwe_infos), ); - let mut sk_glwe: GLWESecret> = GLWESecret::alloc(&glwe_infos); + let mut sk_glwe: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_infos); sk_glwe.fill_ternary_prob(0.5, &mut source_xs); let sk_glwe_prepared: GLWESecretPrepared, B> = sk_glwe.prepare_alloc(module, scratch.borrow()); @@ -112,13 +111,13 @@ where let data: i64 = 17; - let mut lwe_pt: LWEPlaintext> = LWEPlaintext::alloc(&lwe_infos); + let mut lwe_pt: LWEPlaintext> = LWEPlaintext::alloc_from_infos(&lwe_infos); lwe_pt.encode_i64(data, k_lwe_pt); - let mut lwe_ct: LWECiphertext> = LWECiphertext::alloc(&lwe_infos); + let mut lwe_ct: LWE> = LWE::alloc_from_infos(&lwe_infos); lwe_ct.encrypt_sk(module, &lwe_pt, &sk_lwe, &mut source_xa, &mut source_xe); - let mut ksk: LWEToGLWESwitchingKey> = LWEToGLWESwitchingKey::alloc(&lwe_to_glwe_infos); + let mut ksk: LWEToGLWESwitchingKey> = LWEToGLWESwitchingKey::alloc_from_infos(&lwe_to_glwe_infos); ksk.encrypt_sk( module, @@ -129,13 +128,13 @@ where scratch.borrow(), ); - let mut glwe_ct: GLWECiphertext> = GLWECiphertext::alloc(&glwe_infos); + let mut glwe_ct: GLWE> = GLWE::alloc_from_infos(&glwe_infos); let ksk_prepared: LWEToGLWESwitchingKeyPrepared, B> = ksk.prepare_alloc(module, scratch.borrow()); glwe_ct.from_lwe(module, &lwe_ct, &ksk_prepared, scratch.borrow()); - let mut glwe_pt: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_infos); + let mut glwe_pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_infos); glwe_ct.decrypt(module, &mut glwe_pt, &sk_glwe_prepared, scratch.borrow()); assert_eq!(glwe_pt.data.at(0, 0)[0], lwe_pt.data.at(0, 0)[0]); @@ -143,7 +142,7 @@ where pub fn test_glwe_to_lwe(module: &Module) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -156,9 +155,9 @@ where + VecZnxNormalize + VecZnxSub + SvpPrepare - + SvpPPolAllocBytes + + SvpPPolBytesOf + SvpPPolAlloc - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxBigAddInplace + VecZnxBigAddSmallInplace + VecZnxNormalizeTmpBytes @@ -196,7 +195,7 @@ where rank_in: rank, }; - let glwe_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_infos: GLWELayout = GLWELayout { n: n_glwe, base2k: Base2K(17), k: TorusPrecision(34), @@ -214,12 +213,12 @@ where let mut source_xe: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWEToLWEKey::encrypt_sk_scratch_space(module, &glwe_to_lwe_infos) - | LWECiphertext::from_glwe_scratch_space(module, &lwe_infos, &glwe_infos, &glwe_to_lwe_infos) - | GLWECiphertext::decrypt_scratch_space(module, &glwe_infos), + GLWEToLWESwitchingKey::encrypt_sk_tmp_bytes(module, &glwe_to_lwe_infos) + | LWE::from_glwe_tmp_bytes(module, &lwe_infos, &glwe_infos, &glwe_to_lwe_infos) + | GLWE::decrypt_tmp_bytes(module, &glwe_infos), ); - let mut sk_glwe: GLWESecret> = GLWESecret::alloc(&glwe_infos); + let mut sk_glwe: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_infos); sk_glwe.fill_ternary_prob(0.5, &mut source_xs); let sk_glwe_prepared: GLWESecretPrepared, B> = sk_glwe.prepare_alloc(module, scratch.borrow()); @@ -228,10 +227,10 @@ where sk_lwe.fill_ternary_prob(0.5, &mut source_xs); let data: i64 = 17; - let mut glwe_pt: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_infos); + let mut glwe_pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_infos); glwe_pt.encode_coeff_i64(data, k_lwe_pt, 0); - let mut glwe_ct = GLWECiphertext::alloc(&glwe_infos); + let mut glwe_ct = GLWE::alloc_from_infos(&glwe_infos); glwe_ct.encrypt_sk( module, &glwe_pt, @@ -241,7 +240,7 @@ where scratch.borrow(), ); - let mut ksk: GLWEToLWEKey> = GLWEToLWEKey::alloc(&glwe_to_lwe_infos); + let mut ksk: GLWEToLWESwitchingKey> = GLWEToLWESwitchingKey::alloc(&glwe_to_lwe_infos); ksk.encrypt_sk( module, @@ -252,13 +251,13 @@ where scratch.borrow(), ); - let mut lwe_ct: LWECiphertext> = LWECiphertext::alloc(&lwe_infos); + let mut lwe_ct: LWE> = LWE::alloc_from_infos(&lwe_infos); let ksk_prepared: GLWEToLWESwitchingKeyPrepared, B> = ksk.prepare_alloc(module, scratch.borrow()); lwe_ct.from_glwe(module, &glwe_ct, &ksk_prepared, scratch.borrow()); - let mut lwe_pt: LWEPlaintext> = LWEPlaintext::alloc(&lwe_infos); + let mut lwe_pt: LWEPlaintext> = LWEPlaintext::alloc_from_infos(&lwe_infos); lwe_ct.decrypt(module, &mut lwe_pt, &sk_lwe); assert_eq!(glwe_pt.data.at(0, 0)[0], lwe_pt.data.at(0, 0)[0]); diff --git a/poulpy-core/src/tests/test_suite/encryption/gglwe_atk.rs b/poulpy-core/src/tests/test_suite/encryption/gglwe_atk.rs index e9b7c93..bd164dc 100644 --- a/poulpy-core/src/tests/test_suite/encryption/gglwe_atk.rs +++ b/poulpy-core/src/tests/test_suite/encryption/gglwe_atk.rs @@ -1,9 +1,9 @@ use poulpy_hal::{ api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, - VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, - VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, + VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, + VecZnxCopy, VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSubScalarInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, }, @@ -18,15 +18,15 @@ use poulpy_hal::{ use crate::{ encryption::SIGMA, layouts::{ - GGLWEAutomorphismKey, GGLWEAutomorphismKeyLayout, GLWEInfos, GLWESecret, - compressed::{Decompress, GGLWEAutomorphismKeyCompressed}, + AutomorphismKey, AutomorphismKeyLayout, GLWEInfos, GLWESecret, + compressed::{AutomorphismKeyCompressed, Decompress}, prepared::{GLWESecretPrepared, PrepareAlloc}, }, }; pub fn test_gglwe_automorphisk_key_encrypt_sk(module: &Module) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -40,7 +40,7 @@ where + VecZnxNormalize + VecZnxSub + SvpPrepare - + SvpPPolAllocBytes + + SvpPPolBytesOf + SvpPPolAlloc + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes @@ -51,7 +51,7 @@ where + VecZnxSwitchRing + VecZnxAddScalarInplace + VecZnxAutomorphismInplace - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxBigAddInplace + VecZnxSubScalarInplace + VecZnxCopy @@ -75,7 +75,7 @@ where let n: usize = module.n(); let dnum: usize = (k_ksk - di * base2k) / (di * base2k); - let atk_infos: GGLWEAutomorphismKeyLayout = GGLWEAutomorphismKeyLayout { + let atk_infos: AutomorphismKeyLayout = AutomorphismKeyLayout { n: n.into(), base2k: base2k.into(), k: k_ksk.into(), @@ -84,17 +84,17 @@ where rank: rank.into(), }; - let mut atk: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&atk_infos); + let mut atk: AutomorphismKey> = AutomorphismKey::alloc_from_infos(&atk_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWEAutomorphismKey::encrypt_sk_scratch_space( + let mut scratch: ScratchOwned = ScratchOwned::alloc(AutomorphismKey::encrypt_sk_tmp_bytes( module, &atk_infos, )); - let mut sk: GLWESecret> = GLWESecret::alloc(&atk_infos); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&atk_infos); sk.fill_ternary_prob(0.5, &mut source_xs); let p = -5; @@ -129,7 +129,7 @@ where pub fn test_gglwe_automorphisk_key_compressed_encrypt_sk(module: &Module) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -143,7 +143,7 @@ where + VecZnxNormalize + VecZnxSub + SvpPrepare - + SvpPPolAllocBytes + + SvpPPolBytesOf + SvpPPolAlloc + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes @@ -154,7 +154,7 @@ where + VecZnxSwitchRing + VecZnxAddScalarInplace + VecZnxAutomorphismInplace - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxBigAddInplace + VecZnxSubScalarInplace + VecZnxCopy @@ -178,7 +178,7 @@ where let n: usize = module.n(); let dnum: usize = (k_ksk - di * base2k) / (di * base2k); - let atk_infos: GGLWEAutomorphismKeyLayout = GGLWEAutomorphismKeyLayout { + let atk_infos: AutomorphismKeyLayout = AutomorphismKeyLayout { n: n.into(), base2k: base2k.into(), k: k_ksk.into(), @@ -187,16 +187,16 @@ where rank: rank.into(), }; - let mut atk_compressed: GGLWEAutomorphismKeyCompressed> = GGLWEAutomorphismKeyCompressed::alloc(&atk_infos); + let mut atk_compressed: AutomorphismKeyCompressed> = AutomorphismKeyCompressed::alloc_from_infos(&atk_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWEAutomorphismKey::encrypt_sk_scratch_space( + let mut scratch: ScratchOwned = ScratchOwned::alloc(AutomorphismKey::encrypt_sk_tmp_bytes( module, &atk_infos, )); - let mut sk: GLWESecret> = GLWESecret::alloc(&atk_infos); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&atk_infos); sk.fill_ternary_prob(0.5, &mut source_xs); let p = -5; @@ -217,7 +217,7 @@ where }); let sk_out_prepared = sk_out.prepare_alloc(module, scratch.borrow()); - let mut atk: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&atk_infos); + let mut atk: AutomorphismKey> = AutomorphismKey::alloc_from_infos(&atk_infos); atk.decompress(module, &atk_compressed); atk.key diff --git a/poulpy-core/src/tests/test_suite/encryption/gglwe_ct.rs b/poulpy-core/src/tests/test_suite/encryption/gglwe_ct.rs index 60bb7e2..5e09781 100644 --- a/poulpy-core/src/tests/test_suite/encryption/gglwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/encryption/gglwe_ct.rs @@ -1,8 +1,8 @@ use poulpy_hal::{ api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, - VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, + VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxCopy, VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSubScalarInplace, VecZnxSwitchRing, VmpPMatAlloc, VmpPrepare, }, @@ -17,15 +17,15 @@ use poulpy_hal::{ use crate::{ encryption::SIGMA, layouts::{ - GGLWECiphertextLayout, GGLWESwitchingKey, GLWESecret, - compressed::{Decompress, GGLWESwitchingKeyCompressed}, + GGLWECiphertextLayout, GLWESecret, GLWESwitchingKey, + compressed::{Decompress, GLWESwitchingKeyCompressed}, prepared::{GLWESecretPrepared, PrepareAlloc}, }, }; pub fn test_gglwe_switching_key_encrypt_sk(module: &Module) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -39,12 +39,12 @@ where + VecZnxNormalize + VecZnxSub + SvpPrepare - + SvpPPolAllocBytes + + SvpPPolBytesOf + SvpPPolAlloc + VecZnxBigAddSmallInplace + VecZnxSwitchRing + VecZnxAddScalarInplace - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxBigAddInplace + VecZnxSubScalarInplace + VecZnxCopy @@ -81,21 +81,21 @@ where rank_out: rank_out.into(), }; - let mut ksk: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&gglwe_infos); + let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc_from_infos(&gglwe_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWESwitchingKey::encrypt_sk_scratch_space( + let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWESwitchingKey::encrypt_sk_tmp_bytes( module, &gglwe_infos, )); - let mut sk_in: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_in.into()); + let mut sk_in: GLWESecret> = GLWESecret::alloc(n.into(), rank_in.into()); sk_in.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_out: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_out.into()); + let mut sk_out: GLWESecret> = GLWESecret::alloc(n.into(), rank_out.into()); sk_out.fill_ternary_prob(0.5, &mut source_xs); let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); @@ -117,7 +117,7 @@ where pub fn test_gglwe_switching_key_compressed_encrypt_sk(module: &Module) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -131,12 +131,12 @@ where + VecZnxNormalize + VecZnxSub + SvpPrepare - + SvpPPolAllocBytes + + SvpPPolBytesOf + SvpPPolAlloc + VecZnxBigAddSmallInplace + VecZnxSwitchRing + VecZnxAddScalarInplace - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxBigAddInplace + VecZnxSubScalarInplace + VecZnxCopy @@ -173,20 +173,21 @@ where rank_out: rank_out.into(), }; - let mut ksk_compressed: GGLWESwitchingKeyCompressed> = GGLWESwitchingKeyCompressed::alloc(&gglwe_infos); + let mut ksk_compressed: GLWESwitchingKeyCompressed> = + GLWESwitchingKeyCompressed::alloc_from_infos(&gglwe_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWESwitchingKeyCompressed::encrypt_sk_scratch_space( + let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWESwitchingKeyCompressed::encrypt_sk_tmp_bytes( module, &gglwe_infos, )); - let mut sk_in: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_in.into()); + let mut sk_in: GLWESecret> = GLWESecret::alloc(n.into(), rank_in.into()); sk_in.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_out: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_out.into()); + let mut sk_out: GLWESecret> = GLWESecret::alloc(n.into(), rank_out.into()); sk_out.fill_ternary_prob(0.5, &mut source_xs); let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); @@ -201,7 +202,7 @@ where scratch.borrow(), ); - let mut ksk: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&gglwe_infos); + let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc_from_infos(&gglwe_infos); ksk.decompress(module, &ksk_compressed); ksk.key diff --git a/poulpy-core/src/tests/test_suite/encryption/ggsw_ct.rs b/poulpy-core/src/tests/test_suite/encryption/ggsw_ct.rs index 9eacf38..7620709 100644 --- a/poulpy-core/src/tests/test_suite/encryption/ggsw_ct.rs +++ b/poulpy-core/src/tests/test_suite/encryption/ggsw_ct.rs @@ -1,10 +1,10 @@ use poulpy_hal::{ api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, - VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAlloc, VecZnxDftAllocBytes, - VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VmpPMatAlloc, VmpPrepare, + VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAlloc, VecZnxDftApply, + VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, + VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VmpPMatAlloc, VmpPrepare, }, layouts::{Backend, Module, ScalarZnx, ScratchOwned}, oep::{ @@ -17,53 +17,16 @@ use poulpy_hal::{ use crate::{ encryption::SIGMA, layouts::{ - GGSWCiphertext, GGSWCiphertextLayout, GLWESecret, - compressed::{Decompress, GGSWCiphertextCompressed}, + GGSW, GGSWCiphertextLayout, GLWESecret, + compressed::{Decompress, GGSWCompressed}, prepared::{GLWESecretPrepared, PrepareAlloc}, }, }; -pub fn test_ggsw_encrypt_sk(module: &Module) +pub fn test_ggsw_encrypt_sk(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAddSmallInplace - + VecZnxAddScalarInplace - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxCopy - + VmpPMatAlloc - + VmpPrepare - + VecZnxBigAlloc - + VecZnxDftAlloc - + VecZnxBigNormalizeTmpBytes - + VecZnxIdftApplyTmpA, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeSvpPPolImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl - + VecZnxDftAllocBytesImpl - + VecZnxBigAllocBytesImpl - + TakeSvpPPolImpl, + ScratchOwned: ScratchOwnedAlloc, + Module: SvpPrepare, { let base2k: usize = 12; let k: usize = 54; @@ -82,7 +45,7 @@ where rank: rank.into(), }; - let mut ct: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_infos); + let mut ct: GGSW> = GGSW::alloc_from_infos(&ggsw_infos); let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); @@ -92,12 +55,9 @@ where pt_scalar.fill_ternary_hw(0, n, &mut source_xs); - let mut scratch: ScratchOwned = ScratchOwned::alloc(GGSWCiphertext::encrypt_sk_scratch_space( - module, - &ggsw_infos, - )); + let mut scratch: ScratchOwned = ScratchOwned::alloc(GGSW::encrypt_sk_tmp_bytes(module, &ggsw_infos)); - let mut sk: GLWESecret> = GLWESecret::alloc(&ggsw_infos); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&ggsw_infos); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); @@ -119,7 +79,7 @@ where pub fn test_ggsw_compressed_encrypt_sk(module: &Module) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -133,11 +93,11 @@ where + VecZnxNormalize + VecZnxSub + SvpPrepare - + SvpPPolAllocBytes + + SvpPPolBytesOf + SvpPPolAlloc + VecZnxBigAddSmallInplace + VecZnxAddScalarInplace - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxBigAddInplace + VecZnxCopy + VmpPMatAlloc @@ -175,7 +135,7 @@ where rank: rank.into(), }; - let mut ct_compressed: GGSWCiphertextCompressed> = GGSWCiphertextCompressed::alloc(&ggsw_infos); + let mut ct_compressed: GGSWCompressed> = GGSWCompressed::alloc_from_infos(&ggsw_infos); let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); @@ -184,12 +144,12 @@ where pt_scalar.fill_ternary_hw(0, n, &mut source_xs); - let mut scratch: ScratchOwned = ScratchOwned::alloc(GGSWCiphertextCompressed::encrypt_sk_scratch_space( + let mut scratch: ScratchOwned = ScratchOwned::alloc(GGSWCompressed::encrypt_sk_tmp_bytes( module, &ggsw_infos, )); - let mut sk: GLWESecret> = GLWESecret::alloc(&ggsw_infos); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&ggsw_infos); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); @@ -206,7 +166,7 @@ where let noise_f = |_col_i: usize| -(k as f64) + SIGMA.log2() + 0.5; - let mut ct: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_infos); + let mut ct: GGSW> = GGSW::alloc_from_infos(&ggsw_infos); ct.decompress(module, &ct_compressed); ct.assert_noise(module, &sk_prepared, &pt_scalar, noise_f); diff --git a/poulpy-core/src/tests/test_suite/encryption/glwe_ct.rs b/poulpy-core/src/tests/test_suite/encryption/glwe_ct.rs index a1169f6..ed2c1e3 100644 --- a/poulpy-core/src/tests/test_suite/encryption/glwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/encryption/glwe_ct.rs @@ -1,10 +1,9 @@ use poulpy_hal::{ api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxBigAddInplace, VecZnxBigAddNormal, VecZnxBigAddSmallInplace, - VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxCopy, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, - VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, - VecZnxSubInplace, + VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxCopy, VecZnxDftAlloc, VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, + VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, }, layouts::{Backend, Module, ScratchOwned}, oep::{ @@ -17,8 +16,8 @@ use poulpy_hal::{ use crate::{ encryption::SIGMA, layouts::{ - GLWECiphertext, GLWECiphertextLayout, GLWEPlaintext, GLWEPlaintextLayout, GLWEPublicKey, GLWESecret, LWEInfos, - compressed::{Decompress, GLWECiphertextCompressed}, + GLWE, GLWELayout, GLWEPlaintext, GLWEPlaintextLayout, GLWEPublicKey, GLWESecret, LWEInfos, + compressed::{Decompress, GLWECompressed}, prepared::{GLWEPublicKeyPrepared, GLWESecretPrepared, PrepareAlloc}, }, operations::GLWEOperations, @@ -26,8 +25,8 @@ use crate::{ pub fn test_glwe_encrypt_sk(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxBigAllocBytes + Module: VecZnxDftBytesOf + + VecZnxBigBytesOf + VecZnxDftApply + SvpApplyDftToDftInplace + VecZnxIdftApplyConsume @@ -36,7 +35,7 @@ where + VecZnxBigNormalize + VecZnxNormalizeTmpBytes + SvpPrepare - + SvpPPolAllocBytes + + SvpPPolBytesOf + SvpPPolAlloc + SvpApplyDftToDft + VecZnxBigAddNormal @@ -64,7 +63,7 @@ where for rank in 1_usize..3 { let n: usize = module.n(); - let glwe_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_infos: GLWELayout = GLWELayout { n: n.into(), base2k: base2k.into(), k: k_ct.into(), @@ -77,20 +76,19 @@ where k: k_pt.into(), }; - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&glwe_infos); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&pt_infos); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&pt_infos); + let mut ct: GLWE> = GLWE::alloc_from_infos(&glwe_infos); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&pt_infos); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&pt_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::encrypt_sk_scratch_space(module, &glwe_infos) - | GLWECiphertext::decrypt_scratch_space(module, &glwe_infos), + GLWE::encrypt_sk_tmp_bytes(module, &glwe_infos) | GLWE::decrypt_tmp_bytes(module, &glwe_infos), ); - let mut sk: GLWESecret> = GLWESecret::alloc(&glwe_infos); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_infos); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); @@ -118,8 +116,8 @@ where pub fn test_glwe_compressed_encrypt_sk(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxBigAllocBytes + Module: VecZnxDftBytesOf + + VecZnxBigBytesOf + VecZnxDftApply + SvpApplyDftToDftInplace + VecZnxIdftApplyConsume @@ -128,7 +126,7 @@ where + VecZnxBigNormalize + VecZnxNormalizeTmpBytes + SvpPrepare - + SvpPPolAllocBytes + + SvpPPolBytesOf + SvpPPolAlloc + SvpApplyDftToDft + VecZnxBigAddNormal @@ -157,7 +155,7 @@ where for rank in 1_usize..3 { let n: usize = module.n(); - let glwe_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_infos: GLWELayout = GLWELayout { n: n.into(), base2k: base2k.into(), k: k_ct.into(), @@ -170,21 +168,20 @@ where k: k_pt.into(), }; - let mut ct_compressed: GLWECiphertextCompressed> = GLWECiphertextCompressed::alloc(&glwe_infos); + let mut ct_compressed: GLWECompressed> = GLWECompressed::alloc_from_infos(&glwe_infos); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&pt_infos); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&pt_infos); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&pt_infos); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&pt_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertextCompressed::encrypt_sk_scratch_space(module, &glwe_infos) - | GLWECiphertext::decrypt_scratch_space(module, &glwe_infos), + GLWECompressed::encrypt_sk_tmp_bytes(module, &glwe_infos) | GLWE::decrypt_tmp_bytes(module, &glwe_infos), ); - let mut sk: GLWESecret> = GLWESecret::alloc(&glwe_infos); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_infos); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); @@ -201,7 +198,7 @@ where scratch.borrow(), ); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&glwe_infos); + let mut ct: GLWE> = GLWE::alloc_from_infos(&glwe_infos); ct.decompress(module, &ct_compressed); ct.decrypt(module, &mut pt_have, &sk_prepared, scratch.borrow()); @@ -221,8 +218,8 @@ where pub fn test_glwe_encrypt_zero_sk(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxBigAllocBytes + Module: VecZnxDftBytesOf + + VecZnxBigBytesOf + VecZnxDftApply + SvpApplyDftToDftInplace + VecZnxIdftApplyConsume @@ -231,7 +228,7 @@ where + VecZnxBigNormalize + VecZnxNormalizeTmpBytes + SvpPrepare - + SvpPPolAllocBytes + + SvpPPolBytesOf + SvpPPolAlloc + SvpApplyDftToDft + VecZnxBigAddNormal @@ -258,29 +255,28 @@ where for rank in 1_usize..3 { let n: usize = module.n(); - let glwe_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_infos: GLWELayout = GLWELayout { n: n.into(), base2k: base2k.into(), k: k_ct.into(), rank: rank.into(), }; - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_infos); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([1u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::decrypt_scratch_space(module, &glwe_infos) - | GLWECiphertext::encrypt_sk_scratch_space(module, &glwe_infos), + GLWE::decrypt_tmp_bytes(module, &glwe_infos) | GLWE::encrypt_sk_tmp_bytes(module, &glwe_infos), ); - let mut sk: GLWESecret> = GLWESecret::alloc(&glwe_infos); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_infos); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&glwe_infos); + let mut ct: GLWE> = GLWE::alloc_from_infos(&glwe_infos); ct.encrypt_zero_sk( module, @@ -297,7 +293,7 @@ where pub fn test_glwe_encrypt_pk(module: &Module) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -311,10 +307,10 @@ where + VecZnxNormalize + VecZnxSub + SvpPrepare - + SvpPPolAllocBytes + + SvpPPolBytesOf + SvpPPolAlloc + VecZnxBigAddSmallInplace - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxBigAddInplace + VecZnxCopy + VecZnxDftAlloc @@ -336,16 +332,16 @@ where for rank in 1_usize..3 { let n: usize = module.n(); - let glwe_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_infos: GLWELayout = GLWELayout { n: n.into(), base2k: base2k.into(), k: k_ct.into(), rank: rank.into(), }; - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&glwe_infos); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_infos); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_infos); + let mut ct: GLWE> = GLWE::alloc_from_infos(&glwe_infos); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_infos); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -353,17 +349,17 @@ where let mut source_xu: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::encrypt_sk_scratch_space(module, &glwe_infos) - | GLWECiphertext::decrypt_scratch_space(module, &glwe_infos) - | GLWECiphertext::encrypt_pk_scratch_space(module, &glwe_infos), + GLWE::encrypt_sk_tmp_bytes(module, &glwe_infos) + | GLWE::decrypt_tmp_bytes(module, &glwe_infos) + | GLWE::encrypt_pk_tmp_bytes(module, &glwe_infos), ); - let mut sk: GLWESecret> = GLWESecret::alloc(&glwe_infos); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_infos); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); - let mut pk: GLWEPublicKey> = GLWEPublicKey::alloc(&glwe_infos); - pk.generate_from_sk(module, &sk_prepared, &mut source_xa, &mut source_xe); + let mut pk: GLWEPublicKey> = GLWEPublicKey::alloc_from_infos(&glwe_infos); + pk.generate(module, &sk_prepared, &mut source_xa, &mut source_xe); module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa); diff --git a/poulpy-core/src/tests/test_suite/encryption/glwe_tsk.rs b/poulpy-core/src/tests/test_suite/encryption/glwe_tsk.rs index 8bb9730..da05ba1 100644 --- a/poulpy-core/src/tests/test_suite/encryption/glwe_tsk.rs +++ b/poulpy-core/src/tests/test_suite/encryption/glwe_tsk.rs @@ -1,8 +1,8 @@ use poulpy_hal::{ api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, - VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxCopy, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, + VecZnxBigAlloc, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxCopy, VecZnxDftAlloc, VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSubScalarInplace, VecZnxSwitchRing, }, @@ -17,15 +17,15 @@ use poulpy_hal::{ use crate::{ encryption::SIGMA, layouts::{ - Dsize, GGLWETensorKey, GGLWETensorKeyLayout, GLWEPlaintext, GLWESecret, - compressed::{Decompress, GGLWETensorKeyCompressed}, + Dsize, GLWEPlaintext, GLWESecret, TensorKey, TensorKeyLayout, + compressed::{Decompress, TensorKeyCompressed}, prepared::{GLWESecretPrepared, PrepareAlloc}, }, }; pub fn test_gglwe_tensor_key_encrypt_sk(module: &Module) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -39,10 +39,10 @@ where + VecZnxNormalize + VecZnxSub + SvpPrepare - + SvpPPolAllocBytes + + SvpPPolBytesOf + SvpPPolAlloc + VecZnxBigAddSmallInplace - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxBigAddInplace + VecZnxCopy + VecZnxDftAlloc @@ -71,7 +71,7 @@ where let n: usize = module.n(); let dnum: usize = k / base2k; - let tensor_key_infos = GGLWETensorKeyLayout { + let tensor_key_infos = TensorKeyLayout { n: n.into(), base2k: base2k.into(), k: k.into(), @@ -80,18 +80,18 @@ where rank: rank.into(), }; - let mut tensor_key: GGLWETensorKey> = GGLWETensorKey::alloc(&tensor_key_infos); + let mut tensor_key: TensorKey> = TensorKey::alloc_from_infos(&tensor_key_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWETensorKey::encrypt_sk_scratch_space( + let mut scratch: ScratchOwned = ScratchOwned::alloc(TensorKey::encrypt_sk_tmp_bytes( module, &tensor_key_infos, )); - let mut sk: GLWESecret> = GLWESecret::alloc(&tensor_key_infos); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&tensor_key_infos); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); @@ -103,11 +103,11 @@ where scratch.borrow(), ); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&tensor_key_infos); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&tensor_key_infos); let mut sk_ij_dft = module.vec_znx_dft_alloc(1, 1); let mut sk_ij_big = module.vec_znx_big_alloc(1, 1); - let mut sk_ij: GLWESecret> = GLWESecret::alloc_with(n.into(), 1_u32.into()); + let mut sk_ij: GLWESecret> = GLWESecret::alloc(n.into(), 1_u32.into()); let mut sk_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(rank, 1); for i in 0..rank { @@ -145,7 +145,7 @@ where pub fn test_gglwe_tensor_key_compressed_encrypt_sk(module: &Module) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -159,10 +159,10 @@ where + VecZnxNormalize + VecZnxSub + SvpPrepare - + SvpPPolAllocBytes + + SvpPPolBytesOf + SvpPPolAlloc + VecZnxBigAddSmallInplace - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxBigAddInplace + VecZnxCopy + VecZnxDftAlloc @@ -190,7 +190,7 @@ where let n: usize = module.n(); let dnum: usize = k / base2k; - let tensor_key_infos: GGLWETensorKeyLayout = GGLWETensorKeyLayout { + let tensor_key_infos: TensorKeyLayout = TensorKeyLayout { n: n.into(), base2k: base2k.into(), k: k.into(), @@ -199,17 +199,17 @@ where rank: rank.into(), }; - let mut tensor_key_compressed: GGLWETensorKeyCompressed> = GGLWETensorKeyCompressed::alloc(&tensor_key_infos); + let mut tensor_key_compressed: TensorKeyCompressed> = TensorKeyCompressed::alloc_from_infos(&tensor_key_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWETensorKeyCompressed::encrypt_sk_scratch_space( + let mut scratch: ScratchOwned = ScratchOwned::alloc(TensorKeyCompressed::encrypt_sk_tmp_bytes( module, &tensor_key_infos, )); - let mut sk: GLWESecret> = GLWESecret::alloc(&tensor_key_infos); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&tensor_key_infos); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); @@ -217,14 +217,14 @@ where tensor_key_compressed.encrypt_sk(module, &sk, seed_xa, &mut source_xe, scratch.borrow()); - let mut tensor_key: GGLWETensorKey> = GGLWETensorKey::alloc(&tensor_key_infos); + let mut tensor_key: TensorKey> = TensorKey::alloc_from_infos(&tensor_key_infos); tensor_key.decompress(module, &tensor_key_compressed); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&tensor_key_infos); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&tensor_key_infos); let mut sk_ij_dft = module.vec_znx_dft_alloc(1, 1); let mut sk_ij_big = module.vec_znx_big_alloc(1, 1); - let mut sk_ij: GLWESecret> = GLWESecret::alloc_with(n.into(), 1_u32.into()); + let mut sk_ij: GLWESecret> = GLWESecret::alloc(n.into(), 1_u32.into()); let mut sk_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(rank, 1); for i in 0..rank { diff --git a/poulpy-core/src/tests/test_suite/external_product/gglwe_ksk.rs b/poulpy-core/src/tests/test_suite/external_product/gglwe_ksk.rs index 2cfa618..a369978 100644 --- a/poulpy-core/src/tests/test_suite/external_product/gglwe_ksk.rs +++ b/poulpy-core/src/tests/test_suite/external_product/gglwe_ksk.rs @@ -1,8 +1,8 @@ use poulpy_hal::{ api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, - VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, + VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxCopy, VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotateInplace, VecZnxSub, VecZnxSubInplace, VecZnxSubScalarInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, @@ -18,8 +18,8 @@ use poulpy_hal::{ use crate::{ encryption::SIGMA, layouts::{ - GGLWESwitchingKey, GGLWESwitchingKeyLayout, GGSWCiphertext, GGSWCiphertextLayout, GLWESecret, - prepared::{GGSWCiphertextPrepared, GLWESecretPrepared, PrepareAlloc}, + GGSW, GGSWCiphertextLayout, GLWESecret, GLWESwitchingKey, GLWESwitchingKeyLayout, + prepared::{GGSWPrepared, GLWESecretPrepared, PrepareAlloc}, }, noise::noise_ggsw_product, }; @@ -27,7 +27,7 @@ use crate::{ #[allow(clippy::too_many_arguments)] pub fn test_gglwe_switching_key_external_product(module: &Module) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -41,9 +41,9 @@ where + VecZnxNormalize + VecZnxSub + SvpPrepare - + SvpPPolAllocBytes + + SvpPPolBytesOf + SvpPPolAlloc - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxBigAddInplace + VecZnxBigAddSmallInplace + VecZnxSwitchRing @@ -81,7 +81,7 @@ where let dnum: usize = k_in.div_ceil(base2k * di); let dsize_in: usize = 1; - let gglwe_in_infos: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + let gglwe_in_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), base2k: base2k.into(), k: k_in.into(), @@ -91,7 +91,7 @@ where rank_out: rank_out.into(), }; - let gglwe_out_infos: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + let gglwe_out_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), base2k: base2k.into(), k: k_out.into(), @@ -110,9 +110,9 @@ where rank: rank_out.into(), }; - let mut ct_gglwe_in: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&gglwe_in_infos); - let mut ct_gglwe_out: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&gglwe_out_infos); - let mut ct_rgsw: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_infos); + let mut ct_gglwe_in: GLWESwitchingKey> = GLWESwitchingKey::alloc_from_infos(&gglwe_in_infos); + let mut ct_gglwe_out: GLWESwitchingKey> = GLWESwitchingKey::alloc_from_infos(&gglwe_out_infos); + let mut ct_rgsw: GGSW> = GGSW::alloc_from_infos(&ggsw_infos); let mut pt_rgsw: ScalarZnx> = ScalarZnx::alloc(n, 1); @@ -121,14 +121,14 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWESwitchingKey::encrypt_sk_scratch_space(module, &gglwe_in_infos) - | GGLWESwitchingKey::external_product_scratch_space( + GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &gglwe_in_infos) + | GLWESwitchingKey::external_product_tmp_bytes( module, &gglwe_out_infos, &gglwe_in_infos, &ggsw_infos, ) - | GGSWCiphertext::encrypt_sk_scratch_space(module, &ggsw_infos), + | GGSW::encrypt_sk_tmp_bytes(module, &ggsw_infos), ); let r: usize = 1; @@ -137,10 +137,10 @@ where let var_xs: f64 = 0.5; - let mut sk_in: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_in.into()); + let mut sk_in: GLWESecret> = GLWESecret::alloc(n.into(), rank_in.into()); sk_in.fill_ternary_prob(var_xs, &mut source_xs); - let mut sk_out: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_out.into()); + let mut sk_out: GLWESecret> = GLWESecret::alloc(n.into(), rank_out.into()); sk_out.fill_ternary_prob(var_xs, &mut source_xs); let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); @@ -163,7 +163,7 @@ where scratch.borrow(), ); - let ct_rgsw_prepared: GGSWCiphertextPrepared, B> = ct_rgsw.prepare_alloc(module, scratch.borrow()); + let ct_rgsw_prepared: GGSWPrepared, B> = ct_rgsw.prepare_alloc(module, scratch.borrow()); // gglwe_(m) (x) RGSW_(X^k) = gglwe_(m * X^k) ct_gglwe_out.external_product(module, &ct_gglwe_in, &ct_rgsw_prepared, scratch.borrow()); @@ -209,7 +209,7 @@ where #[allow(clippy::too_many_arguments)] pub fn test_gglwe_switching_key_external_product_inplace(module: &Module) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -223,9 +223,9 @@ where + VecZnxNormalize + VecZnxSub + SvpPrepare - + SvpPPolAllocBytes + + SvpPPolBytesOf + SvpPPolAlloc - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxBigAddInplace + VecZnxBigAddSmallInplace + VecZnxSwitchRing @@ -263,7 +263,7 @@ where let dsize_in: usize = 1; - let gglwe_out_infos: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + let gglwe_out_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), base2k: base2k.into(), k: k_out.into(), @@ -282,8 +282,8 @@ where rank: rank_out.into(), }; - let mut ct_gglwe: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&gglwe_out_infos); - let mut ct_rgsw: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_infos); + let mut ct_gglwe: GLWESwitchingKey> = GLWESwitchingKey::alloc_from_infos(&gglwe_out_infos); + let mut ct_rgsw: GGSW> = GGSW::alloc_from_infos(&ggsw_infos); let mut pt_rgsw: ScalarZnx> = ScalarZnx::alloc(n, 1); @@ -292,9 +292,9 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWESwitchingKey::encrypt_sk_scratch_space(module, &gglwe_out_infos) - | GGLWESwitchingKey::external_product_inplace_scratch_space(module, &gglwe_out_infos, &ggsw_infos) - | GGSWCiphertext::encrypt_sk_scratch_space(module, &ggsw_infos), + GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &gglwe_out_infos) + | GLWESwitchingKey::external_product_inplace_tmp_bytes(module, &gglwe_out_infos, &ggsw_infos) + | GGSW::encrypt_sk_tmp_bytes(module, &ggsw_infos), ); let r: usize = 1; @@ -303,10 +303,10 @@ where let var_xs: f64 = 0.5; - let mut sk_in: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_in.into()); + let mut sk_in: GLWESecret> = GLWESecret::alloc(n.into(), rank_in.into()); sk_in.fill_ternary_prob(var_xs, &mut source_xs); - let mut sk_out: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_out.into()); + let mut sk_out: GLWESecret> = GLWESecret::alloc(n.into(), rank_out.into()); sk_out.fill_ternary_prob(var_xs, &mut source_xs); let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); @@ -329,7 +329,7 @@ where scratch.borrow(), ); - let ct_rgsw_prepared: GGSWCiphertextPrepared, B> = ct_rgsw.prepare_alloc(module, scratch.borrow()); + let ct_rgsw_prepared: GGSWPrepared, B> = ct_rgsw.prepare_alloc(module, scratch.borrow()); // gglwe_(m) (x) RGSW_(X^k) = gglwe_(m * X^k) ct_gglwe.external_product_inplace(module, &ct_rgsw_prepared, scratch.borrow()); diff --git a/poulpy-core/src/tests/test_suite/external_product/ggsw_ct.rs b/poulpy-core/src/tests/test_suite/external_product/ggsw_ct.rs index 464eac2..84a2f68 100644 --- a/poulpy-core/src/tests/test_suite/external_product/ggsw_ct.rs +++ b/poulpy-core/src/tests/test_suite/external_product/ggsw_ct.rs @@ -1,11 +1,11 @@ use poulpy_hal::{ api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, - VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAlloc, VecZnxDftAllocBytes, - VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxRotateInplace, VecZnxSub, VecZnxSubInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd, - VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, + VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAlloc, VecZnxDftApply, + VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, + VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotateInplace, VecZnxSub, VecZnxSubInplace, VmpApplyDftToDft, + VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, }, layouts::{Backend, Module, ScalarZnx, ScalarZnxToMut, ScratchOwned, ZnxViewMut}, oep::{ @@ -18,8 +18,8 @@ use poulpy_hal::{ use crate::{ encryption::SIGMA, layouts::{ - GGSWCiphertext, GGSWCiphertextLayout, GLWESecret, - prepared::{GGSWCiphertextPrepared, GLWESecretPrepared, PrepareAlloc}, + GGSW, GGSWCiphertextLayout, GLWESecret, + prepared::{GGSWPrepared, GLWESecretPrepared, PrepareAlloc}, }, noise::noise_ggsw_product, }; @@ -27,7 +27,7 @@ use crate::{ #[allow(clippy::too_many_arguments)] pub fn test_ggsw_external_product(module: &Module) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -41,9 +41,9 @@ where + VecZnxNormalize + VecZnxSub + SvpPrepare - + SvpPPolAllocBytes + + SvpPPolBytesOf + SvpPPolAlloc - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxBigAddInplace + VecZnxBigAddSmallInplace + VecZnxAddScalarInplace @@ -111,9 +111,9 @@ where rank: rank.into(), }; - let mut ggsw_in: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_in_infos); - let mut ggsw_out: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_out_infos); - let mut ggsw_apply: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_apply_infos); + let mut ggsw_in: GGSW> = GGSW::alloc_from_infos(&ggsw_in_infos); + let mut ggsw_out: GGSW> = GGSW::alloc_from_infos(&ggsw_out_infos); + let mut ggsw_apply: GGSW> = GGSW::alloc_from_infos(&ggsw_apply_infos); let mut pt_in: ScalarZnx> = ScalarZnx::alloc(n, 1); let mut pt_apply: ScalarZnx> = ScalarZnx::alloc(n, 1); @@ -128,12 +128,12 @@ where pt_apply.to_mut().raw_mut()[k] = 1; //X^{k} let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, &ggsw_apply_infos) - | GGSWCiphertext::encrypt_sk_scratch_space(module, &ggsw_in_infos) - | GGSWCiphertext::external_product_scratch_space(module, &ggsw_out_infos, &ggsw_in_infos, &ggsw_apply_infos), + GGSW::encrypt_sk_tmp_bytes(module, &ggsw_apply_infos) + | GGSW::encrypt_sk_tmp_bytes(module, &ggsw_in_infos) + | GGSW::external_product_tmp_bytes(module, &ggsw_out_infos, &ggsw_in_infos, &ggsw_apply_infos), ); - let mut sk: GLWESecret> = GLWESecret::alloc_with(n.into(), rank.into()); + let mut sk: GLWESecret> = GLWESecret::alloc(n.into(), rank.into()); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); @@ -155,7 +155,7 @@ where scratch.borrow(), ); - let ct_rhs_prepared: GGSWCiphertextPrepared, B> = ggsw_apply.prepare_alloc(module, scratch.borrow()); + let ct_rhs_prepared: GGSWPrepared, B> = ggsw_apply.prepare_alloc(module, scratch.borrow()); ggsw_out.external_product(module, &ggsw_in, &ct_rhs_prepared, scratch.borrow()); @@ -192,7 +192,7 @@ where #[allow(clippy::too_many_arguments)] pub fn test_ggsw_external_product_inplace(module: &Module) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -206,9 +206,9 @@ where + VecZnxNormalize + VecZnxSub + SvpPrepare - + SvpPPolAllocBytes + + SvpPPolBytesOf + SvpPPolAlloc - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxBigAddInplace + VecZnxBigAddSmallInplace + VecZnxAddScalarInplace @@ -265,8 +265,8 @@ where rank: rank.into(), }; - let mut ggsw_out: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_out_infos); - let mut ggsw_apply: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_apply_infos); + let mut ggsw_out: GGSW> = GGSW::alloc_from_infos(&ggsw_out_infos); + let mut ggsw_apply: GGSW> = GGSW::alloc_from_infos(&ggsw_apply_infos); let mut pt_in: ScalarZnx> = ScalarZnx::alloc(n, 1); let mut pt_apply: ScalarZnx> = ScalarZnx::alloc(n, 1); @@ -282,12 +282,12 @@ where pt_apply.to_mut().raw_mut()[k] = 1; //X^{k} let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, &ggsw_apply_infos) - | GGSWCiphertext::encrypt_sk_scratch_space(module, &ggsw_out_infos) - | GGSWCiphertext::external_product_inplace_scratch_space(module, &ggsw_out_infos, &ggsw_apply_infos), + GGSW::encrypt_sk_tmp_bytes(module, &ggsw_apply_infos) + | GGSW::encrypt_sk_tmp_bytes(module, &ggsw_out_infos) + | GGSW::external_product_inplace_tmp_bytes(module, &ggsw_out_infos, &ggsw_apply_infos), ); - let mut sk: GLWESecret> = GLWESecret::alloc_with(n.into(), rank.into()); + let mut sk: GLWESecret> = GLWESecret::alloc(n.into(), rank.into()); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); @@ -309,7 +309,7 @@ where scratch.borrow(), ); - let ct_rhs_prepared: GGSWCiphertextPrepared, B> = ggsw_apply.prepare_alloc(module, scratch.borrow()); + let ct_rhs_prepared: GGSWPrepared, B> = ggsw_apply.prepare_alloc(module, scratch.borrow()); ggsw_out.external_product_inplace(module, &ct_rhs_prepared, scratch.borrow()); diff --git a/poulpy-core/src/tests/test_suite/external_product/glwe_ct.rs b/poulpy-core/src/tests/test_suite/external_product/glwe_ct.rs index 1a21a0c..60026b8 100644 --- a/poulpy-core/src/tests/test_suite/external_product/glwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/external_product/glwe_ct.rs @@ -1,8 +1,8 @@ use poulpy_hal::{ api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, - VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, + VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotateInplace, VecZnxSub, VecZnxSubInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, }, @@ -17,8 +17,8 @@ use poulpy_hal::{ use crate::{ encryption::SIGMA, layouts::{ - GGSWCiphertext, GGSWCiphertextLayout, GLWECiphertext, GLWECiphertextLayout, GLWEPlaintext, GLWESecret, - prepared::{GGSWCiphertextPrepared, GLWESecretPrepared, PrepareAlloc}, + GGSW, GGSWCiphertextLayout, GLWE, GLWELayout, GLWEPlaintext, GLWESecret, + prepared::{GGSWPrepared, GLWESecretPrepared, PrepareAlloc}, }, noise::noise_ggsw_product, }; @@ -26,7 +26,7 @@ use crate::{ #[allow(clippy::too_many_arguments)] pub fn test_glwe_external_product(module: &Module) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -39,9 +39,9 @@ where + VecZnxNormalize + VecZnxSub + SvpPrepare - + SvpPPolAllocBytes + + SvpPPolBytesOf + SvpPPolAlloc - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxBigAddInplace + VecZnxBigAddSmallInplace + VecZnxNormalizeTmpBytes @@ -73,14 +73,14 @@ where let n: usize = module.n(); let dnum: usize = k_in.div_ceil(base2k * dsize); - let glwe_in_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_in_infos: GLWELayout = GLWELayout { n: n.into(), base2k: base2k.into(), k: k_in.into(), rank: rank.into(), }; - let glwe_out_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_out_infos: GLWELayout = GLWELayout { n: n.into(), base2k: base2k.into(), k: k_out.into(), @@ -96,11 +96,11 @@ where rank: rank.into(), }; - let mut ggsw_apply: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_apply_infos); - let mut glwe_in: GLWECiphertext> = GLWECiphertext::alloc(&glwe_in_infos); - let mut glwe_out: GLWECiphertext> = GLWECiphertext::alloc(&glwe_out_infos); + let mut ggsw_apply: GGSW> = GGSW::alloc_from_infos(&ggsw_apply_infos); + let mut glwe_in: GLWE> = GLWE::alloc_from_infos(&glwe_in_infos); + let mut glwe_out: GLWE> = GLWE::alloc_from_infos(&glwe_out_infos); let mut pt_ggsw: ScalarZnx> = ScalarZnx::alloc(n, 1); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_in_infos); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_in_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -116,12 +116,12 @@ where pt_ggsw.raw_mut()[k] = 1; // X^{k} let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, &ggsw_apply_infos) - | GLWECiphertext::encrypt_sk_scratch_space(module, &glwe_in_infos) - | GLWECiphertext::external_product_scratch_space(module, &glwe_out_infos, &glwe_in_infos, &ggsw_apply_infos), + GGSW::encrypt_sk_tmp_bytes(module, &ggsw_apply_infos) + | GLWE::encrypt_sk_tmp_bytes(module, &glwe_in_infos) + | GLWE::external_product_tmp_bytes(module, &glwe_out_infos, &glwe_in_infos, &ggsw_apply_infos), ); - let mut sk: GLWESecret> = GLWESecret::alloc_with(n.into(), rank.into()); + let mut sk: GLWESecret> = GLWESecret::alloc(n.into(), rank.into()); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); @@ -143,7 +143,7 @@ where scratch.borrow(), ); - let ct_ggsw_prepared: GGSWCiphertextPrepared, B> = ggsw_apply.prepare_alloc(module, scratch.borrow()); + let ct_ggsw_prepared: GGSWPrepared, B> = ggsw_apply.prepare_alloc(module, scratch.borrow()); glwe_out.external_product(module, &glwe_in, &ct_ggsw_prepared, scratch.borrow()); @@ -178,7 +178,7 @@ where #[allow(clippy::too_many_arguments)] pub fn test_glwe_external_product_inplace(module: &Module) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -191,9 +191,9 @@ where + VecZnxNormalize + VecZnxSub + SvpPrepare - + SvpPPolAllocBytes + + SvpPPolBytesOf + SvpPPolAlloc - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxBigAddInplace + VecZnxBigAddSmallInplace + VecZnxNormalizeTmpBytes @@ -224,7 +224,7 @@ where let n: usize = module.n(); let dnum: usize = k_out.div_ceil(base2k * dsize); - let glwe_out_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_out_infos: GLWELayout = GLWELayout { n: n.into(), base2k: base2k.into(), k: k_out.into(), @@ -240,10 +240,10 @@ where rank: rank.into(), }; - let mut ggsw_apply: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_apply_infos); - let mut glwe_out: GLWECiphertext> = GLWECiphertext::alloc(&glwe_out_infos); + let mut ggsw_apply: GGSW> = GGSW::alloc_from_infos(&ggsw_apply_infos); + let mut glwe_out: GLWE> = GLWE::alloc_from_infos(&glwe_out_infos); let mut pt_ggsw: ScalarZnx> = ScalarZnx::alloc(n, 1); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_out_infos); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_out_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -259,12 +259,12 @@ where pt_ggsw.raw_mut()[k] = 1; // X^{k} let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, &ggsw_apply_infos) - | GLWECiphertext::encrypt_sk_scratch_space(module, &glwe_out_infos) - | GLWECiphertext::external_product_inplace_scratch_space(module, &glwe_out_infos, &ggsw_apply_infos), + GGSW::encrypt_sk_tmp_bytes(module, &ggsw_apply_infos) + | GLWE::encrypt_sk_tmp_bytes(module, &glwe_out_infos) + | GLWE::external_product_inplace_tmp_bytes(module, &glwe_out_infos, &ggsw_apply_infos), ); - let mut sk: GLWESecret> = GLWESecret::alloc_with(n.into(), rank.into()); + let mut sk: GLWESecret> = GLWESecret::alloc(n.into(), rank.into()); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); @@ -286,7 +286,7 @@ where scratch.borrow(), ); - let ct_ggsw_prepared: GGSWCiphertextPrepared, B> = ggsw_apply.prepare_alloc(module, scratch.borrow()); + let ct_ggsw_prepared: GGSWPrepared, B> = ggsw_apply.prepare_alloc(module, scratch.borrow()); glwe_out.external_product_inplace(module, &ct_ggsw_prepared, scratch.borrow()); diff --git a/poulpy-core/src/tests/test_suite/keyswitch/gglwe_ct.rs b/poulpy-core/src/tests/test_suite/keyswitch/gglwe_ct.rs index f131556..c667848 100644 --- a/poulpy-core/src/tests/test_suite/keyswitch/gglwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/keyswitch/gglwe_ct.rs @@ -1,11 +1,11 @@ use poulpy_hal::{ api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, - VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VecZnxDftApply, - VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, - VecZnxSubInplace, VecZnxSubScalarInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, - VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, + VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, + VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, + VecZnxSubScalarInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, + VmpPrepare, }, layouts::{Backend, Module, ScratchOwned}, oep::{ @@ -18,15 +18,15 @@ use poulpy_hal::{ use crate::{ encryption::SIGMA, layouts::{ - GGLWESwitchingKey, GGLWESwitchingKeyLayout, GLWESecret, - prepared::{GGLWESwitchingKeyPrepared, GLWESecretPrepared, PrepareAlloc}, + GLWESecret, GLWESwitchingKey, GLWESwitchingKeyLayout, + prepared::{GLWESecretPrepared, GLWESwitchingKeyPrepared, PrepareAlloc}, }, noise::log2_std_noise_gglwe_product, }; pub fn test_gglwe_switching_key_keyswitch(module: &Module) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -39,9 +39,9 @@ where + VecZnxNormalize + VecZnxSub + SvpPrepare - + SvpPPolAllocBytes + + SvpPPolBytesOf + SvpPPolAlloc - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxBigAddInplace + VecZnxBigAddSmallInplace + VecZnxNormalizeTmpBytes @@ -80,7 +80,7 @@ where let dnum_apply: usize = k_in.div_ceil(base2k * di); let dsize_in: usize = 1; - let gglwe_s0s1_infos: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + let gglwe_s0s1_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), base2k: base2k.into(), k: k_in.into(), @@ -90,7 +90,7 @@ where rank_out: rank_out_s0s1.into(), }; - let gglwe_s1s2_infos: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + let gglwe_s1s2_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), base2k: base2k.into(), k: k_ksk.into(), @@ -100,7 +100,7 @@ where rank_out: rank_out_s1s2.into(), }; - let gglwe_s0s2_infos: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + let gglwe_s0s2_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), base2k: base2k.into(), k: k_out.into(), @@ -110,33 +110,33 @@ where rank_out: rank_out_s1s2.into(), }; - let mut gglwe_s0s1: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&gglwe_s0s1_infos); - let mut gglwe_s1s2: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&gglwe_s1s2_infos); - let mut gglwe_s0s2: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&gglwe_s0s2_infos); + let mut gglwe_s0s1: GLWESwitchingKey> = GLWESwitchingKey::alloc_from_infos(&gglwe_s0s1_infos); + let mut gglwe_s1s2: GLWESwitchingKey> = GLWESwitchingKey::alloc_from_infos(&gglwe_s1s2_infos); + let mut gglwe_s0s2: GLWESwitchingKey> = GLWESwitchingKey::alloc_from_infos(&gglwe_s0s2_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch_enc: ScratchOwned = ScratchOwned::alloc( - GGLWESwitchingKey::encrypt_sk_scratch_space(module, &gglwe_s0s1_infos) - | GGLWESwitchingKey::encrypt_sk_scratch_space(module, &gglwe_s1s2_infos) - | GGLWESwitchingKey::encrypt_sk_scratch_space(module, &gglwe_s0s2_infos), + GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &gglwe_s0s1_infos) + | GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &gglwe_s1s2_infos) + | GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &gglwe_s0s2_infos), ); - let mut scratch_apply: ScratchOwned = ScratchOwned::alloc(GGLWESwitchingKey::keyswitch_scratch_space( + let mut scratch_apply: ScratchOwned = ScratchOwned::alloc(GLWESwitchingKey::keyswitch_tmp_bytes( module, &gglwe_s0s1_infos, &gglwe_s0s2_infos, &gglwe_s1s2_infos, )); - let mut sk0: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_in_s0s1.into()); + let mut sk0: GLWESecret> = GLWESecret::alloc(n.into(), rank_in_s0s1.into()); sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_out_s0s1.into()); + let mut sk1: GLWESecret> = GLWESecret::alloc(n.into(), rank_out_s0s1.into()); sk1.fill_ternary_prob(0.5, &mut source_xs); - let mut sk2: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_out_s1s2.into()); + let mut sk2: GLWESecret> = GLWESecret::alloc(n.into(), rank_out_s1s2.into()); sk2.fill_ternary_prob(0.5, &mut source_xs); let sk2_prepared: GLWESecretPrepared, B> = sk2.prepare_alloc(module, scratch_apply.borrow()); @@ -160,7 +160,7 @@ where scratch_enc.borrow(), ); - let gglwe_s1s2_prepared: GGLWESwitchingKeyPrepared, B> = + let gglwe_s1s2_prepared: GLWESwitchingKeyPrepared, B> = gglwe_s1s2.prepare_alloc(module, scratch_apply.borrow()); // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) @@ -196,7 +196,7 @@ where #[allow(clippy::too_many_arguments)] pub fn test_gglwe_switching_key_keyswitch_inplace(module: &Module) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -209,9 +209,9 @@ where + VecZnxNormalize + VecZnxSub + SvpPrepare - + SvpPPolAllocBytes + + SvpPPolBytesOf + SvpPPolAlloc - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxBigAddInplace + VecZnxBigAddSmallInplace + VecZnxNormalizeTmpBytes @@ -246,7 +246,7 @@ where let dnum: usize = k_out.div_ceil(base2k * di); let dsize_in: usize = 1; - let gglwe_s0s1_infos: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + let gglwe_s0s1_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), base2k: base2k.into(), k: k_out.into(), @@ -256,7 +256,7 @@ where rank_out: rank_out.into(), }; - let gglwe_s1s2_infos: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + let gglwe_s1s2_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), base2k: base2k.into(), k: k_ksk.into(), @@ -266,18 +266,18 @@ where rank_out: rank_out.into(), }; - let mut gglwe_s0s1: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&gglwe_s0s1_infos); - let mut gglwe_s1s2: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&gglwe_s1s2_infos); + let mut gglwe_s0s1: GLWESwitchingKey> = GLWESwitchingKey::alloc_from_infos(&gglwe_s0s1_infos); + let mut gglwe_s1s2: GLWESwitchingKey> = GLWESwitchingKey::alloc_from_infos(&gglwe_s1s2_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch_enc: ScratchOwned = ScratchOwned::alloc( - GGLWESwitchingKey::encrypt_sk_scratch_space(module, &gglwe_s0s1_infos) - | GGLWESwitchingKey::encrypt_sk_scratch_space(module, &gglwe_s1s2_infos), + GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &gglwe_s0s1_infos) + | GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &gglwe_s1s2_infos), ); - let mut scratch_apply: ScratchOwned = ScratchOwned::alloc(GGLWESwitchingKey::keyswitch_inplace_scratch_space( + let mut scratch_apply: ScratchOwned = ScratchOwned::alloc(GLWESwitchingKey::keyswitch_inplace_tmp_bytes( module, &gglwe_s0s1_infos, &gglwe_s1s2_infos, @@ -285,13 +285,13 @@ where let var_xs: f64 = 0.5; - let mut sk0: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_in.into()); + let mut sk0: GLWESecret> = GLWESecret::alloc(n.into(), rank_in.into()); sk0.fill_ternary_prob(var_xs, &mut source_xs); - let mut sk1: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_out.into()); + let mut sk1: GLWESecret> = GLWESecret::alloc(n.into(), rank_out.into()); sk1.fill_ternary_prob(var_xs, &mut source_xs); - let mut sk2: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_out.into()); + let mut sk2: GLWESecret> = GLWESecret::alloc(n.into(), rank_out.into()); sk2.fill_ternary_prob(var_xs, &mut source_xs); let sk2_prepared: GLWESecretPrepared, B> = sk2.prepare_alloc(module, scratch_apply.borrow()); @@ -315,13 +315,13 @@ where scratch_enc.borrow(), ); - let gglwe_s1s2_prepared: GGLWESwitchingKeyPrepared, B> = + let gglwe_s1s2_prepared: GLWESwitchingKeyPrepared, B> = gglwe_s1s2.prepare_alloc(module, scratch_apply.borrow()); // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) gglwe_s0s1.keyswitch_inplace(module, &gglwe_s1s2_prepared, scratch_apply.borrow()); - let gglwe_s0s2: GGLWESwitchingKey> = gglwe_s0s1; + let gglwe_s0s2: GLWESwitchingKey> = gglwe_s0s1; let max_noise: f64 = log2_std_noise_gglwe_product( n as f64, diff --git a/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs b/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs index c7e7c82..4d3d556 100644 --- a/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs +++ b/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs @@ -1,9 +1,9 @@ use poulpy_hal::{ api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, - VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAddInplace, VecZnxDftAlloc, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, + VecZnxBigAlloc, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAddInplace, VecZnxDftAlloc, + VecZnxDftApply, VecZnxDftBytesOf, VecZnxDftCopy, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, }, @@ -18,9 +18,8 @@ use poulpy_hal::{ use crate::{ encryption::SIGMA, layouts::{ - GGLWESwitchingKey, GGLWESwitchingKeyLayout, GGLWETensorKey, GGLWETensorKeyLayout, GGSWCiphertext, GGSWCiphertextLayout, - GLWESecret, - prepared::{GGLWESwitchingKeyPrepared, GGLWETensorKeyPrepared, GLWESecretPrepared, PrepareAlloc}, + GGSW, GGSWCiphertextLayout, GLWESecret, GLWESwitchingKey, GLWESwitchingKeyLayout, TensorKey, TensorKeyLayout, + prepared::{GLWESecretPrepared, GLWESwitchingKeyPrepared, PrepareAlloc, TensorKeyPrepared}, }, noise::noise_ggsw_keyswitch, }; @@ -28,7 +27,7 @@ use crate::{ #[allow(clippy::too_many_arguments)] pub fn test_ggsw_keyswitch(module: &Module) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -41,9 +40,9 @@ where + VecZnxNormalize + VecZnxSub + SvpPrepare - + SvpPPolAllocBytes + + SvpPPolBytesOf + SvpPPolAlloc - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxBigAddInplace + VecZnxBigAddSmallInplace + VecZnxNormalizeTmpBytes @@ -103,7 +102,7 @@ where rank: rank.into(), }; - let tsk_infos: GGLWETensorKeyLayout = GGLWETensorKeyLayout { + let tsk_infos: TensorKeyLayout = TensorKeyLayout { n: n.into(), base2k: base2k.into(), k: k_tsk.into(), @@ -112,7 +111,7 @@ where rank: rank.into(), }; - let ksk_apply_infos: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + let ksk_apply_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), base2k: base2k.into(), k: k_ksk.into(), @@ -122,10 +121,10 @@ where rank_out: rank.into(), }; - let mut ggsw_in: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_in_infos); - let mut ggsw_out: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_out_infos); - let mut tsk: GGLWETensorKey> = GGLWETensorKey::alloc(&tsk_infos); - let mut ksk: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&ksk_apply_infos); + let mut ggsw_in: GGSW> = GGSW::alloc_from_infos(&ggsw_in_infos); + let mut ggsw_out: GGSW> = GGSW::alloc_from_infos(&ggsw_out_infos); + let mut tsk: TensorKey> = TensorKey::alloc_from_infos(&tsk_infos); + let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc_from_infos(&ksk_apply_infos); let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); let mut source_xs: Source = Source::new([0u8; 32]); @@ -133,10 +132,10 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, &ggsw_in_infos) - | GGLWESwitchingKey::encrypt_sk_scratch_space(module, &ksk_apply_infos) - | GGLWETensorKey::encrypt_sk_scratch_space(module, &tsk_infos) - | GGSWCiphertext::keyswitch_scratch_space( + GGSW::encrypt_sk_tmp_bytes(module, &ggsw_in_infos) + | GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &ksk_apply_infos) + | TensorKey::encrypt_sk_tmp_bytes(module, &tsk_infos) + | GGSW::keyswitch_tmp_bytes( module, &ggsw_out_infos, &ggsw_in_infos, @@ -147,11 +146,11 @@ where let var_xs: f64 = 0.5; - let mut sk_in: GLWESecret> = GLWESecret::alloc_with(n.into(), rank.into()); + let mut sk_in: GLWESecret> = GLWESecret::alloc(n.into(), rank.into()); sk_in.fill_ternary_prob(var_xs, &mut source_xs); let sk_in_dft: GLWESecretPrepared, B> = sk_in.prepare_alloc(module, scratch.borrow()); - let mut sk_out: GLWESecret> = GLWESecret::alloc_with(n.into(), rank.into()); + let mut sk_out: GLWESecret> = GLWESecret::alloc(n.into(), rank.into()); sk_out.fill_ternary_prob(var_xs, &mut source_xs); let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); @@ -182,8 +181,8 @@ where scratch.borrow(), ); - let ksk_prepared: GGLWESwitchingKeyPrepared, B> = ksk.prepare_alloc(module, scratch.borrow()); - let tsk_prepared: GGLWETensorKeyPrepared, B> = tsk.prepare_alloc(module, scratch.borrow()); + let ksk_prepared: GLWESwitchingKeyPrepared, B> = ksk.prepare_alloc(module, scratch.borrow()); + let tsk_prepared: TensorKeyPrepared, B> = tsk.prepare_alloc(module, scratch.borrow()); ggsw_out.keyswitch( module, @@ -217,7 +216,7 @@ where #[allow(clippy::too_many_arguments)] pub fn test_ggsw_keyswitch_inplace(module: &Module) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -230,9 +229,9 @@ where + VecZnxNormalize + VecZnxSub + SvpPrepare - + SvpPPolAllocBytes + + SvpPPolBytesOf + SvpPPolAlloc - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxBigAddInplace + VecZnxBigAddSmallInplace + VecZnxNormalizeTmpBytes @@ -282,7 +281,7 @@ where rank: rank.into(), }; - let tsk_infos: GGLWETensorKeyLayout = GGLWETensorKeyLayout { + let tsk_infos: TensorKeyLayout = TensorKeyLayout { n: n.into(), base2k: base2k.into(), k: k_tsk.into(), @@ -291,7 +290,7 @@ where rank: rank.into(), }; - let ksk_apply_infos: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + let ksk_apply_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), base2k: base2k.into(), k: k_ksk.into(), @@ -301,9 +300,9 @@ where rank_out: rank.into(), }; - let mut ggsw_out: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_out_infos); - let mut tsk: GGLWETensorKey> = GGLWETensorKey::alloc(&tsk_infos); - let mut ksk: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&ksk_apply_infos); + let mut ggsw_out: GGSW> = GGSW::alloc_from_infos(&ggsw_out_infos); + let mut tsk: TensorKey> = TensorKey::alloc_from_infos(&tsk_infos); + let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc_from_infos(&ksk_apply_infos); let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); let mut source_xs: Source = Source::new([0u8; 32]); @@ -311,19 +310,19 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, &ggsw_out_infos) - | GGLWESwitchingKey::encrypt_sk_scratch_space(module, &ksk_apply_infos) - | GGLWETensorKey::encrypt_sk_scratch_space(module, &tsk_infos) - | GGSWCiphertext::keyswitch_inplace_scratch_space(module, &ggsw_out_infos, &ksk_apply_infos, &tsk_infos), + GGSW::encrypt_sk_tmp_bytes(module, &ggsw_out_infos) + | GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &ksk_apply_infos) + | TensorKey::encrypt_sk_tmp_bytes(module, &tsk_infos) + | GGSW::keyswitch_inplace_tmp_bytes(module, &ggsw_out_infos, &ksk_apply_infos, &tsk_infos), ); let var_xs: f64 = 0.5; - let mut sk_in: GLWESecret> = GLWESecret::alloc_with(n.into(), rank.into()); + let mut sk_in: GLWESecret> = GLWESecret::alloc(n.into(), rank.into()); sk_in.fill_ternary_prob(var_xs, &mut source_xs); let sk_in_dft: GLWESecretPrepared, B> = sk_in.prepare_alloc(module, scratch.borrow()); - let mut sk_out: GLWESecret> = GLWESecret::alloc_with(n.into(), rank.into()); + let mut sk_out: GLWESecret> = GLWESecret::alloc(n.into(), rank.into()); sk_out.fill_ternary_prob(var_xs, &mut source_xs); let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); @@ -354,8 +353,8 @@ where scratch.borrow(), ); - let ksk_prepared: GGLWESwitchingKeyPrepared, B> = ksk.prepare_alloc(module, scratch.borrow()); - let tsk_prepared: GGLWETensorKeyPrepared, B> = tsk.prepare_alloc(module, scratch.borrow()); + let ksk_prepared: GLWESwitchingKeyPrepared, B> = ksk.prepare_alloc(module, scratch.borrow()); + let tsk_prepared: TensorKeyPrepared, B> = tsk.prepare_alloc(module, scratch.borrow()); ggsw_out.keyswitch_inplace(module, &ksk_prepared, &tsk_prepared, scratch.borrow()); diff --git a/poulpy-core/src/tests/test_suite/keyswitch/glwe_ct.rs b/poulpy-core/src/tests/test_suite/keyswitch/glwe_ct.rs index d745f1d..2ea6e75 100644 --- a/poulpy-core/src/tests/test_suite/keyswitch/glwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/keyswitch/glwe_ct.rs @@ -1,11 +1,10 @@ use poulpy_hal::{ api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, - VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VecZnxDftApply, - VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, - VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, - VmpPrepare, + VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, + VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, + VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, }, layouts::{Backend, Module, ScratchOwned}, oep::{ @@ -18,8 +17,8 @@ use poulpy_hal::{ use crate::{ encryption::SIGMA, layouts::{ - GGLWESwitchingKey, GGLWESwitchingKeyLayout, GLWECiphertext, GLWECiphertextLayout, GLWEPlaintext, GLWESecret, - prepared::{GGLWESwitchingKeyPrepared, GLWESecretPrepared, PrepareAlloc}, + GLWE, GLWELayout, GLWEPlaintext, GLWESecret, GLWESwitchingKey, GLWESwitchingKeyLayout, + prepared::{GLWESecretPrepared, GLWESwitchingKeyPrepared, PrepareAlloc}, }, noise::log2_std_noise_gglwe_product, }; @@ -27,7 +26,7 @@ use crate::{ #[allow(clippy::too_many_arguments)] pub fn test_glwe_keyswitch(module: &Module) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -40,9 +39,9 @@ where + VecZnxNormalize + VecZnxSub + SvpPrepare - + SvpPPolAllocBytes + + SvpPPolBytesOf + SvpPPolAlloc - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxBigAddInplace + VecZnxBigAddSmallInplace + VecZnxNormalizeTmpBytes @@ -77,21 +76,21 @@ where let n: usize = module.n(); let dnum: usize = k_in.div_ceil(base2k * dsize); - let glwe_in_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_in_infos: GLWELayout = GLWELayout { n: n.into(), base2k: base2k.into(), k: k_in.into(), rank: rank_in.into(), }; - let glwe_out_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_out_infos: GLWELayout = GLWELayout { n: n.into(), base2k: base2k.into(), k: k_out.into(), rank: rank_out.into(), }; - let key_apply: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + let key_apply: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), base2k: base2k.into(), k: k_ksk.into(), @@ -101,10 +100,10 @@ where rank_out: rank_out.into(), }; - let mut ksk: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&key_apply); - let mut glwe_in: GLWECiphertext> = GLWECiphertext::alloc(&glwe_in_infos); - let mut glwe_out: GLWECiphertext> = GLWECiphertext::alloc(&glwe_out_infos); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_in_infos); + let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc_from_infos(&key_apply); + let mut glwe_in: GLWE> = GLWE::alloc_from_infos(&glwe_in_infos); + let mut glwe_out: GLWE> = GLWE::alloc_from_infos(&glwe_out_infos); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_in_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -113,16 +112,16 @@ where module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWESwitchingKey::encrypt_sk_scratch_space(module, &key_apply) - | GLWECiphertext::encrypt_sk_scratch_space(module, &glwe_in_infos) - | GLWECiphertext::keyswitch_scratch_space(module, &glwe_out_infos, &glwe_in_infos, &key_apply), + GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &key_apply) + | GLWE::encrypt_sk_tmp_bytes(module, &glwe_in_infos) + | GLWE::keyswitch_tmp_bytes(module, &glwe_out_infos, &glwe_in_infos, &key_apply), ); - let mut sk_in: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_in.into()); + let mut sk_in: GLWESecret> = GLWESecret::alloc(n.into(), rank_in.into()); sk_in.fill_ternary_prob(0.5, &mut source_xs); let sk_in_prepared: GLWESecretPrepared, B> = sk_in.prepare_alloc(module, scratch.borrow()); - let mut sk_out: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_out.into()); + let mut sk_out: GLWESecret> = GLWESecret::alloc(n.into(), rank_out.into()); sk_out.fill_ternary_prob(0.5, &mut source_xs); let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); @@ -144,7 +143,7 @@ where scratch.borrow(), ); - let ksk_prepared: GGLWESwitchingKeyPrepared, B> = ksk.prepare_alloc(module, scratch.borrow()); + let ksk_prepared: GLWESwitchingKeyPrepared, B> = ksk.prepare_alloc(module, scratch.borrow()); glwe_out.keyswitch(module, &glwe_in, &ksk_prepared, scratch.borrow()); @@ -169,7 +168,7 @@ where pub fn test_glwe_keyswitch_inplace(module: &Module) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -182,9 +181,9 @@ where + VecZnxNormalize + VecZnxSub + SvpPrepare - + SvpPPolAllocBytes + + SvpPPolBytesOf + SvpPPolAlloc - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxBigAddInplace + VecZnxBigAddSmallInplace + VecZnxNormalizeTmpBytes @@ -217,14 +216,14 @@ where let n: usize = module.n(); let dnum: usize = k_out.div_ceil(base2k * dsize); - let glwe_out_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_out_infos: GLWELayout = GLWELayout { n: n.into(), base2k: base2k.into(), k: k_out.into(), rank: rank.into(), }; - let key_apply_infos: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + let key_apply_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), base2k: base2k.into(), k: k_ksk.into(), @@ -234,9 +233,9 @@ where rank_out: rank.into(), }; - let mut key_apply: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&key_apply_infos); - let mut glwe_out: GLWECiphertext> = GLWECiphertext::alloc(&glwe_out_infos); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_out_infos); + let mut key_apply: GLWESwitchingKey> = GLWESwitchingKey::alloc_from_infos(&key_apply_infos); + let mut glwe_out: GLWE> = GLWE::alloc_from_infos(&glwe_out_infos); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_out_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -245,16 +244,16 @@ where module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWESwitchingKey::encrypt_sk_scratch_space(module, &key_apply_infos) - | GLWECiphertext::encrypt_sk_scratch_space(module, &glwe_out_infos) - | GLWECiphertext::keyswitch_inplace_scratch_space(module, &glwe_out_infos, &key_apply_infos), + GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &key_apply_infos) + | GLWE::encrypt_sk_tmp_bytes(module, &glwe_out_infos) + | GLWE::keyswitch_inplace_tmp_bytes(module, &glwe_out_infos, &key_apply_infos), ); - let mut sk_in: GLWESecret> = GLWESecret::alloc_with(n.into(), rank.into()); + let mut sk_in: GLWESecret> = GLWESecret::alloc(n.into(), rank.into()); sk_in.fill_ternary_prob(0.5, &mut source_xs); let sk_in_prepared: GLWESecretPrepared, B> = sk_in.prepare_alloc(module, scratch.borrow()); - let mut sk_out: GLWESecret> = GLWESecret::alloc_with(n.into(), rank.into()); + let mut sk_out: GLWESecret> = GLWESecret::alloc(n.into(), rank.into()); sk_out.fill_ternary_prob(0.5, &mut source_xs); let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); @@ -276,7 +275,7 @@ where scratch.borrow(), ); - let ksk_prepared: GGLWESwitchingKeyPrepared, B> = key_apply.prepare_alloc(module, scratch.borrow()); + let ksk_prepared: GLWESwitchingKeyPrepared, B> = key_apply.prepare_alloc(module, scratch.borrow()); glwe_out.keyswitch_inplace(module, &ksk_prepared, scratch.borrow()); diff --git a/poulpy-core/src/tests/test_suite/keyswitch/lwe_ct.rs b/poulpy-core/src/tests/test_suite/keyswitch/lwe_ct.rs index 1c29c70..130c356 100644 --- a/poulpy-core/src/tests/test_suite/keyswitch/lwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/keyswitch/lwe_ct.rs @@ -1,9 +1,9 @@ use poulpy_hal::{ api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigAddInplace, - VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, + VecZnxBigAddSmallInplace, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftApply, + VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, ZnAddNormal, ZnFillUniform, ZnNormalizeInplace, }, @@ -16,13 +16,13 @@ use poulpy_hal::{ }; use crate::layouts::{ - LWECiphertext, LWECiphertextLayout, LWEPlaintext, LWESecret, LWESwitchingKey, LWESwitchingKeyLayout, + LWE, LWECiphertextLayout, LWEPlaintext, LWESecret, LWESwitchingKey, LWESwitchingKeyLayout, prepared::{LWESwitchingKeyPrepared, PrepareAlloc}, }; pub fn test_lwe_keyswitch(module: &Module) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -35,9 +35,9 @@ where + VecZnxNormalize + VecZnxSub + SvpPrepare - + SvpPPolAllocBytes + + SvpPPolBytesOf + SvpPPolAlloc - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxBigAddInplace + VecZnxBigAddSmallInplace + VecZnxNormalizeTmpBytes @@ -99,8 +99,8 @@ where }; let mut scratch: ScratchOwned = ScratchOwned::alloc( - LWESwitchingKey::encrypt_sk_scratch_space(module, &key_apply_infos) - | LWECiphertext::keyswitch_scratch_space(module, &lwe_out_infos, &lwe_in_infos, &key_apply_infos), + LWESwitchingKey::encrypt_sk_tmp_bytes(module, &key_apply_infos) + | LWE::keyswitch_tmp_bytes(module, &lwe_out_infos, &lwe_in_infos, &key_apply_infos), ); let mut sk_lwe_in: LWESecret> = LWESecret::alloc(n_lwe_in.into()); @@ -111,10 +111,10 @@ where let data: i64 = 17; - let mut lwe_pt_in: LWEPlaintext> = LWEPlaintext::alloc_with(base2k.into(), k_lwe_pt.into()); + let mut lwe_pt_in: LWEPlaintext> = LWEPlaintext::alloc(base2k.into(), k_lwe_pt.into()); lwe_pt_in.encode_i64(data, k_lwe_pt.into()); - let mut lwe_ct_in: LWECiphertext> = LWECiphertext::alloc(&lwe_in_infos); + let mut lwe_ct_in: LWE> = LWE::alloc_from_infos(&lwe_in_infos); lwe_ct_in.encrypt_sk( module, &lwe_pt_in, @@ -123,7 +123,7 @@ where &mut source_xe, ); - let mut ksk: LWESwitchingKey> = LWESwitchingKey::alloc(&key_apply_infos); + let mut ksk: LWESwitchingKey> = LWESwitchingKey::alloc_from_infos(&key_apply_infos); ksk.encrypt_sk( module, @@ -134,13 +134,13 @@ where scratch.borrow(), ); - let mut lwe_ct_out: LWECiphertext> = LWECiphertext::alloc(&lwe_out_infos); + let mut lwe_ct_out: LWE> = LWE::alloc_from_infos(&lwe_out_infos); let ksk_prepared: LWESwitchingKeyPrepared, B> = ksk.prepare_alloc(module, scratch.borrow()); lwe_ct_out.keyswitch(module, &lwe_ct_in, &ksk_prepared, scratch.borrow()); - let mut lwe_pt_out: LWEPlaintext> = LWEPlaintext::alloc(&lwe_out_infos); + let mut lwe_pt_out: LWEPlaintext> = LWEPlaintext::alloc_from_infos(&lwe_out_infos); lwe_ct_out.decrypt(module, &mut lwe_pt_out, &sk_lwe_out); assert_eq!(lwe_pt_in.data.at(0, 0)[0], lwe_pt_out.data.at(0, 0)[0]); diff --git a/poulpy-core/src/tests/test_suite/packing.rs b/poulpy-core/src/tests/test_suite/packing.rs index de7fc9d..fdfbd57 100644 --- a/poulpy-core/src/tests/test_suite/packing.rs +++ b/poulpy-core/src/tests/test_suite/packing.rs @@ -2,10 +2,10 @@ use std::collections::HashMap; use poulpy_hal::{ api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, - VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigAutomorphismInplace, VecZnxBigNormalize, - VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, + VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace, VecZnxBigBytesOf, VecZnxBigNormalize, + VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxCopy, VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, @@ -21,14 +21,14 @@ use poulpy_hal::{ use crate::{ GLWEOperations, GLWEPacker, layouts::{ - GGLWEAutomorphismKey, GGLWEAutomorphismKeyLayout, GLWECiphertext, GLWECiphertextLayout, GLWEPlaintext, GLWESecret, - prepared::{GGLWEAutomorphismKeyPrepared, GLWESecretPrepared, PrepareAlloc}, + AutomorphismKey, AutomorphismKeyLayout, GLWE, GLWELayout, GLWEPlaintext, GLWESecret, + prepared::{AutomorphismKeyPrepared, GLWESecretPrepared, PrepareAlloc}, }, }; pub fn test_glwe_packing(module: &Module) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VecZnxAutomorphism + VecZnxBigAutomorphismInplace + VecZnxBigSubSmallNegateInplace @@ -48,9 +48,9 @@ where + VecZnxNormalize + VecZnxSub + SvpPrepare - + SvpPPolAllocBytes + + SvpPPolBytesOf + SvpPPolAlloc - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxBigAddInplace + VecZnxBigAddSmallInplace + VecZnxNormalizeTmpBytes @@ -88,14 +88,14 @@ where let dnum: usize = k_ct.div_ceil(base2k * dsize); - let glwe_out_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_out_infos: GLWELayout = GLWELayout { n: n.into(), base2k: base2k.into(), k: k_ct.into(), rank: rank.into(), }; - let key_infos: GGLWEAutomorphismKeyLayout = GGLWEAutomorphismKeyLayout { + let key_infos: AutomorphismKeyLayout = AutomorphismKeyLayout { n: n.into(), base2k: base2k.into(), k: k_ksk.into(), @@ -105,16 +105,16 @@ where }; let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::encrypt_sk_scratch_space(module, &glwe_out_infos) - | GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, &key_infos) - | GLWEPacker::scratch_space(module, &glwe_out_infos, &key_infos), + GLWE::encrypt_sk_tmp_bytes(module, &glwe_out_infos) + | AutomorphismKey::encrypt_sk_tmp_bytes(module, &key_infos) + | GLWEPacker::tmp_bytes(module, &glwe_out_infos, &key_infos), ); - let mut sk: GLWESecret> = GLWESecret::alloc(&glwe_out_infos); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_out_infos); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_dft: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_out_infos); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_out_infos); let mut data: Vec = vec![0i64; n]; data.iter_mut().enumerate().for_each(|(i, x)| { *x = i as i64; @@ -124,8 +124,8 @@ where let gal_els: Vec = GLWEPacker::galois_elements(module); - let mut auto_keys: HashMap, B>> = HashMap::new(); - let mut tmp: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&key_infos); + let mut auto_keys: HashMap, B>> = HashMap::new(); + let mut tmp: AutomorphismKey> = AutomorphismKey::alloc_from_infos(&key_infos); gal_els.iter().for_each(|gal_el| { tmp.encrypt_sk( module, @@ -135,7 +135,7 @@ where &mut source_xe, scratch.borrow(), ); - let atk_prepared: GGLWEAutomorphismKeyPrepared, B> = tmp.prepare_alloc(module, scratch.borrow()); + let atk_prepared: AutomorphismKeyPrepared, B> = tmp.prepare_alloc(module, scratch.borrow()); auto_keys.insert(*gal_el, atk_prepared); }); @@ -143,7 +143,7 @@ where let mut packer: GLWEPacker = GLWEPacker::new(&glwe_out_infos, log_batch); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&glwe_out_infos); + let mut ct: GLWE> = GLWE::alloc_from_infos(&glwe_out_infos); ct.encrypt_sk( module, @@ -171,19 +171,14 @@ where if reverse_bits_msb(i, log_n as u32).is_multiple_of(5) { packer.add(module, Some(&ct), &auto_keys, scratch.borrow()); } else { - packer.add( - module, - None::<&GLWECiphertext>>, - &auto_keys, - scratch.borrow(), - ) + packer.add(module, None::<&GLWE>>, &auto_keys, scratch.borrow()) } }); - let mut res: GLWECiphertext> = GLWECiphertext::alloc(&glwe_out_infos); + let mut res: GLWE> = GLWE::alloc_from_infos(&glwe_out_infos); packer.flush(module, &mut res); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_out_infos); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_out_infos); let mut data: Vec = vec![0i64; n]; data.iter_mut().enumerate().for_each(|(i, x)| { if i.is_multiple_of(5) { diff --git a/poulpy-core/src/tests/test_suite/trace.rs b/poulpy-core/src/tests/test_suite/trace.rs index bf348ca..932b401 100644 --- a/poulpy-core/src/tests/test_suite/trace.rs +++ b/poulpy-core/src/tests/test_suite/trace.rs @@ -2,13 +2,13 @@ use std::collections::HashMap; use poulpy_hal::{ api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxBigAddInplace, - VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigAutomorphismInplace, VecZnxBigNormalize, - VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, - VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, - VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, - VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, + VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, + VecZnxBigSubSmallNegateInplace, VecZnxCopy, VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, + VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, + VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, + VmpPrepare, }, layouts::{Backend, Module, ScratchOwned, ZnxView, ZnxViewMut}, oep::{ @@ -21,16 +21,15 @@ use poulpy_hal::{ use crate::{ encryption::SIGMA, layouts::{ - GGLWEAutomorphismKey, GGLWEAutomorphismKeyLayout, GLWECiphertext, GLWECiphertextLayout, GLWEPlaintext, GLWESecret, - LWEInfos, - prepared::{GGLWEAutomorphismKeyPrepared, GLWESecretPrepared, PrepareAlloc}, + AutomorphismKey, AutomorphismKeyLayout, GLWE, GLWELayout, GLWEPlaintext, GLWESecret, LWEInfos, + prepared::{AutomorphismKeyPrepared, GLWESecretPrepared, PrepareAlloc}, }, noise::var_noise_gglwe_product, }; pub fn test_glwe_trace_inplace(module: &Module) where - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VecZnxAutomorphism + VecZnxBigAutomorphismInplace + VecZnxBigSubSmallNegateInplace @@ -48,9 +47,9 @@ where + VecZnxNormalize + VecZnxSub + SvpPrepare - + SvpPPolAllocBytes + + SvpPPolBytesOf + SvpPPolAlloc - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxBigAddInplace + VecZnxBigAddSmallInplace + VecZnxNormalizeTmpBytes @@ -83,14 +82,14 @@ where let dsize: usize = 1; let dnum: usize = k.div_ceil(base2k * dsize); - let glwe_out_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_out_infos: GLWELayout = GLWELayout { n: n.into(), base2k: base2k.into(), k: k.into(), rank: rank.into(), }; - let key_infos: GGLWEAutomorphismKeyLayout = GGLWEAutomorphismKeyLayout { + let key_infos: AutomorphismKeyLayout = AutomorphismKeyLayout { n: n.into(), base2k: base2k.into(), k: k_autokey.into(), @@ -99,22 +98,22 @@ where dnum: dnum.into(), }; - let mut glwe_out: GLWECiphertext> = GLWECiphertext::alloc(&glwe_out_infos); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_out_infos); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_out_infos); + let mut glwe_out: GLWE> = GLWE::alloc_from_infos(&glwe_out_infos); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_out_infos); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_out_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::encrypt_sk_scratch_space(module, &glwe_out_infos) - | GLWECiphertext::decrypt_scratch_space(module, &glwe_out_infos) - | GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, &key_infos) - | GLWECiphertext::trace_inplace_scratch_space(module, &glwe_out_infos, &key_infos), + GLWE::encrypt_sk_tmp_bytes(module, &glwe_out_infos) + | GLWE::decrypt_tmp_bytes(module, &glwe_out_infos) + | AutomorphismKey::encrypt_sk_tmp_bytes(module, &key_infos) + | GLWE::trace_inplace_tmp_bytes(module, &glwe_out_infos, &key_infos), ); - let mut sk: GLWESecret> = GLWESecret::alloc(&glwe_out_infos); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_out_infos); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_dft: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); @@ -135,9 +134,9 @@ where scratch.borrow(), ); - let mut auto_keys: HashMap, B>> = HashMap::new(); - let gal_els: Vec = GLWECiphertext::trace_galois_elements(module); - let mut tmp: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&key_infos); + let mut auto_keys: HashMap, B>> = HashMap::new(); + let gal_els: Vec = GLWE::trace_galois_elements(module); + let mut tmp: AutomorphismKey> = AutomorphismKey::alloc_from_infos(&key_infos); gal_els.iter().for_each(|gal_el| { tmp.encrypt_sk( module, @@ -147,7 +146,7 @@ where &mut source_xe, scratch.borrow(), ); - let atk_prepared: GGLWEAutomorphismKeyPrepared, B> = tmp.prepare_alloc(module, scratch.borrow()); + let atk_prepared: AutomorphismKeyPrepared, B> = tmp.prepare_alloc(module, scratch.borrow()); auto_keys.insert(*gal_el, atk_prepared); }); diff --git a/poulpy-hal/src/api/module.rs b/poulpy-hal/src/api/module.rs index 6e5faed..3dd6176 100644 --- a/poulpy-hal/src/api/module.rs +++ b/poulpy-hal/src/api/module.rs @@ -4,3 +4,7 @@ use crate::layouts::Backend; pub trait ModuleNew { fn new(n: u64) -> Self; } + +pub trait ModuleN { + fn n(&self) -> usize; +} diff --git a/poulpy-hal/src/api/scratch.rs b/poulpy-hal/src/api/scratch.rs index 38901bf..ee4a080 100644 --- a/poulpy-hal/src/api/scratch.rs +++ b/poulpy-hal/src/api/scratch.rs @@ -1,4 +1,7 @@ -use crate::layouts::{Backend, MatZnx, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}; +use crate::{ + api::{ModuleN, SvpPPolBytesOf, VecZnxBigBytesOf, VecZnxDftBytesOf, VmpPMatBytesOf}, + layouts::{Backend, MatZnx, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}, +}; /// Allocates a new [crate::layouts::ScratchOwned] of `size` aligned bytes. pub trait ScratchOwnedAlloc { @@ -25,76 +28,130 @@ pub trait TakeSlice { fn take_slice(&mut self, len: usize) -> (&mut [T], &mut Self); } -/// Take a slice of bytes from a [Scratch], wraps it into a [ScalarZnx] and returns it -/// as well as a new [Scratch] minus the taken array of bytes. -pub trait TakeScalarZnx { - fn take_scalar_znx(&mut self, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self); -} +pub trait ScratchTakeBasic +where + Self: TakeSlice, +{ + fn take_scalar_znx(&mut self, module: &M, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) + where + M: ModuleN, + { + let (take_slice, rem_slice) = self.take_slice(ScalarZnx::bytes_of(module.n(), cols)); + ( + ScalarZnx::from_data(take_slice, module.n(), cols), + rem_slice, + ) + } -/// Take a slice of bytes from a [Scratch], wraps it into a [SvpPPol] and returns it -/// as well as a new [Scratch] minus the taken array of bytes. -pub trait TakeSvpPPol { - fn take_svp_ppol(&mut self, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Self); -} + fn take_svp_ppol(&mut self, module: &M, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Self) + where + M: SvpPPolBytesOf + ModuleN, + { + let (take_slice, rem_slice) = self.take_slice(module.bytes_of_svp_ppol(cols)); + (SvpPPol::from_data(take_slice, module.n(), cols), rem_slice) + } -/// Take a slice of bytes from a [Scratch], wraps it into a [VecZnx] and returns it -/// as well as a new [Scratch] minus the taken array of bytes. -pub trait TakeVecZnx { - fn take_vec_znx(&mut self, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self); -} + fn take_vec_znx(&mut self, module: &M, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) + where + M: ModuleN, + { + let (take_slice, rem_slice) = self.take_slice(VecZnx::bytes_of(module.n(), cols, size)); + ( + VecZnx::from_data(take_slice, module.n(), cols, size), + rem_slice, + ) + } -/// Take a slice of bytes from a [Scratch], slices it into a vector of [VecZnx] aand returns it -/// as well as a new [Scratch] minus the taken array of bytes. -pub trait TakeVecZnxSlice { - fn take_vec_znx_slice(&mut self, len: usize, n: usize, cols: usize, size: usize) -> (Vec>, &mut Self); -} + fn take_vec_znx_big(&mut self, module: &M, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, &mut Self) + where + M: VecZnxBigBytesOf + ModuleN, + { + let (take_slice, rem_slice) = self.take_slice(module.bytes_of_vec_znx_big(cols, size)); + ( + VecZnxBig::from_data(take_slice, module.n(), cols, size), + rem_slice, + ) + } -/// Take a slice of bytes from a [Scratch], wraps it into a [VecZnxBig] and returns it -/// as well as a new [Scratch] minus the taken array of bytes. -pub trait TakeVecZnxBig { - fn take_vec_znx_big(&mut self, n: usize, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, &mut Self); -} + fn take_vec_znx_dft(&mut self, module: &M, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, &mut Self) + where + M: VecZnxDftBytesOf + ModuleN, + { + let (take_slice, rem_slice) = self.take_slice(module.bytes_of_vec_znx_dft(cols, size)); -/// Take a slice of bytes from a [Scratch], wraps it into a [VecZnxDft] and returns it -/// as well as a new [Scratch] minus the taken array of bytes. -pub trait TakeVecZnxDft { - fn take_vec_znx_dft(&mut self, n: usize, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, &mut Self); -} + ( + VecZnxDft::from_data(take_slice, module.n(), cols, size), + rem_slice, + ) + } -/// Take a slice of bytes from a [Scratch], slices it into a vector of [VecZnxDft] and returns it -/// as well as a new [Scratch] minus the taken array of bytes. -pub trait TakeVecZnxDftSlice { - fn take_vec_znx_dft_slice( + fn take_vec_znx_dft_slice( &mut self, + module: &M, len: usize, - n: usize, cols: usize, size: usize, - ) -> (Vec>, &mut Self); -} + ) -> (Vec>, &mut Self) + where + M: VecZnxDftBytesOf + ModuleN, + { + let mut scratch: &mut Self = self; + let mut slice: Vec> = Vec::with_capacity(len); + for _ in 0..len { + let (znx, new_scratch) = scratch.take_vec_znx_dft(module, cols, size); + scratch = new_scratch; + slice.push(znx); + } + (slice, scratch) + } -/// Take a slice of bytes from a [Scratch], wraps it into a [VmpPMat] and returns it -/// as well as a new [Scratch] minus the taken array of bytes. -pub trait TakeVmpPMat { - fn take_vmp_pmat( + fn take_vec_znx_slice(&mut self, module: &M, len: usize, cols: usize, size: usize) -> (Vec>, &mut Self) + where + M: ModuleN, + { + let mut scratch: &mut Self = self; + let mut slice: Vec> = Vec::with_capacity(len); + for _ in 0..len { + let (znx, new_scratch) = scratch.take_vec_znx(module, cols, size); + scratch = new_scratch; + slice.push(znx); + } + (slice, scratch) + } + + fn take_vmp_pmat( &mut self, - n: usize, + module: &M, rows: usize, cols_in: usize, cols_out: usize, size: usize, - ) -> (VmpPMat<&mut [u8], B>, &mut Self); -} + ) -> (VmpPMat<&mut [u8], B>, &mut Self) + where + M: VmpPMatBytesOf + ModuleN, + { + let (take_slice, rem_slice) = self.take_slice(module.bytes_of_vmp_pmat(rows, cols_in, cols_out, size)); + ( + VmpPMat::from_data(take_slice, module.n(), rows, cols_in, cols_out, size), + rem_slice, + ) + } -/// Take a slice of bytes from a [Scratch], wraps it into a [MatZnx] and returns it -/// as well as a new [Scratch] minus the taken array of bytes. -pub trait TakeMatZnx { - fn take_mat_znx( + fn take_mat_znx( &mut self, - n: usize, + module: &M, rows: usize, cols_in: usize, cols_out: usize, size: usize, - ) -> (MatZnx<&mut [u8]>, &mut Self); + ) -> (MatZnx<&mut [u8]>, &mut Self) + where + M: ModuleN, + { + let (take_slice, rem_slice) = self.take_slice(MatZnx::bytes_of(module.n(), rows, cols_in, cols_out, size)); + ( + MatZnx::from_data(take_slice, module.n(), rows, cols_in, cols_out, size), + rem_slice, + ) + } } diff --git a/poulpy-hal/src/api/svp_ppol.rs b/poulpy-hal/src/api/svp_ppol.rs index 5a72367..6678584 100644 --- a/poulpy-hal/src/api/svp_ppol.rs +++ b/poulpy-hal/src/api/svp_ppol.rs @@ -8,8 +8,8 @@ pub trait SvpPPolAlloc { } /// Returns the size in bytes to allocate a [crate::layouts::SvpPPol]. -pub trait SvpPPolAllocBytes { - fn svp_ppol_alloc_bytes(&self, cols: usize) -> usize; +pub trait SvpPPolBytesOf { + fn bytes_of_svp_ppol(&self, cols: usize) -> usize; } /// Consume a vector of bytes into a [crate::layouts::MatZnx]. diff --git a/poulpy-hal/src/api/vec_znx_big.rs b/poulpy-hal/src/api/vec_znx_big.rs index 08159bb..8cb5105 100644 --- a/poulpy-hal/src/api/vec_znx_big.rs +++ b/poulpy-hal/src/api/vec_znx_big.rs @@ -16,8 +16,8 @@ pub trait VecZnxBigAlloc { } /// Returns the size in bytes to allocate a [crate::layouts::VecZnxBig]. -pub trait VecZnxBigAllocBytes { - fn vec_znx_big_alloc_bytes(&self, cols: usize, size: usize) -> usize; +pub trait VecZnxBigBytesOf { + fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize; } /// Consume a vector of bytes into a [crate::layouts::VecZnxBig]. diff --git a/poulpy-hal/src/api/vec_znx_dft.rs b/poulpy-hal/src/api/vec_znx_dft.rs index 58589c3..3a003a9 100644 --- a/poulpy-hal/src/api/vec_znx_dft.rs +++ b/poulpy-hal/src/api/vec_znx_dft.rs @@ -10,8 +10,8 @@ pub trait VecZnxDftFromBytes { fn vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxDftOwned; } -pub trait VecZnxDftAllocBytes { - fn vec_znx_dft_alloc_bytes(&self, cols: usize, size: usize) -> usize; +pub trait VecZnxDftBytesOf { + fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize; } pub trait VecZnxDftApply { diff --git a/poulpy-hal/src/api/vmp_pmat.rs b/poulpy-hal/src/api/vmp_pmat.rs index 3d0e248..de3433a 100644 --- a/poulpy-hal/src/api/vmp_pmat.rs +++ b/poulpy-hal/src/api/vmp_pmat.rs @@ -6,8 +6,8 @@ pub trait VmpPMatAlloc { fn vmp_pmat_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned; } -pub trait VmpPMatAllocBytes { - fn vmp_pmat_alloc_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize; +pub trait VmpPMatBytesOf { + fn bytes_of_vmp_pmat(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize; } pub trait VmpPMatFromBytes { diff --git a/poulpy-hal/src/delegates/scratch.rs b/poulpy-hal/src/delegates/scratch.rs index ac022e3..b91afbb 100644 --- a/poulpy-hal/src/delegates/scratch.rs +++ b/poulpy-hal/src/delegates/scratch.rs @@ -1,14 +1,7 @@ use crate::{ - api::{ - ScratchAvailable, ScratchFromBytes, ScratchOwnedAlloc, ScratchOwnedBorrow, TakeMatZnx, TakeScalarZnx, TakeSlice, - TakeSvpPPol, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice, TakeVecZnxSlice, TakeVmpPMat, - }, - layouts::{Backend, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}, - oep::{ - ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeMatZnxImpl, - TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxDftSliceImpl, - TakeVecZnxImpl, TakeVecZnxSliceImpl, TakeVmpPMatImpl, - }, + api::{ScratchAvailable, ScratchFromBytes, ScratchOwnedAlloc, ScratchOwnedBorrow, TakeSlice}, + layouts::{Backend, Scratch, ScratchOwned}, + oep::{ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeSliceImpl}, }; impl ScratchOwnedAlloc for ScratchOwned @@ -55,104 +48,3 @@ where B::take_slice_impl(self, len) } } - -impl TakeScalarZnx for Scratch -where - B: Backend + TakeScalarZnxImpl, -{ - fn take_scalar_znx(&mut self, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) { - B::take_scalar_znx_impl(self, n, cols) - } -} - -impl TakeSvpPPol for Scratch -where - B: Backend + TakeSvpPPolImpl, -{ - fn take_svp_ppol(&mut self, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Self) { - B::take_svp_ppol_impl(self, n, cols) - } -} - -impl TakeVecZnx for Scratch -where - B: Backend + TakeVecZnxImpl, -{ - fn take_vec_znx(&mut self, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) { - B::take_vec_znx_impl(self, n, cols, size) - } -} - -impl TakeVecZnxSlice for Scratch -where - B: Backend + TakeVecZnxSliceImpl, -{ - fn take_vec_znx_slice(&mut self, len: usize, n: usize, cols: usize, size: usize) -> (Vec>, &mut Self) { - B::take_vec_znx_slice_impl(self, len, n, cols, size) - } -} - -impl TakeVecZnxBig for Scratch -where - B: Backend + TakeVecZnxBigImpl, -{ - fn take_vec_znx_big(&mut self, n: usize, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, &mut Self) { - B::take_vec_znx_big_impl(self, n, cols, size) - } -} - -impl TakeVecZnxDft for Scratch -where - B: Backend + TakeVecZnxDftImpl, -{ - fn take_vec_znx_dft(&mut self, n: usize, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, &mut Self) { - B::take_vec_znx_dft_impl(self, n, cols, size) - } -} - -impl TakeVecZnxDftSlice for Scratch -where - B: Backend + TakeVecZnxDftSliceImpl, -{ - fn take_vec_znx_dft_slice( - &mut self, - len: usize, - n: usize, - cols: usize, - size: usize, - ) -> (Vec>, &mut Self) { - B::take_vec_znx_dft_slice_impl(self, len, n, cols, size) - } -} - -impl TakeVmpPMat for Scratch -where - B: Backend + TakeVmpPMatImpl, -{ - fn take_vmp_pmat( - &mut self, - n: usize, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - ) -> (VmpPMat<&mut [u8], B>, &mut Self) { - B::take_vmp_pmat_impl(self, n, rows, cols_in, cols_out, size) - } -} - -impl TakeMatZnx for Scratch -where - B: Backend + TakeMatZnxImpl, -{ - fn take_mat_znx( - &mut self, - n: usize, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - ) -> (MatZnx<&mut [u8]>, &mut Self) { - B::take_mat_znx_impl(self, n, rows, cols_in, cols_out, size) - } -} diff --git a/poulpy-hal/src/delegates/svp_ppol.rs b/poulpy-hal/src/delegates/svp_ppol.rs index 54a99b2..de36fd0 100644 --- a/poulpy-hal/src/delegates/svp_ppol.rs +++ b/poulpy-hal/src/delegates/svp_ppol.rs @@ -1,6 +1,6 @@ use crate::{ api::{ - SvpApplyDft, SvpApplyDftToDft, SvpApplyDftToDftAdd, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, + SvpApplyDft, SvpApplyDftToDft, SvpApplyDftToDftAdd, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPPolFromBytes, SvpPrepare, }, layouts::{ @@ -30,12 +30,12 @@ where } } -impl SvpPPolAllocBytes for Module +impl SvpPPolBytesOf for Module where B: Backend + SvpPPolAllocBytesImpl, { - fn svp_ppol_alloc_bytes(&self, cols: usize) -> usize { - B::svp_ppol_alloc_bytes_impl(self.n(), cols) + fn bytes_of_svp_ppol(&self, cols: usize) -> usize { + B::svp_ppol_bytes_of_impl(self.n(), cols) } } diff --git a/poulpy-hal/src/delegates/vec_znx_big.rs b/poulpy-hal/src/delegates/vec_znx_big.rs index 1556a87..a1cc307 100644 --- a/poulpy-hal/src/delegates/vec_znx_big.rs +++ b/poulpy-hal/src/delegates/vec_znx_big.rs @@ -1,7 +1,7 @@ use crate::{ api::{ VecZnxBigAdd, VecZnxBigAddInplace, VecZnxBigAddNormal, VecZnxBigAddSmall, VecZnxBigAddSmallInplace, VecZnxBigAlloc, - VecZnxBigAllocBytes, VecZnxBigAutomorphism, VecZnxBigAutomorphismInplace, VecZnxBigAutomorphismInplaceTmpBytes, + VecZnxBigAutomorphism, VecZnxBigAutomorphismInplace, VecZnxBigAutomorphismInplaceTmpBytes, VecZnxBigBytesOf, VecZnxBigFromBytes, VecZnxBigFromSmall, VecZnxBigNegate, VecZnxBigNegateInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSub, VecZnxBigSubInplace, VecZnxBigSubNegateInplace, VecZnxBigSubSmallA, VecZnxBigSubSmallB, VecZnxBigSubSmallInplace, VecZnxBigSubSmallNegateInplace, @@ -49,12 +49,12 @@ where } } -impl VecZnxBigAllocBytes for Module +impl VecZnxBigBytesOf for Module where B: Backend + VecZnxBigAllocBytesImpl, { - fn vec_znx_big_alloc_bytes(&self, cols: usize, size: usize) -> usize { - B::vec_znx_big_alloc_bytes_impl(self.n(), cols, size) + fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize { + B::vec_znx_big_bytes_of_impl(self.n(), cols, size) } } diff --git a/poulpy-hal/src/delegates/vec_znx_dft.rs b/poulpy-hal/src/delegates/vec_znx_dft.rs index 3736e34..16a583f 100644 --- a/poulpy-hal/src/delegates/vec_znx_dft.rs +++ b/poulpy-hal/src/delegates/vec_znx_dft.rs @@ -1,8 +1,8 @@ use crate::{ api::{ - VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy, - VecZnxDftFromBytes, VecZnxDftSub, VecZnxDftSubInplace, VecZnxDftSubNegateInplace, VecZnxDftZero, VecZnxIdftApply, - VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxIdftApplyTmpBytes, + VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftApply, VecZnxDftBytesOf, VecZnxDftCopy, VecZnxDftFromBytes, + VecZnxDftSub, VecZnxDftSubInplace, VecZnxDftSubNegateInplace, VecZnxDftZero, VecZnxIdftApply, VecZnxIdftApplyConsume, + VecZnxIdftApplyTmpA, VecZnxIdftApplyTmpBytes, }, layouts::{ Backend, Data, Module, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef, @@ -24,12 +24,12 @@ where } } -impl VecZnxDftAllocBytes for Module +impl VecZnxDftBytesOf for Module where B: Backend + VecZnxDftAllocBytesImpl, { - fn vec_znx_dft_alloc_bytes(&self, cols: usize, size: usize) -> usize { - B::vec_znx_dft_alloc_bytes_impl(self.n(), cols, size) + fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize { + B::vec_znx_dft_bytes_of_impl(self.n(), cols, size) } } diff --git a/poulpy-hal/src/delegates/vmp_pmat.rs b/poulpy-hal/src/delegates/vmp_pmat.rs index a875a40..2c65508 100644 --- a/poulpy-hal/src/delegates/vmp_pmat.rs +++ b/poulpy-hal/src/delegates/vmp_pmat.rs @@ -1,7 +1,7 @@ use crate::{ api::{ VmpApplyDft, VmpApplyDftTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftAddTmpBytes, - VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPMatAllocBytes, VmpPMatFromBytes, VmpPrepare, VmpPrepareTmpBytes, + VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPMatBytesOf, VmpPMatFromBytes, VmpPrepare, VmpPrepareTmpBytes, }, layouts::{ Backend, MatZnxToRef, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, VmpPMatOwned, VmpPMatToMut, @@ -23,12 +23,12 @@ where } } -impl VmpPMatAllocBytes for Module +impl VmpPMatBytesOf for Module where B: Backend + VmpPMatAllocBytesImpl, { - fn vmp_pmat_alloc_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { - B::vmp_pmat_alloc_bytes_impl(self.n(), rows, cols_in, cols_out, size) + fn bytes_of_vmp_pmat(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { + B::vmp_pmat_bytes_of_impl(self.n(), rows, cols_in, cols_out, size) } } diff --git a/poulpy-hal/src/layouts/mat_znx.rs b/poulpy-hal/src/layouts/mat_znx.rs index 01be1a1..5ccb860 100644 --- a/poulpy-hal/src/layouts/mat_znx.rs +++ b/poulpy-hal/src/layouts/mat_znx.rs @@ -114,12 +114,12 @@ impl MatZnx { } impl MatZnx> { - pub fn alloc_bytes(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { - rows * cols_in * VecZnx::>::alloc_bytes(n, cols_out, size) + pub fn bytes_of(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { + rows * cols_in * VecZnx::>::bytes_of(n, cols_out, size) } pub fn alloc(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self { - let data: Vec = alloc_aligned(Self::alloc_bytes(n, rows, cols_in, cols_out, size)); + let data: Vec = alloc_aligned(Self::bytes_of(n, rows, cols_in, cols_out, size)); Self { data, n, @@ -132,7 +132,7 @@ impl MatZnx> { pub fn from_bytes(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); - assert!(data.len() == Self::alloc_bytes(n, rows, cols_in, cols_out, size)); + assert!(data.len() == Self::bytes_of(n, rows, cols_in, cols_out, size)); Self { data, n, @@ -153,7 +153,7 @@ impl MatZnx { } let self_ref: MatZnx<&[u8]> = self.to_ref(); - let nb_bytes: usize = VecZnx::>::alloc_bytes(self.n, self.cols_out, self.size); + let nb_bytes: usize = VecZnx::>::bytes_of(self.n, self.cols_out, self.size); let start: usize = nb_bytes * self.cols() * row + col * nb_bytes; let end: usize = start + nb_bytes; @@ -181,7 +181,7 @@ impl MatZnx { let size: usize = self.size(); let self_ref: MatZnx<&mut [u8]> = self.to_mut(); - let nb_bytes: usize = VecZnx::>::alloc_bytes(n, cols_out, size); + let nb_bytes: usize = VecZnx::>::bytes_of(n, cols_out, size); let start: usize = nb_bytes * cols_in * row + col * nb_bytes; let end: usize = start + nb_bytes; diff --git a/poulpy-hal/src/layouts/scalar_znx.rs b/poulpy-hal/src/layouts/scalar_znx.rs index 296071b..bdf2159 100644 --- a/poulpy-hal/src/layouts/scalar_znx.rs +++ b/poulpy-hal/src/layouts/scalar_znx.rs @@ -132,18 +132,18 @@ impl ScalarZnx { } impl ScalarZnx> { - pub fn alloc_bytes(n: usize, cols: usize) -> usize { + pub fn bytes_of(n: usize, cols: usize) -> usize { n * cols * size_of::() } pub fn alloc(n: usize, cols: usize) -> Self { - let data: Vec = alloc_aligned::(Self::alloc_bytes(n, cols)); + let data: Vec = alloc_aligned::(Self::bytes_of(n, cols)); Self { data, n, cols } } pub fn from_bytes(n: usize, cols: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); - assert!(data.len() == Self::alloc_bytes(n, cols)); + assert!(data.len() == Self::bytes_of(n, cols)); Self { data, n, cols } } } diff --git a/poulpy-hal/src/layouts/svp_ppol.rs b/poulpy-hal/src/layouts/svp_ppol.rs index 50523fc..234ebca 100644 --- a/poulpy-hal/src/layouts/svp_ppol.rs +++ b/poulpy-hal/src/layouts/svp_ppol.rs @@ -77,7 +77,7 @@ where B: SvpPPolAllocBytesImpl, { pub fn alloc(n: usize, cols: usize) -> Self { - let data: Vec = alloc_aligned::(B::svp_ppol_alloc_bytes_impl(n, cols)); + let data: Vec = alloc_aligned::(B::svp_ppol_bytes_of_impl(n, cols)); Self { data: data.into(), n, @@ -88,7 +88,7 @@ where pub fn from_bytes(n: usize, cols: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); - assert!(data.len() == B::svp_ppol_alloc_bytes_impl(n, cols)); + assert!(data.len() == B::svp_ppol_bytes_of_impl(n, cols)); Self { data: data.into(), n, diff --git a/poulpy-hal/src/layouts/vec_znx.rs b/poulpy-hal/src/layouts/vec_znx.rs index d40ef4b..c084934 100644 --- a/poulpy-hal/src/layouts/vec_znx.rs +++ b/poulpy-hal/src/layouts/vec_znx.rs @@ -110,7 +110,7 @@ impl ZnxView for VecZnx { } impl VecZnx> { - pub fn rsh_scratch_space(n: usize) -> usize { + pub fn rsh_tmp_bytes(n: usize) -> usize { n * std::mem::size_of::() } } @@ -125,12 +125,12 @@ impl ZnxZero for VecZnx { } impl VecZnx> { - pub fn alloc_bytes(n: usize, cols: usize, size: usize) -> usize { + pub fn bytes_of(n: usize, cols: usize, size: usize) -> usize { n * cols * size * size_of::() } pub fn alloc(n: usize, cols: usize, size: usize) -> Self { - let data: Vec = alloc_aligned::(Self::alloc_bytes(n, cols, size)); + let data: Vec = alloc_aligned::(Self::bytes_of(n, cols, size)); Self { data, n, @@ -142,7 +142,7 @@ impl VecZnx> { pub fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); - assert!(data.len() == Self::alloc_bytes(n, cols, size)); + assert!(data.len() == Self::bytes_of(n, cols, size)); Self { data, n, diff --git a/poulpy-hal/src/layouts/vec_znx_big.rs b/poulpy-hal/src/layouts/vec_znx_big.rs index c50cf66..73a3e0f 100644 --- a/poulpy-hal/src/layouts/vec_znx_big.rs +++ b/poulpy-hal/src/layouts/vec_znx_big.rs @@ -96,7 +96,7 @@ where B: VecZnxBigAllocBytesImpl, { pub fn alloc(n: usize, cols: usize, size: usize) -> Self { - let data = alloc_aligned::(B::vec_znx_big_alloc_bytes_impl(n, cols, size)); + let data = alloc_aligned::(B::vec_znx_big_bytes_of_impl(n, cols, size)); Self { data: data.into(), n, @@ -109,7 +109,7 @@ where pub fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); - assert!(data.len() == B::vec_znx_big_alloc_bytes_impl(n, cols, size)); + assert!(data.len() == B::vec_znx_big_bytes_of_impl(n, cols, size)); Self { data: data.into(), n, diff --git a/poulpy-hal/src/layouts/vec_znx_dft.rs b/poulpy-hal/src/layouts/vec_znx_dft.rs index 3dc92d5..19d28e1 100644 --- a/poulpy-hal/src/layouts/vec_znx_dft.rs +++ b/poulpy-hal/src/layouts/vec_znx_dft.rs @@ -116,7 +116,7 @@ where B: VecZnxDftAllocBytesImpl, { pub fn alloc(n: usize, cols: usize, size: usize) -> Self { - let data: Vec = alloc_aligned::(B::vec_znx_dft_alloc_bytes_impl(n, cols, size)); + let data: Vec = alloc_aligned::(B::vec_znx_dft_bytes_of_impl(n, cols, size)); Self { data: data.into(), n, @@ -129,7 +129,7 @@ where pub fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); - assert!(data.len() == B::vec_znx_dft_alloc_bytes_impl(n, cols, size)); + assert!(data.len() == B::vec_znx_dft_bytes_of_impl(n, cols, size)); Self { data: data.into(), n, diff --git a/poulpy-hal/src/layouts/vmp_pmat.rs b/poulpy-hal/src/layouts/vmp_pmat.rs index ce83458..bd469ec 100644 --- a/poulpy-hal/src/layouts/vmp_pmat.rs +++ b/poulpy-hal/src/layouts/vmp_pmat.rs @@ -88,9 +88,7 @@ where B: VmpPMatAllocBytesImpl, { pub fn alloc(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self { - let data: Vec = alloc_aligned(B::vmp_pmat_alloc_bytes_impl( - n, rows, cols_in, cols_out, size, - )); + let data: Vec = alloc_aligned(B::vmp_pmat_bytes_of_impl(n, rows, cols_in, cols_out, size)); Self { data: data.into(), n, @@ -104,7 +102,7 @@ where pub fn from_bytes(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); - assert!(data.len() == B::vmp_pmat_alloc_bytes_impl(n, rows, cols_in, cols_out, size)); + assert!(data.len() == B::vmp_pmat_bytes_of_impl(n, rows, cols_in, cols_out, size)); Self { data: data.into(), n, diff --git a/poulpy-hal/src/layouts/zn.rs b/poulpy-hal/src/layouts/zn.rs index 40f5622..00f8067 100644 --- a/poulpy-hal/src/layouts/zn.rs +++ b/poulpy-hal/src/layouts/zn.rs @@ -98,7 +98,7 @@ impl ZnxView for Zn { } impl Zn> { - pub fn rsh_scratch_space(n: usize) -> usize { + pub fn rsh_tmp_bytes(n: usize) -> usize { n * std::mem::size_of::() } } @@ -113,12 +113,12 @@ impl ZnxZero for Zn { } impl Zn> { - pub fn alloc_bytes(n: usize, cols: usize, size: usize) -> usize { + pub fn bytes_of(n: usize, cols: usize, size: usize) -> usize { n * cols * size * size_of::() } pub fn alloc(n: usize, cols: usize, size: usize) -> Self { - let data: Vec = alloc_aligned::(Self::alloc_bytes(n, cols, size)); + let data: Vec = alloc_aligned::(Self::bytes_of(n, cols, size)); Self { data, n, @@ -130,7 +130,7 @@ impl Zn> { pub fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); - assert!(data.len() == Self::alloc_bytes(n, cols, size)); + assert!(data.len() == Self::bytes_of(n, cols, size)); Self { data, n, diff --git a/poulpy-hal/src/oep/scratch.rs b/poulpy-hal/src/oep/scratch.rs index 51a9c56..c973b04 100644 --- a/poulpy-hal/src/oep/scratch.rs +++ b/poulpy-hal/src/oep/scratch.rs @@ -1,4 +1,4 @@ -use crate::layouts::{Backend, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}; +use crate::layouts::{Backend, Scratch, ScratchOwned}; /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation. @@ -39,111 +39,3 @@ pub unsafe trait ScratchAvailableImpl { pub unsafe trait TakeSliceImpl { fn take_slice_impl(scratch: &mut Scratch, len: usize) -> (&mut [T], &mut Scratch); } - -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation. -/// * See [crate::api::TakeScalarZnx] for corresponding public API. -/// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait TakeScalarZnxImpl { - fn take_scalar_znx_impl(scratch: &mut Scratch, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Scratch); -} - -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation. -/// * See [crate::api::TakeSvpPPol] for corresponding public API. -/// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait TakeSvpPPolImpl { - fn take_svp_ppol_impl(scratch: &mut Scratch, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Scratch); -} - -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation. -/// * See [crate::api::TakeVecZnx] for corresponding public API. -/// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait TakeVecZnxImpl { - fn take_vec_znx_impl(scratch: &mut Scratch, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Scratch); -} - -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation. -/// * See [crate::api::TakeVecZnxSlice] for corresponding public API. -/// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait TakeVecZnxSliceImpl { - fn take_vec_znx_slice_impl( - scratch: &mut Scratch, - len: usize, - n: usize, - cols: usize, - size: usize, - ) -> (Vec>, &mut Scratch); -} - -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation. -/// * See [crate::api::TakeVecZnxBig] for corresponding public API. -/// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait TakeVecZnxBigImpl { - fn take_vec_znx_big_impl( - scratch: &mut Scratch, - n: usize, - cols: usize, - size: usize, - ) -> (VecZnxBig<&mut [u8], B>, &mut Scratch); -} - -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation. -/// * See [crate::api::TakeVecZnxDft] for corresponding public API. -/// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait TakeVecZnxDftImpl { - fn take_vec_znx_dft_impl( - scratch: &mut Scratch, - n: usize, - cols: usize, - size: usize, - ) -> (VecZnxDft<&mut [u8], B>, &mut Scratch); -} - -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation. -/// * See [crate::api::TakeVecZnxDftSlice] for corresponding public API. -/// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait TakeVecZnxDftSliceImpl { - fn take_vec_znx_dft_slice_impl( - scratch: &mut Scratch, - len: usize, - n: usize, - cols: usize, - size: usize, - ) -> (Vec>, &mut Scratch); -} - -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation. -/// * See [crate::api::TakeVmpPMat] for corresponding public API. -/// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait TakeVmpPMatImpl { - fn take_vmp_pmat_impl( - scratch: &mut Scratch, - n: usize, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - ) -> (VmpPMat<&mut [u8], B>, &mut Scratch); -} - -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation. -/// * See [crate::api::TakeMatZnx] for corresponding public API. -/// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait TakeMatZnxImpl { - fn take_mat_znx_impl( - scratch: &mut Scratch, - n: usize, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - ) -> (MatZnx<&mut [u8]>, &mut Scratch); -} diff --git a/poulpy-hal/src/oep/svp_ppol.rs b/poulpy-hal/src/oep/svp_ppol.rs index 6550b6f..42c50ea 100644 --- a/poulpy-hal/src/oep/svp_ppol.rs +++ b/poulpy-hal/src/oep/svp_ppol.rs @@ -23,7 +23,7 @@ pub unsafe trait SvpPPolAllocImpl { /// * See [crate::api::SvpPPolAllocBytes] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait SvpPPolAllocBytesImpl { - fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize; + fn svp_ppol_bytes_of_impl(n: usize, cols: usize) -> usize; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) diff --git a/poulpy-hal/src/oep/vec_znx_big.rs b/poulpy-hal/src/oep/vec_znx_big.rs index b2bd4c8..4c12e6a 100644 --- a/poulpy-hal/src/oep/vec_znx_big.rs +++ b/poulpy-hal/src/oep/vec_znx_big.rs @@ -35,7 +35,7 @@ pub unsafe trait VecZnxBigFromBytesImpl { /// * See [crate::api::VecZnxBigAllocBytes] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxBigAllocBytesImpl { - fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize; + fn vec_znx_big_bytes_of_impl(n: usize, cols: usize, size: usize) -> usize; } #[allow(clippy::too_many_arguments)] diff --git a/poulpy-hal/src/oep/vec_znx_dft.rs b/poulpy-hal/src/oep/vec_znx_dft.rs index e5a2bcb..0f9288b 100644 --- a/poulpy-hal/src/oep/vec_znx_dft.rs +++ b/poulpy-hal/src/oep/vec_znx_dft.rs @@ -42,7 +42,7 @@ pub unsafe trait VecZnxDftApplyImpl { /// * See [crate::api::VecZnxDftAllocBytes] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxDftAllocBytesImpl { - fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize; + fn vec_znx_dft_bytes_of_impl(n: usize, cols: usize, size: usize) -> usize; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) diff --git a/poulpy-hal/src/oep/vmp_pmat.rs b/poulpy-hal/src/oep/vmp_pmat.rs index e399f00..bdca416 100644 --- a/poulpy-hal/src/oep/vmp_pmat.rs +++ b/poulpy-hal/src/oep/vmp_pmat.rs @@ -15,7 +15,7 @@ pub unsafe trait VmpPMatAllocImpl { /// * See [crate::api::VmpPMatAllocBytes] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VmpPMatAllocBytesImpl { - fn vmp_pmat_alloc_bytes_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize; + fn vmp_pmat_bytes_of_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) diff --git a/poulpy-hal/src/reference/fft64/vmp.rs b/poulpy-hal/src/reference/fft64/vmp.rs index f6fb73c..07e1a8d 100644 --- a/poulpy-hal/src/reference/fft64/vmp.rs +++ b/poulpy-hal/src/reference/fft64/vmp.rs @@ -140,7 +140,7 @@ where assert!(a.cols() <= cols); } - let (data, tmp_bytes) = tmp_bytes.split_at_mut(BE::vec_znx_dft_alloc_bytes_impl(n, cols, size)); + let (data, tmp_bytes) = tmp_bytes.split_at_mut(BE::vec_znx_dft_bytes_of_impl(n, cols, size)); let mut a_dft: VecZnxDft<&mut [u8], BE> = VecZnxDft::from_data(cast_mut(data), n, cols, size); diff --git a/poulpy-schemes/benches/circuit_bootstrapping.rs b/poulpy-schemes/benches/circuit_bootstrapping.rs index 47c848e..c2ce4ad 100644 --- a/poulpy-schemes/benches/circuit_bootstrapping.rs +++ b/poulpy-schemes/benches/circuit_bootstrapping.rs @@ -3,16 +3,16 @@ use std::hint::black_box; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; use poulpy_backend::{FFT64Avx, FFT64Ref, FFT64Spqlios}; use poulpy_core::layouts::{ - Dsize, GGLWEAutomorphismKeyLayout, GGLWETensorKeyLayout, GGSWCiphertext, GGSWCiphertextLayout, GLWESecret, LWECiphertext, - LWECiphertextLayout, LWESecret, prepared::PrepareAlloc, + AutomorphismKeyLayout, Dsize, GGSW, GGSWLayout, GLWESecret, LWE, LWELayout, LWESecret, TensorKeyLayout, + prepared::PrepareAlloc, }; use poulpy_hal::{ api::{ ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, - SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, - VecZnxAutomorphismInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAllocBytes, - VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxCopy, - VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy, VecZnxFillUniform, + SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, + VecZnxAutomorphismInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAutomorphismInplace, + VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxCopy, + VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftApply, VecZnxDftBytesOf, VecZnxDftCopy, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxRshInplace, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, @@ -42,7 +42,7 @@ where + VecZnxFillUniform + VecZnxAddNormal + VecZnxNormalizeInplace - + VecZnxDftAllocBytes + + VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -55,7 +55,7 @@ where + VecZnxAddScalarInplace + VecZnxAutomorphism + VecZnxSwitchRing - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxIdftApplyTmpA + SvpApplyDftToDft + VecZnxBigAddInplace @@ -70,7 +70,7 @@ where + VmpApplyDftToDftTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + SvpPPolAllocBytes + + SvpPPolBytesOf + VecZnxRotateInplace + VecZnxBigAutomorphismInplace + VecZnxRshInplace @@ -80,7 +80,7 @@ where + VecZnxAutomorphismInplace + VecZnxBigSubSmallNegateInplace + VecZnxRotateInplaceTmpBytes - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxDftAddInplace + VecZnxRotate + ZnFillUniform @@ -113,8 +113,8 @@ where extension_factor: usize, k_pt: usize, block_size: usize, - lwe_infos: LWECiphertextLayout, - ggsw_infos: GGSWCiphertextLayout, + lwe_infos: LWELayout, + ggsw_infos: GGSWLayout, cbt_infos: CircuitBootstrappingKeyLayout, } @@ -124,7 +124,7 @@ where + VecZnxFillUniform + VecZnxAddNormal + VecZnxNormalizeInplace - + VecZnxDftAllocBytes + + VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -137,7 +137,7 @@ where + VecZnxAddScalarInplace + VecZnxAutomorphism + VecZnxSwitchRing - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxIdftApplyTmpA + SvpApplyDftToDft + VecZnxBigAddInplace @@ -152,7 +152,7 @@ where + VmpApplyDftToDftTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + SvpPPolAllocBytes + + SvpPPolBytesOf + VecZnxRotateInplace + VecZnxBigAutomorphismInplace + VecZnxRshInplace @@ -162,7 +162,7 @@ where + VecZnxAutomorphismInplace + VecZnxBigSubSmallNegateInplace + VecZnxRotateInplaceTmpBytes - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxDftAddInplace + VecZnxRotate + ZnFillUniform @@ -188,8 +188,8 @@ where // Scratch space (4MB) let mut scratch: ScratchOwned = ScratchOwned::alloc(1 << 22); - let n_glwe: poulpy_core::layouts::Degree = params.cbt_infos.layout_brk.n_glwe(); - let n_lwe: poulpy_core::layouts::Degree = params.cbt_infos.layout_brk.n_lwe(); + let n_glwe: poulpy_core::layouts::RingDegree = params.cbt_infos.layout_brk.n_glwe(); + let n_lwe: poulpy_core::layouts::RingDegree = params.cbt_infos.layout_brk.n_lwe(); let rank: poulpy_core::layouts::Rank = params.cbt_infos.layout_brk.rank; let module: Module = Module::::new(n_glwe.as_u32() as u64); @@ -202,10 +202,10 @@ where sk_lwe.fill_binary_block(params.block_size, &mut source_xs); sk_lwe.fill_zero(); - let mut sk_glwe: GLWESecret> = GLWESecret::alloc_with(n_glwe, rank); + let mut sk_glwe: GLWESecret> = GLWESecret::alloc(n_glwe, rank); sk_glwe.fill_ternary_prob(0.5, &mut source_xs); - let ct_lwe: LWECiphertext> = LWECiphertext::alloc(¶ms.lwe_infos); + let ct_lwe: LWE> = LWE::alloc_from_infos(¶ms.lwe_infos); // Circuit bootstrapping evaluation key let cbt_key: CircuitBootstrappingKey, BRA> = CircuitBootstrappingKey::encrypt_sk( @@ -218,7 +218,7 @@ where scratch.borrow(), ); - let mut res: GGSWCiphertext> = GGSWCiphertext::alloc(¶ms.ggsw_infos); + let mut res: GGSW> = GGSW::alloc_from_infos(¶ms.ggsw_infos); let cbt_prepared: CircuitBootstrappingKeyPrepared, BRA, B> = cbt_key.prepare_alloc(&module, scratch.borrow()); move || { @@ -238,13 +238,13 @@ where name: String::from("1-bit"), extension_factor: 1, k_pt: 1, - lwe_infos: LWECiphertextLayout { + lwe_infos: LWELayout { n: 574_u32.into(), k: 13_u32.into(), base2k: 13_u32.into(), }, block_size: 7, - ggsw_infos: GGSWCiphertextLayout { + ggsw_infos: GGSWLayout { n: 1024_u32.into(), base2k: 13_u32.into(), k: 26_u32.into(), @@ -261,7 +261,7 @@ where dnum: 3_u32.into(), rank: 2_u32.into(), }, - layout_atk: GGLWEAutomorphismKeyLayout { + layout_atk: AutomorphismKeyLayout { n: 1024_u32.into(), base2k: 13_u32.into(), k: 52_u32.into(), @@ -269,7 +269,7 @@ where dsize: Dsize(1), rank: 2_u32.into(), }, - layout_tsk: GGLWETensorKeyLayout { + layout_tsk: TensorKeyLayout { n: 1024_u32.into(), base2k: 13_u32.into(), k: 52_u32.into(), diff --git a/poulpy-schemes/examples/circuit_bootstrapping.rs b/poulpy-schemes/examples/circuit_bootstrapping.rs index 4fec699..3383114 100644 --- a/poulpy-schemes/examples/circuit_bootstrapping.rs +++ b/poulpy-schemes/examples/circuit_bootstrapping.rs @@ -1,9 +1,9 @@ use poulpy_core::{ GLWEOperations, layouts::{ - GGLWEAutomorphismKeyLayout, GGLWETensorKeyLayout, GGSWCiphertext, GGSWCiphertextLayout, GLWECiphertext, - GLWECiphertextLayout, GLWEPlaintext, GLWESecret, LWECiphertext, LWECiphertextLayout, LWEInfos, LWEPlaintext, LWESecret, - prepared::{GGSWCiphertextPrepared, GLWESecretPrepared, PrepareAlloc}, + AutomorphismKeyLayout, GGSW, GGSWLayout, GLWE, GLWELayout, GLWEPlaintext, GLWESecret, LWE, LWEInfos, LWELayout, + LWEPlaintext, LWESecret, TensorKeyLayout, + prepared::{GGSWPrepared, GLWESecretPrepared, PrepareAlloc}, }, }; use std::time::Instant; @@ -89,7 +89,7 @@ fn main() { dnum: rows_brk.into(), rank: rank.into(), }, - layout_atk: GGLWEAutomorphismKeyLayout { + layout_atk: AutomorphismKeyLayout { n: n_glwe.into(), base2k: base2k.into(), k: k_trace.into(), @@ -97,7 +97,7 @@ fn main() { dsize: 1_u32.into(), rank: rank.into(), }, - layout_tsk: GGLWETensorKeyLayout { + layout_tsk: TensorKeyLayout { n: n_glwe.into(), base2k: base2k.into(), k: k_tsk.into(), @@ -107,7 +107,7 @@ fn main() { }, }; - let ggsw_infos: GGSWCiphertextLayout = GGSWCiphertextLayout { + let ggsw_infos: GGSWLayout = GGSWLayout { n: n_glwe.into(), base2k: base2k.into(), k: k_ggsw_res.into(), @@ -116,7 +116,7 @@ fn main() { rank: rank.into(), }; - let lwe_infos = LWECiphertextLayout { + let lwe_infos = LWELayout { n: n_lwe.into(), k: k_lwe_ct.into(), base2k: base2k.into(), @@ -140,7 +140,7 @@ fn main() { sk_lwe.fill_zero(); // GLWE secret - let mut sk_glwe: GLWESecret> = GLWESecret::alloc_with(n_glwe.into(), rank.into()); + let mut sk_glwe: GLWESecret> = GLWESecret::alloc(n_glwe.into(), rank.into()); sk_glwe.fill_ternary_prob(0.5, &mut source_xs); // sk_glwe.fill_zero(); @@ -151,7 +151,7 @@ fn main() { let data: i64 = 1 % (1 << k_lwe_pt); // LWE plaintext - let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc_with(base2k.into(), k_lwe_pt.into()); + let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc(base2k.into(), k_lwe_pt.into()); // LWE plaintext(data * 2^{- (k_lwe_pt - 1)}) pt_lwe.encode_i64(data, (k_lwe_pt + 1).into()); // +1 for padding bit @@ -167,7 +167,7 @@ fn main() { println!("pt_lwe: {pt_lwe}"); // LWE ciphertext - let mut ct_lwe: LWECiphertext> = LWECiphertext::alloc(&lwe_infos); + let mut ct_lwe: LWE> = LWE::alloc_from_infos(&lwe_infos); // Encrypt LWE Plaintext ct_lwe.encrypt_sk(&module, &pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe); @@ -187,7 +187,7 @@ fn main() { println!("CBT-KGEN: {} ms", now.elapsed().as_millis()); // Output GGSW - let mut res: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_infos); + let mut res: GGSW> = GGSW::alloc_from_infos(&ggsw_infos); // Circuit bootstrapping key prepared (opaque backend dependant write only struct) let cbt_prepared: CircuitBootstrappingKeyPrepared, CGGI, BackendImpl> = @@ -214,7 +214,7 @@ fn main() { // Tests RLWE(1) * GGSW(data) - let glwe_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_infos: GLWELayout = GLWELayout { n: n_glwe.into(), base2k: base2k.into(), k: (k_ggsw_res - base2k).into(), @@ -222,11 +222,11 @@ fn main() { }; // GLWE ciphertext modulus - let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(&glwe_infos); + let mut ct_glwe: GLWE> = GLWE::alloc_from_infos(&glwe_infos); // Some GLWE plaintext with signed data let k_glwe_pt: usize = 3; - let mut pt_glwe: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_infos); + let mut pt_glwe: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_infos); let mut data_vec: Vec = vec![0i64; n_glwe]; data_vec .iter_mut() @@ -249,13 +249,13 @@ fn main() { ); // Prepare GGSW output of circuit bootstrapping (opaque backend dependant write only struct) - let res_prepared: GGSWCiphertextPrepared, BackendImpl> = res.prepare_alloc(&module, scratch.borrow()); + let res_prepared: GGSWPrepared, BackendImpl> = res.prepare_alloc(&module, scratch.borrow()); // Apply GLWE x GGSW ct_glwe.external_product_inplace(&module, &res_prepared, scratch.borrow()); // Decrypt - let mut pt_res: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_infos); + let mut pt_res: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_infos); ct_glwe.decrypt(&module, &mut pt_res, &sk_glwe_prepared, scratch.borrow()); println!("pt_res: {:?}", &pt_res.data.at(0, 0)[..64]); diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs index 938c45e..ba0d38f 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs @@ -1,5 +1,5 @@ use itertools::Itertools; -use poulpy_core::layouts::prepared::GGSWCiphertextPreparedToRef; +use poulpy_core::layouts::prepared::GGSWPreparedToRef; use poulpy_hal::layouts::{Backend, DataMut, DataRef, Module, Scratch}; use crate::tfhe::bdd_arithmetic::{ @@ -60,15 +60,11 @@ pub fn eval_bdd_2w_to_1w> = a + let inputs: Vec<&dyn GGSWPreparedToRef> = a .blocks .iter() - .map(|x| x as &dyn GGSWCiphertextPreparedToRef) - .chain( - b.blocks - .iter() - .map(|x| x as &dyn GGSWCiphertextPreparedToRef), - ) + .map(|x| x as &dyn GGSWPreparedToRef) + .chain(b.blocks.iter().map(|x| x as &dyn GGSWPreparedToRef)) .collect_vec(); // Evaluates out[i] = circuit[i](a, b) diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block.rs index 145ce6b..de4e30d 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block.rs @@ -1,9 +1,9 @@ use std::marker::PhantomData; -use poulpy_core::layouts::{Base2K, GLWECiphertext, GLWEInfos, GLWEPlaintextLayout, LWEInfos, Rank, TorusPrecision}; +use poulpy_core::layouts::{Base2K, GLWE, GLWEInfos, GLWEPlaintextLayout, LWEInfos, Rank, TorusPrecision}; -use poulpy_core::{TakeGLWEPt, layouts::prepared::GLWESecretPrepared}; -use poulpy_hal::api::VecZnxBigAllocBytes; +use poulpy_core::{TakeGLWEPlaintext, layouts::prepared::GLWESecretPrepared}; +use poulpy_hal::api::VecZnxBigBytesOf; #[cfg(test)] use poulpy_hal::api::{ ScratchAvailable, TakeVecZnx, VecZnxAddInplace, VecZnxAddNormal, VecZnxFillUniform, VecZnxNormalize, VecZnxSub, @@ -12,8 +12,8 @@ use poulpy_hal::api::{ use poulpy_hal::source::Source; use poulpy_hal::{ api::{ - TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, - VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalizeTmpBytes, + TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxDftApply, + VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalizeTmpBytes, }, layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}, }; @@ -24,7 +24,7 @@ use crate::tfhe::bdd_arithmetic::{FromBits, ToBits, UnsignedInteger}; /// An FHE ciphertext encrypting the bits of an [UnsignedInteger]. pub struct FheUintBlocks { - pub(crate) blocks: Vec>, + pub(crate) blocks: Vec>, pub(crate) _base: u8, pub(crate) _phantom: PhantomData, } @@ -38,7 +38,7 @@ impl LWEInfos for FheUintBlocks { self.blocks[0].k() } - fn n(&self) -> poulpy_core::layouts::Degree { + fn n(&self) -> poulpy_core::layouts::RingDegree { self.blocks[0].n() } } @@ -62,7 +62,7 @@ impl FheUintBlocks, T> { pub(crate) fn alloc_with(module: &Module, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self { Self { blocks: (0..T::WORD_SIZE) - .map(|_| GLWECiphertext::alloc_with(module.n().into(), base2k, k, rank)) + .map(|_| GLWE::alloc(module.n().into(), base2k, k, rank)) .collect(), _base: 1, _phantom: PhantomData, @@ -83,7 +83,7 @@ impl FheUintBlocks { scratch: &mut Scratch, ) where S: DataRef, - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -96,7 +96,7 @@ impl FheUintBlocks { + VecZnxAddNormal + VecZnxNormalize + VecZnxSub, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeGLWEPt, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeGLWEPlaintext, { use poulpy_core::layouts::GLWEPlaintextLayout; @@ -136,7 +136,7 @@ impl FheUintBlocks { + VecZnxBigAddInplace + VecZnxBigAddSmallInplace + VecZnxBigNormalize, - Scratch: TakeVecZnxDft + TakeVecZnxBig + TakeGLWEPt, + Scratch: TakeVecZnxDft + TakeVecZnxBig + TakeGLWEPlaintext, { #[cfg(debug_assertions)] { @@ -175,8 +175,8 @@ impl FheUintBlocks { scratch: &mut Scratch, ) -> Vec where - Module: VecZnxDftAllocBytes - + VecZnxBigAllocBytes + Module: VecZnxDftBytesOf + + VecZnxBigBytesOf + VecZnxDftApply + SvpApplyDftToDftInplace + VecZnxIdftApplyConsume @@ -186,7 +186,7 @@ impl FheUintBlocks { + VecZnxNormalizeTmpBytes + VecZnxSubInplace + VecZnxNormalizeInplace, - Scratch: TakeGLWEPt + TakeVecZnxDft + TakeVecZnxBig, + Scratch: TakeGLWEPlaintext + TakeVecZnxDft + TakeVecZnxBig, { #[cfg(debug_assertions)] { diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block_prepared.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block_prepared.rs index 3cd50a6..2f4bc44 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block_prepared.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block_prepared.rs @@ -1,12 +1,10 @@ use std::marker::PhantomData; -use poulpy_core::layouts::{ - Base2K, Dnum, Dsize, GGSWInfos, GLWEInfos, LWEInfos, Rank, TorusPrecision, prepared::GGSWCiphertextPrepared, -}; +use poulpy_core::layouts::{Base2K, Dnum, Dsize, GGSWInfos, GLWEInfos, LWEInfos, Rank, TorusPrecision, prepared::GGSWPrepared}; #[cfg(test)] use poulpy_core::{ TakeGGSW, - layouts::{GGSWCiphertext, prepared::GLWESecretPrepared}, + layouts::{GGSW, prepared::GLWESecretPrepared}, }; use poulpy_hal::{ api::VmpPMatAlloc, @@ -16,8 +14,8 @@ use poulpy_hal::{ use poulpy_hal::{ api::{ ScratchAvailable, SvpApplyDftToDftInplace, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, - VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAllocBytes, - VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, + VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigBytesOf, + VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAlloc, VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VmpPrepare, }, @@ -29,7 +27,7 @@ use crate::tfhe::bdd_arithmetic::{FheUintBlocks, FheUintPrepare, ToBits, Unsigne #[cfg(test)] pub(crate) struct FheUintBlocksPrepDebug { - pub(crate) blocks: Vec>, + pub(crate) blocks: Vec>, pub(crate) _base: u8, pub(crate) _phantom: PhantomData, } @@ -62,7 +60,7 @@ impl FheUintBlocksPrepDebug, T> { ) -> Self { Self { blocks: (0..T::WORD_SIZE) - .map(|_| GGSWCiphertext::alloc_with(module.n().into(), base2k, k, rank, dnum, dsize)) + .map(|_| GGSW::alloc(module.n().into(), base2k, k, rank, dnum, dsize)) .collect(), _base: 1, _phantom: PhantomData, @@ -72,7 +70,7 @@ impl FheUintBlocksPrepDebug, T> { /// A prepared FHE ciphertext encrypting the bits of an [UnsignedInteger]. pub struct FheUintBlocksPrep { - pub(crate) blocks: Vec>, + pub(crate) blocks: Vec>, pub(crate) _base: u8, pub(crate) _phantom: PhantomData, } @@ -103,7 +101,7 @@ where { Self { blocks: (0..T::WORD_SIZE) - .map(|_| GGSWCiphertextPrepared::alloc_with(module, base2k, k, dnum, dsize, rank)) + .map(|_| GGSWPrepared::alloc(module, base2k, k, dnum, dsize, rank)) .collect(), _base: 1, _phantom: PhantomData, @@ -125,7 +123,7 @@ impl FheUintBlocksPrep: VecZnxAddScalarInplace - + VecZnxDftAllocBytes + + VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -192,8 +190,8 @@ impl FheUintBlocksPrepDebug { #[allow(dead_code)] pub(crate) fn noise(&self, module: &Module, sk: &GLWESecretPrepared, want: T) where - Module: VecZnxDftAllocBytes - + VecZnxBigAllocBytes + Module: VecZnxDftBytesOf + + VecZnxBigBytesOf + VecZnxDftApply + SvpApplyDftToDftInplace + VecZnxIdftApplyConsume @@ -227,7 +225,7 @@ impl LWEInfos for FheUintBlocksPrep< self.blocks[0].k() } - fn n(&self) -> poulpy_core::layouts::Degree { + fn n(&self) -> poulpy_core::layouts::RingDegree { self.blocks[0].n() } } @@ -258,7 +256,7 @@ impl LWEInfos for FheUintBlocksPrepDebug { self.blocks[0].k() } - fn n(&self) -> poulpy_core::layouts::Degree { + fn n(&self) -> poulpy_core::layouts::RingDegree { self.blocks[0].n() } } diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/word.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/word.rs index cc754bc..38ea57d 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/word.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/word.rs @@ -1,9 +1,9 @@ use itertools::Itertools; use poulpy_core::{ - GLWEOperations, TakeGLWECtSlice, TakeGLWEPt, glwe_packing, + GLWEOperations, TakeGLWEPlaintext, TakeGLWESlice, glwe_packing, layouts::{ - GLWECiphertext, GLWEInfos, GLWEPlaintextLayout, LWEInfos, TorusPrecision, - prepared::{GGLWEAutomorphismKeyPrepared, GLWESecretPrepared}, + GLWE, GLWEInfos, GLWEPlaintextLayout, LWEInfos, TorusPrecision, + prepared::{AutomorphismKeyPrepared, GLWESecretPrepared}, }, }; use poulpy_hal::{ @@ -11,7 +11,7 @@ use poulpy_hal::{ ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxCopy, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, + VecZnxDftApply, VecZnxDftBytesOf, VecZnxDftCopy, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, @@ -24,22 +24,22 @@ use std::{collections::HashMap, marker::PhantomData}; use crate::tfhe::bdd_arithmetic::{FromBits, ToBits, UnsignedInteger}; /// A FHE ciphertext encrypting a [UnsignedInteger]. -pub struct FheUintWord(pub(crate) GLWECiphertext, pub(crate) PhantomData); +pub struct FheUintWord(pub(crate) GLWE, pub(crate) PhantomData); impl FheUintWord { #[allow(dead_code)] fn post_process( &mut self, module: &Module, - mut tmp_res: Vec>, - auto_keys: &HashMap>, + mut tmp_res: Vec>, + auto_keys: &HashMap>, scratch: &mut Scratch, ) where ATK: DataRef, Module: VecZnxSub + VecZnxCopy + VecZnxNegateInplace - + VecZnxDftAllocBytes + + VecZnxDftBytesOf + VecZnxAddInplace + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes @@ -62,12 +62,12 @@ impl FheUintWord { + VecZnxAutomorphismInplace + VecZnxBigSubSmallNegateInplace + VecZnxRotate, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeGLWECtSlice, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeGLWESlice, { // Repacks the GLWE ciphertexts bits let gap: usize = module.n() / T::WORD_SIZE; let log_gap: usize = (usize::BITS - (gap - 1).leading_zeros()) as usize; - let mut cts: HashMap> = HashMap::new(); + let mut cts: HashMap> = HashMap::new(); for (i, ct) in tmp_res.iter_mut().enumerate().take(T::WORD_SIZE) { cts.insert(i * gap, ct); } @@ -87,7 +87,7 @@ impl LWEInfos for FheUintWord { self.0.k() } - fn n(&self) -> poulpy_core::layouts::Degree { + fn n(&self) -> poulpy_core::layouts::RingDegree { self.0.n() } } @@ -109,7 +109,7 @@ impl FheUintWord { scratch: &mut Scratch, ) where Module: VecZnxAddScalarInplace - + VecZnxDftAllocBytes + + VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -122,7 +122,7 @@ impl FheUintWord { + VecZnxAddNormal + VecZnxNormalize + VecZnxSub, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeGLWEPt, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeGLWEPlaintext, { #[cfg(debug_assertions)] { @@ -167,7 +167,7 @@ impl FheUintWord { + VecZnxBigAddInplace + VecZnxBigAddSmallInplace + VecZnxBigNormalize, - Scratch: TakeVecZnxDft + TakeVecZnxBig + TakeGLWEPt, + Scratch: TakeVecZnxDft + TakeVecZnxBig + TakeGLWEPlaintext, { #[cfg(debug_assertions)] { diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs index 98e1083..245fe03 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs @@ -1,9 +1,9 @@ use itertools::Itertools; use poulpy_core::{ - GLWEExternalProductInplace, GLWEOperations, TakeGLWECtSlice, + GLWEExternalProductInplace, GLWEOperations, TakeGLWESlice, layouts::{ - GLWECiphertext, GLWECiphertextToMut, LWEInfos, - prepared::{GGSWCiphertextPrepared, GGSWCiphertextPreparedToRef}, + GLWE, GLWEToMut, LWEInfos, + prepared::{GGSWPrepared, GGSWPreparedToRef}, }, }; use poulpy_hal::{ @@ -38,8 +38,8 @@ where fn execute( &self, module: &Module, - out: &mut [GLWECiphertext], - inputs: &[&dyn GGSWCiphertextPreparedToRef], + out: &mut [GLWE], + inputs: &[&dyn GGSWPreparedToRef], scratch: &mut Scratch, ) where O: DataMut; @@ -49,13 +49,13 @@ impl Circuit where Self: GetBitCircuitInfo, Module: Cmux + VecZnxCopy, - Scratch: TakeGLWECtSlice, + Scratch: TakeGLWESlice, { fn execute( &self, module: &Module, - out: &mut [GLWECiphertext], - inputs: &[&dyn GGSWCiphertextPreparedToRef], + out: &mut [GLWE], + inputs: &[&dyn GGSWPreparedToRef], scratch: &mut Scratch, ) where O: DataMut, @@ -159,14 +159,8 @@ impl Node { } pub trait Cmux { - fn cmux( - &self, - out: &mut GLWECiphertext, - t: &GLWECiphertext, - f: &GLWECiphertext, - s: &GGSWCiphertextPrepared, - scratch: &mut Scratch, - ) where + fn cmux(&self, out: &mut GLWE, t: &GLWE, f: &GLWE, s: &GGSWPrepared, scratch: &mut Scratch) + where O: DataMut, T: DataRef, F: DataRef, @@ -177,14 +171,8 @@ impl Cmux for Module where Module: GLWEExternalProductInplace + VecZnxSub + VecZnxCopy + VecZnxNegateInplace + VecZnxAddInplace, { - fn cmux( - &self, - out: &mut GLWECiphertext, - t: &GLWECiphertext, - f: &GLWECiphertext, - s: &GGSWCiphertextPrepared, - scratch: &mut Scratch, - ) where + fn cmux(&self, out: &mut GLWE, t: &GLWE, f: &GLWE, s: &GGSWPrepared, scratch: &mut Scratch) + where O: DataMut, T: DataRef, F: DataRef, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs index febabed..fdcf5b3 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs @@ -9,18 +9,18 @@ use crate::tfhe::{ }, }; use poulpy_core::{ - TakeGGSW, TakeGLWECt, + TakeGGSW, TakeGLWE, layouts::{ - GLWESecret, GLWEToLWEKey, GLWEToLWEKeyLayout, LWECiphertext, LWESecret, + GLWESecret, GLWEToLWEKeyLayout, GLWEToLWESwitchingKey, LWE, LWESecret, prepared::{GLWEToLWESwitchingKeyPrepared, Prepare, PrepareAlloc}, }, }; use poulpy_hal::{ api::{ - ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, + ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, TakeScalarZnx, TakeSvpPPol, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, + VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPrepare, }, @@ -56,7 +56,7 @@ where BRA: BlindRotationAlgo, { cbt: CircuitBootstrappingKey, - ks: GLWEToLWEKey, + ks: GLWEToLWESwitchingKey, } impl BDDKey, Vec, BRA> { @@ -77,7 +77,7 @@ impl BDDKey, Vec, BRA> { Module: SvpApplyDftToDft + VecZnxIdftApplyTmpA + VecZnxAddScalarInplace - + VecZnxDftAllocBytes + + VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -92,13 +92,13 @@ impl BDDKey, Vec, BRA> { + VecZnxSub + SvpPrepare + VecZnxSwitchRing - + SvpPPolAllocBytes + + SvpPPolBytesOf + SvpPPolAlloc + VecZnxAutomorphism + VecZnxAutomorphismInplace, Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeSvpPPol + TakeVecZnxBig, { - let mut ks: GLWEToLWEKey> = GLWEToLWEKey::alloc(&infos.ks_infos()); + let mut ks: GLWEToLWESwitchingKey> = GLWEToLWESwitchingKey::alloc(&infos.ks_infos()); ks.encrypt_sk(module, sk_lwe, sk_glwe, source_xa, source_xe, scratch); Self { @@ -131,7 +131,7 @@ impl PrepareAll for BDDKey where CircuitBootstrappingKey: PrepareAlloc>, - GLWEToLWEKey: PrepareAlloc>, + GLWEToLWESwitchingKey: PrepareAlloc>, { fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> BDDKeyPrepared { BDDKeyPrepared { @@ -157,7 +157,7 @@ where BE: Backend, Module: VmpPrepare + VecZnxRotate - + VecZnxDftAllocBytes + + VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft @@ -168,7 +168,7 @@ where + VecZnxBigNormalize + VecZnxNormalize + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable + TakeVecZnxDft + TakeGLWECt + TakeVecZnx + TakeGGSW, + Scratch: ScratchAvailable + TakeVecZnxDft + TakeGLWE + TakeVecZnx + TakeGGSW, CircuitBootstrappingKeyPrepared: CirtuitBootstrappingExecute, { fn prepare( @@ -182,7 +182,7 @@ where { assert_eq!(out.blocks.len(), bits.blocks.len()); } - let mut lwe: LWECiphertext> = LWECiphertext::alloc(&bits.blocks[0]); //TODO: add TakeLWE + let mut lwe: LWE> = LWE::alloc(&bits.blocks[0]); //TODO: add TakeLWE let (mut tmp_ggsw, scratch_1) = scratch.take_ggsw(out); for (dst, src) in out.blocks.iter_mut().zip(bits.blocks.iter()) { lwe.from_glwe(module, src, &self.ks, scratch_1); @@ -206,7 +206,7 @@ where BE: Backend, Module: VmpPrepare + VecZnxRotate - + VecZnxDftAllocBytes + + VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft @@ -217,7 +217,7 @@ where + VecZnxBigNormalize + VecZnxNormalize + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable + TakeVecZnxDft + TakeGLWECt + TakeVecZnx + TakeGGSW, + Scratch: ScratchAvailable + TakeVecZnxDft + TakeGLWE + TakeVecZnx + TakeGGSW, CircuitBootstrappingKeyPrepared: CirtuitBootstrappingExecute, { fn prepare( @@ -231,7 +231,7 @@ where { assert_eq!(out.blocks.len(), bits.blocks.len()); } - let mut lwe: LWECiphertext> = LWECiphertext::alloc(&bits.blocks[0]); //TODO: add TakeLWE + let mut lwe: LWE> = LWE::alloc(&bits.blocks[0]); //TODO: add TakeLWE for (dst, src) in out.blocks.iter_mut().zip(bits.blocks.iter()) { lwe.from_glwe(module, src, &self.ks, scratch); self.cbt diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/parameters.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/parameters.rs index 6b56f79..924d53a 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/parameters.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/parameters.rs @@ -1,7 +1,7 @@ #[cfg(test)] use poulpy_core::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWEAutomorphismKeyLayout, GGLWETensorKeyLayout, GGSWCiphertextLayout, GLWECiphertextLayout, - GLWEToLWEKeyLayout, Rank, TorusPrecision, + AutomorphismKeyLayout, Base2K, Dnum, Dsize, GGSWLayout, GLWELayout, GLWEToLWEKeyLayout, Rank, RingDegree, TensorKeyLayout, + TorusPrecision, }; #[cfg(test)] @@ -25,16 +25,16 @@ pub(crate) const TEST_BLOCK_SIZE: u32 = 7; pub(crate) const TEST_RANK: u32 = 2; #[cfg(test)] -pub(crate) static TEST_GLWE_INFOS: GLWECiphertextLayout = GLWECiphertextLayout { - n: Degree(TEST_N_GLWE), +pub(crate) static TEST_GLWE_INFOS: GLWELayout = GLWELayout { + n: RingDegree(TEST_N_GLWE), base2k: Base2K(TEST_BASE2K), k: TorusPrecision(TEST_K_GLWE), rank: Rank(TEST_RANK), }; #[cfg(test)] -pub(crate) static TEST_GGSW_INFOS: GGSWCiphertextLayout = GGSWCiphertextLayout { - n: Degree(TEST_N_GLWE), +pub(crate) static TEST_GGSW_INFOS: GGSWLayout = GGSWLayout { + n: RingDegree(TEST_N_GLWE), base2k: Base2K(TEST_BASE2K), k: TorusPrecision(TEST_K_GGSW), rank: Rank(TEST_RANK), @@ -46,23 +46,23 @@ pub(crate) static TEST_GGSW_INFOS: GGSWCiphertextLayout = GGSWCiphertextLayout { pub(crate) static TEST_BDD_KEY_LAYOUT: BDDKeyLayout = BDDKeyLayout { cbt: CircuitBootstrappingKeyLayout { layout_brk: BlindRotationKeyLayout { - n_glwe: Degree(TEST_N_GLWE), - n_lwe: Degree(TEST_N_LWE), + n_glwe: RingDegree(TEST_N_GLWE), + n_lwe: RingDegree(TEST_N_LWE), base2k: Base2K(TEST_BASE2K), k: TorusPrecision(52), dnum: Dnum(3), rank: Rank(TEST_RANK), }, - layout_atk: GGLWEAutomorphismKeyLayout { - n: Degree(TEST_N_GLWE), + layout_atk: AutomorphismKeyLayout { + n: RingDegree(TEST_N_GLWE), base2k: Base2K(TEST_BASE2K), k: TorusPrecision(52), rank: Rank(TEST_RANK), dnum: Dnum(3), dsize: Dsize(1), }, - layout_tsk: GGLWETensorKeyLayout { - n: Degree(TEST_N_GLWE), + layout_tsk: TensorKeyLayout { + n: RingDegree(TEST_N_GLWE), base2k: Base2K(TEST_BASE2K), k: TorusPrecision(52), rank: Rank(TEST_RANK), @@ -71,7 +71,7 @@ pub(crate) static TEST_BDD_KEY_LAYOUT: BDDKeyLayout = BDDKeyLayout { }, }, ks: GLWEToLWEKeyLayout { - n: Degree(TEST_N_GLWE), + n: RingDegree(TEST_N_GLWE), base2k: Base2K(TEST_BASE2K), k: TorusPrecision(39), rank_in: Rank(TEST_RANK), diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/test.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/test.rs index 1889e02..1c2e69d 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/test.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/test.rs @@ -2,20 +2,20 @@ use std::time::Instant; use poulpy_backend::FFT64Ref; use poulpy_core::{ - TakeGGSW, TakeGLWEPt, + TakeGGSW, TakeGLWEPlaintext, layouts::{ - GGSWCiphertextLayout, GLWECiphertextLayout, GLWESecret, LWEInfos, LWESecret, + GGSWLayout, GLWELayout, GLWESecret, LWEInfos, LWESecret, prepared::{GLWESecretPrepared, PrepareAlloc}, }, }; use poulpy_hal::{ api::{ ModuleNew, ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDft, SvpApplyDftToDftInplace, - SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeSlice, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, + SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, TakeScalarZnx, TakeSlice, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, - VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigAutomorphismInplace, + VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAutomorphismInplace, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxCopy, VecZnxDftAddInplace, - VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy, VecZnxFillUniform, VecZnxIdftApplyConsume, + VecZnxDftAlloc, VecZnxDftApply, VecZnxDftBytesOf, VecZnxDftCopy, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxRshInplace, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, ZnAddNormal, ZnFillUniform, @@ -51,7 +51,7 @@ where Module: ModuleNew + SvpPPolAlloc + SvpPrepare + VmpPMatAlloc, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Module: VecZnxAddScalarInplace - + VecZnxDftAllocBytes + + VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -68,16 +68,16 @@ where Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeGGSW + TakeScalarZnx + TakeSlice, Module: VecZnxCopy + VecZnxNegateInplace + VmpApplyDftToDftTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd, Module: VecZnxBigAddInplace + VecZnxBigAddSmallInplace + VecZnxBigNormalize, - Scratch: TakeVecZnxDft + TakeVecZnxBig + TakeGLWEPt, + Scratch: TakeVecZnxDft + TakeVecZnxBig + TakeGLWEPlaintext, Module: VecZnxAutomorphism + VecZnxSwitchRing - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxIdftApplyTmpA + SvpApplyDftToDft + VecZnxBigAlloc + VecZnxDftAlloc + VecZnxBigNormalizeTmpBytes - + SvpPPolAllocBytes + + SvpPPolBytesOf + VecZnxRotateInplace + VecZnxBigAutomorphismInplace + VecZnxRshInplace @@ -85,7 +85,7 @@ where + VecZnxAutomorphismInplace + VecZnxBigSubSmallNegateInplace + VecZnxRotateInplaceTmpBytes - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxDftAddInplace + VecZnxRotate + ZnFillUniform @@ -107,8 +107,8 @@ where BlindRotationKeyPrepared, BRA, BE>: BlincRotationExecute, BlindRotationKey, BRA>: BlindRotationKeyAlloc + BlindRotationKeyEncryptSk, { - let glwe_infos: GLWECiphertextLayout = TEST_GLWE_INFOS; - let ggsw_infos: GGSWCiphertextLayout = TEST_GGSW_INFOS; + let glwe_infos: GLWELayout = TEST_GLWE_INFOS; + let ggsw_infos: GGSWLayout = TEST_GGSW_INFOS; let n_glwe: usize = glwe_infos.n().into(); @@ -120,7 +120,7 @@ where let mut scratch: ScratchOwned = ScratchOwned::alloc(1 << 22); - let mut sk_glwe: GLWESecret> = GLWESecret::alloc(&glwe_infos); + let mut sk_glwe: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_infos); sk_glwe.fill_ternary_prob(0.5, &mut source_xs); let sk_glwe_prep: GLWESecretPrepared, BE> = sk_glwe.prepare_alloc(&module, scratch.borrow()); diff --git a/poulpy-schemes/src/tfhe/blind_rotation/cggi_algo.rs b/poulpy-schemes/src/tfhe/blind_rotation/cggi_algo.rs index 03f36e9..8bf0a9e 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/cggi_algo.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/cggi_algo.rs @@ -1,9 +1,9 @@ use itertools::izip; use poulpy_hal::{ api::{ - ScratchAvailable, SvpApplyDftToDft, SvpPPolAllocBytes, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice, - TakeVecZnxSlice, VecZnxAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, - VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftApply, + ScratchAvailable, SvpApplyDftToDft, SvpPPolBytesOf, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice, + TakeVecZnxSlice, VecZnxAddInplace, VecZnxBigAddSmallInplace, VecZnxBigBytesOf, VecZnxBigNormalize, + VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftApply, VecZnxDftBytesOf, VecZnxDftSubInplace, VecZnxDftZero, VecZnxIdftApply, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpBytes, VecZnxMulXpMinusOneInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxSubInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, @@ -12,8 +12,8 @@ use poulpy_hal::{ }; use poulpy_core::{ - Distribution, GLWEOperations, TakeGLWECt, - layouts::{GGSWInfos, GLWECiphertext, GLWECiphertextToMut, GLWEInfos, LWECiphertext, LWECiphertextToRef, LWEInfos}, + Distribution, GLWEOperations, TakeGLWE, + layouts::{GGSWInfos, GLWE, GLWEInfos, GLWEToMut, LWE, LWEInfos, LWEToRef}, }; use crate::tfhe::blind_rotation::{ @@ -21,7 +21,7 @@ use crate::tfhe::blind_rotation::{ }; #[allow(clippy::too_many_arguments)] -pub fn cggi_blind_rotate_scratch_space( +pub fn cggi_blind_rotate_tmp_bytes( module: &Module, block_size: usize, extension_factor: usize, @@ -31,10 +31,10 @@ pub fn cggi_blind_rotate_scratch_space( where OUT: GLWEInfos, GGSW: GGSWInfos, - Module: VecZnxDftAllocBytes + Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxIdftApplyTmpBytes + VecZnxBigNormalizeTmpBytes, { @@ -43,14 +43,14 @@ where if block_size > 1 { let cols: usize = (brk_infos.rank() + 1).into(); let dnum: usize = brk_infos.dnum().into(); - let acc_dft: usize = module.vec_znx_dft_alloc_bytes(cols, dnum) * extension_factor; - let acc_big: usize = module.vec_znx_big_alloc_bytes(1, brk_size); - let vmp_res: usize = module.vec_znx_dft_alloc_bytes(cols, brk_size) * extension_factor; - let vmp_xai: usize = module.vec_znx_dft_alloc_bytes(1, brk_size); + let acc_dft: usize = module.bytes_of_vec_znx_dft(cols, dnum) * extension_factor; + let acc_big: usize = module.bytes_of_vec_znx_big(1, brk_size); + let vmp_res: usize = module.bytes_of_vec_znx_dft(cols, brk_size) * extension_factor; + let vmp_xai: usize = module.bytes_of_vec_znx_dft(1, brk_size); let acc_dft_add: usize = vmp_res; let vmp: usize = module.vmp_apply_dft_to_dft_tmp_bytes(brk_size, dnum, dnum, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2) let acc: usize = if extension_factor > 1 { - VecZnx::alloc_bytes(module.n(), cols, glwe_infos.size()) * extension_factor + VecZnx::bytes_of(module.n(), cols, glwe_infos.size()) * extension_factor } else { 0 }; @@ -61,16 +61,15 @@ where + vmp_xai + (vmp | (acc_big + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_apply_tmp_bytes()))) } else { - GLWECiphertext::alloc_bytes(glwe_infos) - + GLWECiphertext::external_product_inplace_scratch_space(module, glwe_infos, brk_infos) + GLWE::bytes_of(glwe_infos) + GLWE::external_product_inplace_tmp_bytes(module, glwe_infos, brk_infos) } } impl BlincRotationExecute for BlindRotationKeyPrepared where - Module: VecZnxBigAllocBytes - + VecZnxDftAllocBytes - + SvpPPolAllocBytes + Module: VecZnxBigBytesOf + + VecZnxDftBytesOf + + SvpPPolBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxIdftApplyTmpBytes @@ -99,8 +98,8 @@ where fn execute( &self, module: &Module, - res: &mut GLWECiphertext, - lwe: &LWECiphertext, + res: &mut GLWE, + lwe: &LWE, lut: &LookUpTable, scratch: &mut Scratch, ) { @@ -121,8 +120,8 @@ where fn execute_block_binary_extended( module: &Module, - res: &mut GLWECiphertext, - lwe: &LWECiphertext, + res: &mut GLWE, + lwe: &LWE, lut: &LookUpTable, brk: &BlindRotationKeyPrepared, scratch: &mut Scratch, @@ -130,9 +129,9 @@ fn execute_block_binary_extended( DataRes: DataMut, DataIn: DataRef, DataBrk: DataRef, - Module: VecZnxBigAllocBytes - + VecZnxDftAllocBytes - + SvpPPolAllocBytes + Module: VecZnxBigBytesOf + + VecZnxDftBytesOf + + SvpPPolBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxIdftApplyTmpBytes @@ -179,7 +178,7 @@ fn execute_block_binary_extended( } let mut lwe_2n: Vec = vec![0i64; (lwe.n() + 1).as_usize()]; // TODO: from scratch space - let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref(); + let lwe_ref: LWE<&[u8]> = lwe.to_ref(); let two_n: usize = 2 * n_glwe; let two_n_ext: usize = 2 * lut.domain_size(); @@ -288,8 +287,8 @@ fn execute_block_binary_extended( fn execute_block_binary( module: &Module, - res: &mut GLWECiphertext, - lwe: &LWECiphertext, + res: &mut GLWE, + lwe: &LWE, lut: &LookUpTable, brk: &BlindRotationKeyPrepared, scratch: &mut Scratch, @@ -297,9 +296,9 @@ fn execute_block_binary( DataRes: DataMut, DataIn: DataRef, DataBrk: DataRef, - Module: VecZnxBigAllocBytes - + VecZnxDftAllocBytes - + SvpPPolAllocBytes + Module: VecZnxBigBytesOf + + VecZnxDftBytesOf + + SvpPPolBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxIdftApplyTmpBytes @@ -324,8 +323,8 @@ fn execute_block_binary( { let n_glwe: usize = brk.n_glwe().into(); let mut lwe_2n: Vec = vec![0i64; (lwe.n() + 1).into()]; // TODO: from scratch space - let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut(); - let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref(); + let mut out_mut: GLWE<&mut [u8]> = res.to_mut(); + let lwe_ref: LWE<&[u8]> = lwe.to_ref(); let two_n: usize = n_glwe << 1; let base2k: usize = brk.base2k().into(); let dnum: usize = brk.dnum().into(); @@ -410,8 +409,8 @@ fn execute_block_binary( fn execute_standard( module: &Module, - res: &mut GLWECiphertext, - lwe: &LWECiphertext, + res: &mut GLWE, + lwe: &LWE, lut: &LookUpTable, brk: &BlindRotationKeyPrepared, scratch: &mut Scratch, @@ -419,9 +418,9 @@ fn execute_standard( DataRes: DataMut, DataIn: DataRef, DataBrk: DataRef, - Module: VecZnxBigAllocBytes - + VecZnxDftAllocBytes - + SvpPPolAllocBytes + Module: VecZnxBigBytesOf + + VecZnxDftBytesOf + + SvpPPolBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxIdftApplyTmpBytes @@ -480,8 +479,8 @@ fn execute_standard( } let mut lwe_2n: Vec = vec![0i64; (lwe.n() + 1).into()]; // TODO: from scratch space - let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut(); - let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref(); + let mut out_mut: GLWE<&mut [u8]> = res.to_mut(); + let lwe_ref: LWE<&[u8]> = lwe.to_ref(); mod_switch_2n( 2 * lut.domain_size(), @@ -519,7 +518,7 @@ fn execute_standard( out_mut.normalize_inplace(module, scratch_1); } -pub fn mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphertext<&[u8]>, rot_dir: LookUpTableRotationDirection) { +pub fn mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWE<&[u8]>, rot_dir: LookUpTableRotationDirection) { let base2k: usize = lwe.base2k().into(); let log2n: usize = usize::BITS as usize - (n - 1).leading_zeros() as usize + 1; diff --git a/poulpy-schemes/src/tfhe/blind_rotation/cggi_key.rs b/poulpy-schemes/src/tfhe/blind_rotation/cggi_key.rs index fbf506b..3114ea1 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/cggi_key.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/cggi_key.rs @@ -1,9 +1,8 @@ use poulpy_hal::{ api::{ ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, - VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, - VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, - VmpPMatAlloc, VmpPrepare, + VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, + VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VmpPMatAlloc, VmpPrepare, }, layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxView, ZnxViewMut}, source::Source, @@ -14,9 +13,9 @@ use std::marker::PhantomData; use poulpy_core::{ Distribution, layouts::{ - GGSWCiphertext, GGSWInfos, LWESecret, - compressed::GGSWCiphertextCompressed, - prepared::{GGSWCiphertextPrepared, GLWESecretPrepared}, + GGSW, GGSWInfos, LWESecret, + compressed::GGSWCompressed, + prepared::{GGSWPrepared, GLWESecretPrepared}, }, }; @@ -30,9 +29,9 @@ impl BlindRotationKeyAlloc for BlindRotationKey, CGGI> { where A: BlindRotationKeyInfos, { - let mut data: Vec>> = Vec::with_capacity(infos.n_lwe().into()); + let mut data: Vec>> = Vec::with_capacity(infos.n_lwe().into()); for _ in 0..infos.n_lwe().as_usize() { - data.push(GGSWCiphertext::alloc(infos)); + data.push(GGSW::alloc_from_infos(infos)); } Self { @@ -44,19 +43,19 @@ impl BlindRotationKeyAlloc for BlindRotationKey, CGGI> { } impl BlindRotationKey, CGGI> { - pub fn generate_from_sk_scratch_space(module: &Module, infos: &A) -> usize + pub fn generate_from_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GGSWInfos, - Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes, + Module: VecZnxNormalizeTmpBytes + VecZnxDftBytesOf, { - GGSWCiphertext::encrypt_sk_scratch_space(module, infos) + GGSW::encrypt_sk_tmp_bytes(module, infos) } } impl BlindRotationKeyEncryptSk for BlindRotationKey where Module: VecZnxAddScalarInplace - + VecZnxDftAllocBytes + + VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -121,8 +120,8 @@ where where A: BlindRotationKeyInfos, { - let mut data: Vec, B>> = Vec::with_capacity(infos.n_lwe().into()); - (0..infos.n_lwe().as_usize()).for_each(|_| data.push(GGSWCiphertextPrepared::alloc(module, infos))); + let mut data: Vec, B>> = Vec::with_capacity(infos.n_lwe().into()); + (0..infos.n_lwe().as_usize()).for_each(|_| data.push(GGSWPrepared::alloc_from_infos(module, infos))); Self { data, dist: Distribution::NONE, @@ -137,8 +136,8 @@ impl BlindRotationKeyCompressed, CGGI> { where A: BlindRotationKeyInfos, { - let mut data: Vec>> = Vec::with_capacity(infos.n_lwe().into()); - (0..infos.n_lwe().as_usize()).for_each(|_| data.push(GGSWCiphertextCompressed::alloc(infos))); + let mut data: Vec>> = Vec::with_capacity(infos.n_lwe().into()); + (0..infos.n_lwe().as_usize()).for_each(|_| data.push(GGSWCompressed::alloc_from_infos(infos))); Self { keys: data, dist: Distribution::NONE, @@ -146,12 +145,12 @@ impl BlindRotationKeyCompressed, CGGI> { } } - pub fn generate_from_sk_scratch_space(module: &Module, infos: &A) -> usize + pub fn generate_from_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GGSWInfos, - Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes, + Module: VecZnxNormalizeTmpBytes + VecZnxDftBytesOf, { - GGSWCiphertextCompressed::encrypt_sk_scratch_space(module, infos) + GGSWCompressed::encrypt_sk_tmp_bytes(module, infos) } } @@ -169,7 +168,7 @@ impl BlindRotationKeyCompressed { DataSkGLWE: DataRef, DataSkLWE: DataRef, Module: VecZnxAddScalarInplace - + VecZnxDftAllocBytes + + VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace diff --git a/poulpy-schemes/src/tfhe/blind_rotation/key.rs b/poulpy-schemes/src/tfhe/blind_rotation/key.rs index 86dbd21..276539f 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/key.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/key.rs @@ -8,7 +8,7 @@ use std::{fmt, marker::PhantomData}; use poulpy_core::{ Distribution, layouts::{ - Base2K, Degree, Dnum, Dsize, GGSWCiphertext, GGSWInfos, GLWEInfos, LWEInfos, LWESecret, Rank, TorusPrecision, + Base2K, Dnum, Dsize, GGSW, GGSWInfos, GLWEInfos, LWEInfos, LWESecret, Rank, RingDegree, TorusPrecision, prepared::GLWESecretPrepared, }, }; @@ -19,8 +19,8 @@ use crate::tfhe::blind_rotation::BlindRotationAlgo; #[derive(PartialEq, Eq, Copy, Clone, Debug)] pub struct BlindRotationKeyLayout { - pub n_glwe: Degree, - pub n_lwe: Degree, + pub n_glwe: RingDegree, + pub n_lwe: RingDegree, pub base2k: Base2K, pub k: TorusPrecision, pub dnum: Dnum, @@ -28,11 +28,11 @@ pub struct BlindRotationKeyLayout { } impl BlindRotationKeyInfos for BlindRotationKeyLayout { - fn n_glwe(&self) -> Degree { + fn n_glwe(&self) -> RingDegree { self.n_glwe } - fn n_lwe(&self) -> Degree { + fn n_lwe(&self) -> RingDegree { self.n_lwe } } @@ -62,7 +62,7 @@ impl LWEInfos for BlindRotationKeyLayout { self.k } - fn n(&self) -> Degree { + fn n(&self) -> RingDegree { self.n_glwe } } @@ -71,8 +71,8 @@ pub trait BlindRotationKeyInfos where Self: GGSWInfos, { - fn n_glwe(&self) -> Degree; - fn n_lwe(&self) -> Degree; + fn n_glwe(&self) -> RingDegree; + fn n_lwe(&self) -> RingDegree; } pub trait BlindRotationKeyAlloc { @@ -98,7 +98,7 @@ pub trait BlindRotationKeyEncryptSk { #[derive(Clone)] pub struct BlindRotationKey { - pub(crate) keys: Vec>, + pub(crate) keys: Vec>, pub(crate) dist: Distribution, pub(crate) _phantom: PhantomData, } @@ -178,12 +178,12 @@ impl WriterTo for BlindRotationKey { } impl BlindRotationKeyInfos for BlindRotationKey { - fn n_glwe(&self) -> Degree { + fn n_glwe(&self) -> RingDegree { self.n() } - fn n_lwe(&self) -> Degree { - Degree(self.keys.len() as u32) + fn n_lwe(&self) -> RingDegree { + RingDegree(self.keys.len() as u32) } } @@ -206,7 +206,7 @@ impl LWEInfos for BlindRotationKey { self.keys[0].k() } - fn n(&self) -> Degree { + fn n(&self) -> RingDegree { self.keys[0].n() } diff --git a/poulpy-schemes/src/tfhe/blind_rotation/key_compressed.rs b/poulpy-schemes/src/tfhe/blind_rotation/key_compressed.rs index 51ff139..af7e40a 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/key_compressed.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/key_compressed.rs @@ -8,14 +8,14 @@ use std::{fmt, marker::PhantomData}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use poulpy_core::{ Distribution, - layouts::{Base2K, Degree, Dsize, GGSWInfos, GLWEInfos, LWEInfos, TorusPrecision, compressed::GGSWCiphertextCompressed}, + layouts::{Base2K, Dsize, GGSWInfos, GLWEInfos, LWEInfos, RingDegree, TorusPrecision, compressed::GGSWCompressed}, }; use crate::tfhe::blind_rotation::{BlindRotationAlgo, BlindRotationKeyInfos}; #[derive(Clone)] pub struct BlindRotationKeyCompressed { - pub(crate) keys: Vec>, + pub(crate) keys: Vec>, pub(crate) dist: Distribution, pub(crate) _phantom: PhantomData, } @@ -94,17 +94,17 @@ impl WriterTo for BlindRotationKeyCompressed } impl BlindRotationKeyInfos for BlindRotationKeyCompressed { - fn n_glwe(&self) -> Degree { + fn n_glwe(&self) -> RingDegree { self.n() } - fn n_lwe(&self) -> Degree { - Degree(self.keys.len() as u32) + fn n_lwe(&self) -> RingDegree { + RingDegree(self.keys.len() as u32) } } impl LWEInfos for BlindRotationKeyCompressed { - fn n(&self) -> Degree { + fn n(&self) -> RingDegree { self.keys[0].n() } diff --git a/poulpy-schemes/src/tfhe/blind_rotation/key_prepared.rs b/poulpy-schemes/src/tfhe/blind_rotation/key_prepared.rs index 8719de4..1f99ef5 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/key_prepared.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/key_prepared.rs @@ -8,8 +8,8 @@ use std::marker::PhantomData; use poulpy_core::{ Distribution, layouts::{ - Base2K, Degree, Dnum, Dsize, GGSWInfos, GLWEInfos, LWEInfos, Rank, TorusPrecision, - prepared::{GGSWCiphertextPrepared, Prepare, PrepareAlloc}, + Base2K, Dnum, Dsize, GGSWInfos, GLWEInfos, LWEInfos, Rank, RingDegree, TorusPrecision, + prepared::{GGSWPrepared, Prepare, PrepareAlloc}, }, }; @@ -23,19 +23,19 @@ pub trait BlindRotationKeyPreparedAlloc { #[derive(PartialEq, Eq)] pub struct BlindRotationKeyPrepared { - pub(crate) data: Vec>, + pub(crate) data: Vec>, pub(crate) dist: Distribution, pub(crate) x_pow_a: Option, B>>>, pub(crate) _phantom: PhantomData, } impl BlindRotationKeyInfos for BlindRotationKeyPrepared { - fn n_glwe(&self) -> Degree { + fn n_glwe(&self) -> RingDegree { self.n() } - fn n_lwe(&self) -> Degree { - Degree(self.data.len() as u32) + fn n_lwe(&self) -> RingDegree { + RingDegree(self.data.len() as u32) } } @@ -48,7 +48,7 @@ impl LWEInfos for BlindRotationKeyP self.data[0].k() } - fn n(&self) -> Degree { + fn n(&self) -> RingDegree { self.data[0].n() } diff --git a/poulpy-schemes/src/tfhe/blind_rotation/mod.rs b/poulpy-schemes/src/tfhe/blind_rotation/mod.rs index bd83a08..8cc262b 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/mod.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/mod.rs @@ -14,7 +14,7 @@ pub use lut::*; pub mod tests; -use poulpy_core::layouts::{GLWECiphertext, LWECiphertext}; +use poulpy_core::layouts::{GLWE, LWE}; use poulpy_hal::layouts::{Backend, DataMut, DataRef, Module, Scratch}; pub trait BlindRotationAlgo {} @@ -27,8 +27,8 @@ pub trait BlincRotationExecute { fn execute( &self, module: &Module, - res: &mut GLWECiphertext, - lwe: &LWECiphertext, + res: &mut GLWE, + lwe: &LWE, lut: &LookUpTable, scratch: &mut Scratch, ); diff --git a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs index 190adff..4b8131a 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs @@ -1,9 +1,9 @@ use poulpy_hal::{ api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, - VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAdd, VecZnxDftAddInplace, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftSubInplace, VecZnxDftZero, VecZnxFillUniform, VecZnxIdftApply, + VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAdd, VecZnxDftAddInplace, + VecZnxDftApply, VecZnxDftBytesOf, VecZnxDftSubInplace, VecZnxDftZero, VecZnxFillUniform, VecZnxIdftApply, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpBytes, VecZnxMulXpMinusOneInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, ZnAddNormal, @@ -19,20 +19,19 @@ use poulpy_hal::{ use crate::tfhe::blind_rotation::{ BlincRotationExecute, BlindRotationKey, BlindRotationKeyAlloc, BlindRotationKeyEncryptSk, BlindRotationKeyLayout, - BlindRotationKeyPrepared, CGGI, LookUpTable, cggi_blind_rotate_scratch_space, mod_switch_2n, + BlindRotationKeyPrepared, CGGI, LookUpTable, cggi_blind_rotate_tmp_bytes, mod_switch_2n, }; use poulpy_core::layouts::{ - GLWECiphertext, GLWECiphertextLayout, GLWEPlaintext, GLWESecret, LWECiphertext, LWECiphertextLayout, LWECiphertextToRef, - LWEInfos, LWEPlaintext, LWESecret, + GLWE, GLWELayout, GLWEPlaintext, GLWESecret, LWE, LWEInfos, LWELayout, LWEPlaintext, LWESecret, LWEToRef, prepared::{GLWESecretPrepared, PrepareAlloc}, }; pub fn test_blind_rotation(module: &Module, n_lwe: usize, block_size: usize, extension_factor: usize) where - Module: VecZnxBigAllocBytes - + VecZnxDftAllocBytes - + SvpPPolAllocBytes + Module: VecZnxBigBytesOf + + VecZnxDftBytesOf + + SvpPPolBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxIdftApplyTmpBytes @@ -111,31 +110,31 @@ where rank: rank.into(), }; - let glwe_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_infos: GLWELayout = GLWELayout { n: n_glwe.into(), base2k: base2k.into(), k: k_res.into(), rank: rank.into(), }; - let lwe_infos: LWECiphertextLayout = LWECiphertextLayout { + let lwe_infos: LWELayout = LWELayout { n: n_lwe.into(), k: k_lwe.into(), base2k: base2k.into(), }; - let mut scratch: ScratchOwned = ScratchOwned::::alloc(BlindRotationKey::generate_from_sk_scratch_space( + let mut scratch: ScratchOwned = ScratchOwned::::alloc(BlindRotationKey::generate_from_sk_tmp_bytes( module, &brk_infos, )); - let mut sk_glwe: GLWESecret> = GLWESecret::alloc(&glwe_infos); + let mut sk_glwe: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_infos); sk_glwe.fill_ternary_prob(0.5, &mut source_xs); let sk_glwe_dft: GLWESecretPrepared, B> = sk_glwe.prepare_alloc(module, scratch.borrow()); let mut sk_lwe: LWESecret> = LWESecret::alloc(n_lwe.into()); sk_lwe.fill_binary_block(block_size, &mut source_xs); - let mut scratch_br: ScratchOwned = ScratchOwned::::alloc(cggi_blind_rotate_scratch_space( + let mut scratch_br: ScratchOwned = ScratchOwned::::alloc(cggi_blind_rotate_tmp_bytes( module, block_size, extension_factor, @@ -154,9 +153,9 @@ where scratch.borrow(), ); - let mut lwe: LWECiphertext> = LWECiphertext::alloc(&lwe_infos); + let mut lwe: LWE> = LWE::alloc_from_infos(&lwe_infos); - let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc(&lwe_infos); + let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc_from_infos(&lwe_infos); let x: i64 = 15 % (message_modulus as i64); @@ -175,13 +174,13 @@ where let mut lut: LookUpTable = LookUpTable::alloc(module, base2k, k_lut, extension_factor); lut.set(module, &f_vec, log_message_modulus + 1); - let mut res: GLWECiphertext> = GLWECiphertext::alloc(&glwe_infos); + let mut res: GLWE> = GLWE::alloc_from_infos(&glwe_infos); let brk_prepared: BlindRotationKeyPrepared, CGGI, B> = brk.prepare_alloc(module, scratch.borrow()); brk_prepared.execute(module, &mut res, &lwe, &lut, scratch_br.borrow()); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_infos); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_infos); res.decrypt(module, &mut pt_have, &sk_glwe_dft, scratch.borrow()); diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs index 45b9717..1799127 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs @@ -3,9 +3,9 @@ use std::collections::HashMap; use poulpy_hal::{ api::{ ScratchAvailable, TakeMatZnx, TakeSlice, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice, TakeVecZnxSlice, - VecZnxAddInplace, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigAutomorphismInplace, + VecZnxAddInplace, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxCopy, VecZnxDftAddInplace, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNegateInplace, + VecZnxDftApply, VecZnxDftBytesOf, VecZnxDftCopy, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxRshInplace, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, @@ -15,12 +15,12 @@ use poulpy_hal::{ }; use poulpy_core::{ - GLWEOperations, TakeGGLWE, TakeGLWECt, - layouts::{Dsize, GGLWECiphertextLayout, GGSWInfos, GLWEInfos, LWEInfos}, + GLWEOperations, TakeGGLWE, TakeGLWE, + layouts::{Dsize, GGLWELayout, GGSWInfos, GLWEInfos, LWEInfos}, }; use poulpy_core::glwe_packing; -use poulpy_core::layouts::{GGSWCiphertext, GLWECiphertext, LWECiphertext, prepared::GGLWEAutomorphismKeyPrepared}; +use poulpy_core::layouts::{GGSW, GLWE, LWE, prepared::AutomorphismKeyPrepared}; use crate::tfhe::{ blind_rotation::{ @@ -44,7 +44,7 @@ where + VecZnxNegateInplace + VecZnxCopy + VecZnxSubInplace - + VecZnxDftAllocBytes + + VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft @@ -56,7 +56,7 @@ where + VecZnxAutomorphismInplace + VecZnxBigSubSmallNegateInplace + VecZnxRotateInplaceTmpBytes - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxDftAddInplace + VecZnxRotate + VecZnxNormalize, @@ -74,8 +74,8 @@ where fn execute_to_constant( &self, module: &Module, - res: &mut GGSWCiphertext, - lwe: &LWECiphertext, + res: &mut GGSW, + lwe: &LWE, log_domain: usize, extension_factor: usize, scratch: &mut Scratch, @@ -97,8 +97,8 @@ where &self, module: &Module, log_gap_out: usize, - res: &mut GGSWCiphertext, - lwe: &LWECiphertext, + res: &mut GGSW, + lwe: &LWE, log_domain: usize, extension_factor: usize, scratch: &mut Scratch, @@ -122,8 +122,8 @@ pub fn circuit_bootstrap_core( to_exponent: bool, module: &Module, log_gap_out: usize, - res: &mut GGSWCiphertext, - lwe: &LWECiphertext, + res: &mut GGSW, + lwe: &LWE, log_domain: usize, extension_factor: usize, key: &CircuitBootstrappingKeyPrepared, @@ -145,7 +145,7 @@ pub fn circuit_bootstrap_core( + VecZnxNegateInplace + VecZnxCopy + VecZnxSubInplace - + VecZnxDftAllocBytes + + VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft @@ -156,7 +156,7 @@ pub fn circuit_bootstrap_core( + VecZnxBigNormalize + VecZnxAutomorphismInplace + VecZnxBigSubSmallNegateInplace - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxDftAddInplace + VecZnxRotateInplaceTmpBytes + VecZnxRotate @@ -214,7 +214,7 @@ pub fn circuit_bootstrap_core( // TODO: separate GGSW k from output of blind rotation k let (mut res_glwe, scratch_1) = scratch.take_glwe_ct(res); - let gglwe_infos: GGLWECiphertextLayout = GGLWECiphertextLayout { + let gglwe_infos: GGLWELayout = GGLWELayout { n: n.into(), base2k: base2k.into(), k: k.into(), @@ -233,7 +233,7 @@ pub fn circuit_bootstrap_core( let log_gap_in: usize = (usize::BITS - (gap * alpha - 1).leading_zeros()) as _; (0..dnum).for_each(|i| { - let mut tmp_glwe: GLWECiphertext<&mut [u8]> = tmp_gglwe.at_mut(i, 0); + let mut tmp_glwe: GLWE<&mut [u8]> = tmp_gglwe.at_mut(i, 0); if to_exponent { // Isolates i-th LUT and moves coefficients according to requested gap. @@ -263,12 +263,12 @@ pub fn circuit_bootstrap_core( #[allow(clippy::too_many_arguments)] fn post_process( module: &Module, - res: &mut GLWECiphertext, - a: &GLWECiphertext, + res: &mut GLWE, + a: &GLWE, log_gap_in: usize, log_gap_out: usize, log_domain: usize, - auto_keys: &HashMap, B>>, + auto_keys: &HashMap, B>>, scratch: &mut Scratch, ) where DataRes: DataMut, @@ -286,7 +286,7 @@ fn post_process( + VecZnxNegateInplace + VecZnxCopy + VecZnxSubInplace - + VecZnxDftAllocBytes + + VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft @@ -303,7 +303,7 @@ fn post_process( { let log_n: usize = module.log_n(); - let mut cts: HashMap>> = HashMap::new(); + let mut cts: HashMap>> = HashMap::new(); // First partial trace, vanishes all coefficients which are not multiples of gap_in // [1, 1, 1, 1, 0, 0, 0, ..., 0, 0, -1, -1, -1, -1] -> [1, 0, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0, 0] @@ -322,7 +322,7 @@ fn post_process( let steps: usize = 1 << log_domain; // TODO: from Scratch - let mut cts_vec: Vec>> = Vec::new(); + let mut cts_vec: Vec>> = Vec::new(); for i in 0..steps { if i != 0 { @@ -336,7 +336,7 @@ fn post_process( } glwe_packing(module, &mut cts, log_gap_out, auto_keys, scratch); - let packed: &mut GLWECiphertext> = cts.remove(&0).unwrap(); + let packed: &mut GLWE> = cts.remove(&0).unwrap(); res.trace( module, log_n - log_gap_out, diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/key.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/key.rs index 427cf75..1e52a76 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/key.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/key.rs @@ -1,15 +1,15 @@ use poulpy_core::layouts::{ - GGLWEAutomorphismKey, GGLWEAutomorphismKeyLayout, GGLWEInfos, GGLWETensorKey, GGLWETensorKeyLayout, GGSWInfos, - GLWECiphertext, GLWEInfos, GLWESecret, LWEInfos, LWESecret, - prepared::{GGLWEAutomorphismKeyPrepared, GGLWETensorKeyPrepared, GLWESecretPrepared, PrepareAlloc}, + AutomorphismKey, AutomorphismKeyLayout, GGLWEInfos, GGSWInfos, GLWE, GLWEInfos, GLWESecret, LWEInfos, LWESecret, TensorKey, + TensorKeyLayout, + prepared::{AutomorphismKeyPrepared, GLWESecretPrepared, PrepareAlloc, TensorKeyPrepared}, }; use std::collections::HashMap; use poulpy_hal::{ api::{ - ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, + ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, TakeScalarZnx, TakeSvpPPol, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, - VecZnxAutomorphism, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, + VecZnxAutomorphism, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpPMatAlloc, VmpPrepare, }, @@ -24,19 +24,19 @@ use crate::tfhe::blind_rotation::{ pub trait CircuitBootstrappingKeyInfos { fn brk_infos(&self) -> BlindRotationKeyLayout; - fn atk_infos(&self) -> GGLWEAutomorphismKeyLayout; - fn tsk_infos(&self) -> GGLWETensorKeyLayout; + fn atk_infos(&self) -> AutomorphismKeyLayout; + fn tsk_infos(&self) -> TensorKeyLayout; } #[derive(Debug, Clone, Copy)] pub struct CircuitBootstrappingKeyLayout { pub layout_brk: BlindRotationKeyLayout, - pub layout_atk: GGLWEAutomorphismKeyLayout, - pub layout_tsk: GGLWETensorKeyLayout, + pub layout_atk: AutomorphismKeyLayout, + pub layout_tsk: TensorKeyLayout, } impl CircuitBootstrappingKeyInfos for CircuitBootstrappingKeyLayout { - fn atk_infos(&self) -> GGLWEAutomorphismKeyLayout { + fn atk_infos(&self) -> AutomorphismKeyLayout { self.layout_atk } @@ -44,7 +44,7 @@ impl CircuitBootstrappingKeyInfos for CircuitBootstrappingKeyLayout { self.layout_brk } - fn tsk_infos(&self) -> GGLWETensorKeyLayout { + fn tsk_infos(&self) -> TensorKeyLayout { self.layout_tsk } } @@ -68,8 +68,8 @@ pub trait CircuitBootstrappingKeyEncryptSk { pub struct CircuitBootstrappingKey { pub(crate) brk: BlindRotationKey, - pub(crate) tsk: GGLWETensorKey>, - pub(crate) atk: HashMap>>, + pub(crate) tsk: TensorKey>, + pub(crate) atk: HashMap>>, } impl CircuitBootstrappingKeyEncryptSk for CircuitBootstrappingKey, BRA> @@ -78,7 +78,7 @@ where Module: SvpApplyDftToDft + VecZnxIdftApplyTmpA + VecZnxAddScalarInplace - + VecZnxDftAllocBytes + + VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -93,7 +93,7 @@ where + VecZnxSub + SvpPrepare + VecZnxSwitchRing - + SvpPPolAllocBytes + + SvpPPolBytesOf + SvpPPolAlloc + VecZnxAutomorphism, Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeSvpPPol + TakeVecZnxBig, @@ -117,14 +117,14 @@ where assert_eq!(sk_glwe.n(), cbt_infos.atk_infos().n()); assert_eq!(sk_glwe.n(), cbt_infos.tsk_infos().n()); - let atk_infos: GGLWEAutomorphismKeyLayout = cbt_infos.atk_infos(); + let atk_infos: AutomorphismKeyLayout = cbt_infos.atk_infos(); let brk_infos: BlindRotationKeyLayout = cbt_infos.brk_infos(); - let trk_infos: GGLWETensorKeyLayout = cbt_infos.tsk_infos(); + let trk_infos: TensorKeyLayout = cbt_infos.tsk_infos(); - let mut auto_keys: HashMap>> = HashMap::new(); - let gal_els: Vec = GLWECiphertext::trace_galois_elements(module); + let mut auto_keys: HashMap>> = HashMap::new(); + let gal_els: Vec = GLWE::trace_galois_elements(module); gal_els.iter().for_each(|gal_el| { - let mut key: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&atk_infos); + let mut key: AutomorphismKey> = AutomorphismKey::alloc_from_infos(&atk_infos); key.encrypt_sk(module, *gal_el, sk_glwe, source_xa, source_xe, scratch); auto_keys.insert(*gal_el, key); }); @@ -141,7 +141,7 @@ where scratch, ); - let mut tsk: GGLWETensorKey> = GGLWETensorKey::alloc(&trk_infos); + let mut tsk: TensorKey> = TensorKey::alloc_from_infos(&trk_infos); tsk.encrypt_sk(module, sk_glwe, source_xa, source_xe, scratch); Self { @@ -154,14 +154,14 @@ where pub struct CircuitBootstrappingKeyPrepared { pub(crate) brk: BlindRotationKeyPrepared, - pub(crate) tsk: GGLWETensorKeyPrepared, B>, - pub(crate) atk: HashMap, B>>, + pub(crate) tsk: TensorKeyPrepared, B>, + pub(crate) atk: HashMap, B>>, } impl CircuitBootstrappingKeyInfos for CircuitBootstrappingKeyPrepared { - fn atk_infos(&self) -> GGLWEAutomorphismKeyLayout { + fn atk_infos(&self) -> AutomorphismKeyLayout { let (_, atk) = self.atk.iter().next().expect("atk is empty"); - GGLWEAutomorphismKeyLayout { + AutomorphismKeyLayout { n: atk.n(), base2k: atk.base2k(), k: atk.k(), @@ -182,8 +182,8 @@ impl CircuitBootstrappingKeyInfo } } - fn tsk_infos(&self) -> GGLWETensorKeyLayout { - GGLWETensorKeyLayout { + fn tsk_infos(&self) -> TensorKeyLayout { + TensorKeyLayout { n: self.tsk.n(), base2k: self.tsk.base2k(), k: self.tsk.k(), @@ -199,13 +199,13 @@ impl PrepareAlloc: VmpPMatAlloc + VmpPrepare, BlindRotationKey: PrepareAlloc, BRA, B>>, - GGLWETensorKey: PrepareAlloc, B>>, - GGLWEAutomorphismKey: PrepareAlloc, B>>, + TensorKey: PrepareAlloc, B>>, + AutomorphismKey: PrepareAlloc, B>>, { fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> CircuitBootstrappingKeyPrepared, BRA, B> { let brk: BlindRotationKeyPrepared, BRA, B> = self.brk.prepare_alloc(module, scratch); - let tsk: GGLWETensorKeyPrepared, B> = self.tsk.prepare_alloc(module, scratch); - let mut atk: HashMap, B>> = HashMap::new(); + let tsk: TensorKeyPrepared, B> = self.tsk.prepare_alloc(module, scratch); + let mut atk: HashMap, B>> = HashMap::new(); for (key, value) in &self.atk { atk.insert(*key, value.prepare_alloc(module, scratch)); } diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/mod.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/mod.rs index 86dcf5e..5835765 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/mod.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/mod.rs @@ -5,7 +5,7 @@ pub mod tests; pub use circuit::*; pub use key::*; -use poulpy_core::layouts::{GGSWCiphertext, LWECiphertext}; +use poulpy_core::layouts::{GGSW, LWE}; use poulpy_hal::layouts::{Backend, DataMut, DataRef, Module, Scratch}; @@ -13,8 +13,8 @@ pub trait CirtuitBootstrappingExecute { fn execute_to_constant( &self, module: &Module, - res: &mut GGSWCiphertext, - lwe: &LWECiphertext, + res: &mut GGSW, + lwe: &LWE, log_domain: usize, extension_factor: usize, scratch: &mut Scratch, @@ -25,8 +25,8 @@ pub trait CirtuitBootstrappingExecute { &self, module: &Module, log_gap_out: usize, - res: &mut GGSWCiphertext, - lwe: &LWECiphertext, + res: &mut GGSW, + lwe: &LWE, log_domain: usize, extension_factor: usize, scratch: &mut Scratch, diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/circuit_bootstrapping.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/circuit_bootstrapping.rs index 40f5448..cfe238f 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/circuit_bootstrapping.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/circuit_bootstrapping.rs @@ -2,11 +2,11 @@ use std::time::Instant; use poulpy_hal::{ api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, - VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigAutomorphismInplace, + VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAutomorphismInplace, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxCopy, VecZnxDftAddInplace, - VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy, VecZnxFillUniform, VecZnxIdftApplyConsume, + VecZnxDftAlloc, VecZnxDftApply, VecZnxDftBytesOf, VecZnxDftCopy, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxRshInplace, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, ZnAddNormal, ZnFillUniform, @@ -31,13 +31,11 @@ use crate::tfhe::{ }, }; -use poulpy_core::layouts::{ - Dsize, GGLWEAutomorphismKeyLayout, GGLWETensorKeyLayout, GGSWCiphertextLayout, LWECiphertextLayout, prepared::PrepareAlloc, -}; +use poulpy_core::layouts::{AutomorphismKeyLayout, Dsize, GGSWLayout, LWELayout, TensorKeyLayout, prepared::PrepareAlloc}; use poulpy_core::layouts::{ - GGSWCiphertext, GLWECiphertext, GLWEPlaintext, GLWESecret, LWECiphertext, LWEPlaintext, LWESecret, - prepared::{GGSWCiphertextPrepared, GLWESecretPrepared}, + GGSW, GLWE, GLWEPlaintext, GLWESecret, LWE, LWEPlaintext, LWESecret, + prepared::{GGSWPrepared, GLWESecretPrepared}, }; pub fn test_circuit_bootstrapping_to_exponent(module: &Module) @@ -45,7 +43,7 @@ where Module: VecZnxFillUniform + VecZnxAddNormal + VecZnxNormalizeInplace - + VecZnxDftAllocBytes + + VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -58,7 +56,7 @@ where + VecZnxAddScalarInplace + VecZnxAutomorphism + VecZnxSwitchRing - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxIdftApplyTmpA + SvpApplyDftToDft + VecZnxBigAddInplace @@ -73,7 +71,7 @@ where + VmpApplyDftToDftTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + SvpPPolAllocBytes + + SvpPPolBytesOf + VecZnxRotateInplace + VecZnxBigAutomorphismInplace + VecZnxRshInplace @@ -83,7 +81,7 @@ where + VecZnxAutomorphismInplace + VecZnxBigSubSmallNegateInplace + VecZnxRotateInplaceTmpBytes - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxDftAddInplace + VecZnxRotate + ZnFillUniform @@ -128,7 +126,7 @@ where let k_ggsw_res: usize = 4 * base2k; let rows_ggsw_res: usize = 2; - let lwe_infos: LWECiphertextLayout = LWECiphertextLayout { + let lwe_infos: LWELayout = LWELayout { n: n_lwe.into(), k: k_lwe_ct.into(), base2k: base2k.into(), @@ -143,7 +141,7 @@ where dnum: rows_brk.into(), rank: rank.into(), }, - layout_atk: GGLWEAutomorphismKeyLayout { + layout_atk: AutomorphismKeyLayout { n: n_glwe.into(), base2k: base2k.into(), k: k_atk.into(), @@ -151,7 +149,7 @@ where rank: rank.into(), dsize: Dsize(1), }, - layout_tsk: GGLWETensorKeyLayout { + layout_tsk: TensorKeyLayout { n: n_glwe.into(), base2k: base2k.into(), k: k_tsk.into(), @@ -161,7 +159,7 @@ where }, }; - let ggsw_infos: GGSWCiphertextLayout = GGSWCiphertextLayout { + let ggsw_infos: GGSWLayout = GGSWLayout { n: n_glwe.into(), base2k: base2k.into(), k: k_ggsw_res.into(), @@ -179,19 +177,19 @@ where let mut sk_lwe: LWESecret> = LWESecret::alloc(n_lwe.into()); sk_lwe.fill_binary_block(block_size, &mut source_xs); - let mut sk_glwe: GLWESecret> = GLWESecret::alloc_with(n_glwe.into(), rank.into()); + let mut sk_glwe: GLWESecret> = GLWESecret::alloc(n_glwe.into(), rank.into()); sk_glwe.fill_ternary_prob(0.5, &mut source_xs); let sk_glwe_prepared: GLWESecretPrepared, B> = sk_glwe.prepare_alloc(module, scratch.borrow()); let data: i64 = 1; - let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc_with(base2k.into(), k_lwe_pt.into()); + let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc(base2k.into(), k_lwe_pt.into()); pt_lwe.encode_i64(data, (k_lwe_pt + 1).into()); println!("pt_lwe: {pt_lwe}"); - let mut ct_lwe: LWECiphertext> = LWECiphertext::alloc(&lwe_infos); + let mut ct_lwe: LWE> = LWE::alloc_from_infos(&lwe_infos); ct_lwe.encrypt_sk(module, &pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe); let now: Instant = Instant::now(); @@ -206,7 +204,7 @@ where ); println!("CBT-KGEN: {} ms", now.elapsed().as_millis()); - let mut res: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_infos); + let mut res: GGSW> = GGSW::alloc_from_infos(&ggsw_infos); let log_gap_out = 1; @@ -236,8 +234,8 @@ where res.print_noise(module, &sk_glwe_prepared, &pt_ggsw); - let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(&ggsw_infos); - let mut pt_glwe: GLWEPlaintext> = GLWEPlaintext::alloc(&ggsw_infos); + let mut ct_glwe: GLWE> = GLWE::alloc_from_infos(&ggsw_infos); + let mut pt_glwe: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&ggsw_infos); pt_glwe.data.at_mut(0, 0)[0] = 1 << (base2k - 2); ct_glwe.encrypt_sk( @@ -249,11 +247,11 @@ where scratch.borrow(), ); - let res_prepared: GGSWCiphertextPrepared, B> = res.prepare_alloc(module, scratch.borrow()); + let res_prepared: GGSWPrepared, B> = res.prepare_alloc(module, scratch.borrow()); ct_glwe.external_product_inplace(module, &res_prepared, scratch.borrow()); - let mut pt_res: GLWEPlaintext> = GLWEPlaintext::alloc(&ggsw_infos); + let mut pt_res: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&ggsw_infos); ct_glwe.decrypt(module, &mut pt_res, &sk_glwe_prepared, scratch.borrow()); // Parameters are set such that the first limb should be noiseless. @@ -267,7 +265,7 @@ where Module: VecZnxFillUniform + VecZnxAddNormal + VecZnxNormalizeInplace - + VecZnxDftAllocBytes + + VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace @@ -280,7 +278,7 @@ where + VecZnxAddScalarInplace + VecZnxAutomorphism + VecZnxSwitchRing - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxIdftApplyTmpA + SvpApplyDftToDft + VecZnxBigAddInplace @@ -295,7 +293,7 @@ where + VmpApplyDftToDftTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + SvpPPolAllocBytes + + SvpPPolBytesOf + VecZnxRotateInplace + VecZnxBigAutomorphismInplace + VecZnxRotateInplaceTmpBytes @@ -305,7 +303,7 @@ where + VecZnxCopy + VecZnxAutomorphismInplace + VecZnxBigSubSmallNegateInplace - + VecZnxBigAllocBytes + + VecZnxBigBytesOf + VecZnxDftAddInplace + VecZnxRotate + ZnFillUniform @@ -350,7 +348,7 @@ where let k_ggsw_res: usize = 4 * base2k; let rows_ggsw_res: usize = 3; - let lwe_infos: LWECiphertextLayout = LWECiphertextLayout { + let lwe_infos: LWELayout = LWELayout { n: n_lwe.into(), k: k_lwe_ct.into(), base2k: base2k.into(), @@ -365,7 +363,7 @@ where dnum: rows_brk.into(), rank: rank.into(), }, - layout_atk: GGLWEAutomorphismKeyLayout { + layout_atk: AutomorphismKeyLayout { n: n_glwe.into(), base2k: base2k.into(), k: k_atk.into(), @@ -373,7 +371,7 @@ where rank: rank.into(), dsize: Dsize(1), }, - layout_tsk: GGLWETensorKeyLayout { + layout_tsk: TensorKeyLayout { n: n_glwe.into(), base2k: base2k.into(), k: k_tsk.into(), @@ -383,7 +381,7 @@ where }, }; - let ggsw_infos: GGSWCiphertextLayout = GGSWCiphertextLayout { + let ggsw_infos: GGSWLayout = GGSWLayout { n: n_glwe.into(), base2k: base2k.into(), k: k_ggsw_res.into(), @@ -401,19 +399,19 @@ where let mut sk_lwe: LWESecret> = LWESecret::alloc(n_lwe.into()); sk_lwe.fill_binary_block(block_size, &mut source_xs); - let mut sk_glwe: GLWESecret> = GLWESecret::alloc_with(n_glwe.into(), rank.into()); + let mut sk_glwe: GLWESecret> = GLWESecret::alloc(n_glwe.into(), rank.into()); sk_glwe.fill_ternary_prob(0.5, &mut source_xs); let sk_glwe_prepared: GLWESecretPrepared, B> = sk_glwe.prepare_alloc(module, scratch.borrow()); let data: i64 = 1; - let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc_with(base2k.into(), k_lwe_pt.into()); + let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc(base2k.into(), k_lwe_pt.into()); pt_lwe.encode_i64(data, (k_lwe_pt + 1).into()); println!("pt_lwe: {pt_lwe}"); - let mut ct_lwe: LWECiphertext> = LWECiphertext::alloc(&lwe_infos); + let mut ct_lwe: LWE> = LWE::alloc_from_infos(&lwe_infos); ct_lwe.encrypt_sk(module, &pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe); let now: Instant = Instant::now(); @@ -428,7 +426,7 @@ where ); println!("CBT-KGEN: {} ms", now.elapsed().as_millis()); - let mut res: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_infos); + let mut res: GGSW> = GGSW::alloc_from_infos(&ggsw_infos); let cbt_prepared: CircuitBootstrappingKeyPrepared, BRA, B> = cbt_key.prepare_alloc(module, scratch.borrow()); @@ -449,8 +447,8 @@ where res.print_noise(module, &sk_glwe_prepared, &pt_ggsw); - let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(&ggsw_infos); - let mut pt_glwe: GLWEPlaintext> = GLWEPlaintext::alloc(&ggsw_infos); + let mut ct_glwe: GLWE> = GLWE::alloc_from_infos(&ggsw_infos); + let mut pt_glwe: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&ggsw_infos); pt_glwe.data.at_mut(0, 0)[0] = 1 << (base2k - k_lwe_pt - 1); ct_glwe.encrypt_sk( @@ -462,11 +460,11 @@ where scratch.borrow(), ); - let res_prepared: GGSWCiphertextPrepared, B> = res.prepare_alloc(module, scratch.borrow()); + let res_prepared: GGSWPrepared, B> = res.prepare_alloc(module, scratch.borrow()); ct_glwe.external_product_inplace(module, &res_prepared, scratch.borrow()); - let mut pt_res: GLWEPlaintext> = GLWEPlaintext::alloc(&ggsw_infos); + let mut pt_res: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&ggsw_infos); ct_glwe.decrypt(module, &mut pt_res, &sk_glwe_prepared, scratch.borrow()); // Parameters are set such that the first limb should be noiseless.