diff --git a/poulpy-backend/src/cpu_fft64_avx/scratch.rs b/poulpy-backend/src/cpu_fft64_avx/scratch.rs index 922166b..f6595f9 100644 --- a/poulpy-backend/src/cpu_fft64_avx/scratch.rs +++ b/poulpy-backend/src/cpu_fft64_avx/scratch.rs @@ -3,13 +3,8 @@ use std::marker::PhantomData; use poulpy_hal::{ DEFAULTALIGN, alloc_aligned, api::ScratchFromBytes, - layouts::{Backend, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}, - oep::{ - ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, SvpPPolAllocBytesImpl, - TakeMatZnxImpl, TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl, - TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, TakeVmpPMatImpl, VecZnxBigAllocBytesImpl, - VecZnxDftAllocBytesImpl, VmpPMatAllocBytesImpl, - }, + layouts::{Backend, Scratch, ScratchOwned}, + oep::{ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeSliceImpl}, }; use crate::cpu_fft64_avx::FFT64Avx; @@ -64,178 +59,6 @@ where } } -unsafe impl TakeScalarZnxImpl for FFT64Avx -where - B: ScratchFromBytesImpl, -{ - fn take_scalar_znx_impl(scratch: &mut Scratch, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, ScalarZnx::alloc_bytes(n, cols)); - ( - ScalarZnx::from_data(take_slice, n, cols), - Scratch::from_bytes(rem_slice), - ) - } -} - -unsafe impl TakeSvpPPolImpl for FFT64Avx -where - B: SvpPPolAllocBytesImpl + ScratchFromBytesImpl, -{ - fn take_svp_ppol_impl(scratch: &mut Scratch, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, B::svp_ppol_alloc_bytes_impl(n, cols)); - ( - SvpPPol::from_data(take_slice, n, cols), - Scratch::from_bytes(rem_slice), - ) - } -} - -unsafe impl TakeVecZnxImpl for FFT64Avx -where - B: ScratchFromBytesImpl, -{ - fn take_vec_znx_impl(scratch: &mut Scratch, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, VecZnx::alloc_bytes(n, cols, size)); - ( - VecZnx::from_data(take_slice, n, cols, size), - Scratch::from_bytes(rem_slice), - ) - } -} - -unsafe impl TakeVecZnxBigImpl for FFT64Avx -where - B: VecZnxBigAllocBytesImpl + ScratchFromBytesImpl, -{ - fn take_vec_znx_big_impl( - scratch: &mut Scratch, - n: usize, - cols: usize, - size: usize, - ) -> (VecZnxBig<&mut [u8], B>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned( - &mut scratch.data, - B::vec_znx_big_alloc_bytes_impl(n, cols, size), - ); - ( - VecZnxBig::from_data(take_slice, n, cols, size), - Scratch::from_bytes(rem_slice), - ) - } -} - -unsafe impl TakeVecZnxDftImpl for FFT64Avx -where - B: VecZnxDftAllocBytesImpl + ScratchFromBytesImpl, -{ - fn take_vec_znx_dft_impl( - scratch: &mut Scratch, - n: usize, - cols: usize, - size: usize, - ) -> (VecZnxDft<&mut [u8], B>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned( - &mut scratch.data, - B::vec_znx_dft_alloc_bytes_impl(n, cols, size), - ); - - ( - VecZnxDft::from_data(take_slice, n, cols, size), - Scratch::from_bytes(rem_slice), - ) - } -} - -unsafe impl TakeVecZnxDftSliceImpl for FFT64Avx -where - B: VecZnxDftAllocBytesImpl + ScratchFromBytesImpl + TakeVecZnxDftImpl, -{ - fn take_vec_znx_dft_slice_impl( - scratch: &mut Scratch, - len: usize, - n: usize, - cols: usize, - size: usize, - ) -> (Vec>, &mut Scratch) { - let mut scratch: &mut Scratch = scratch; - let mut slice: Vec> = Vec::with_capacity(len); - for _ in 0..len { - let (znx, new_scratch) = B::take_vec_znx_dft_impl(scratch, n, cols, size); - scratch = new_scratch; - slice.push(znx); - } - (slice, scratch) - } -} - -unsafe impl TakeVecZnxSliceImpl for FFT64Avx -where - B: ScratchFromBytesImpl + TakeVecZnxImpl, -{ - fn take_vec_znx_slice_impl( - scratch: &mut Scratch, - len: usize, - n: usize, - cols: usize, - size: usize, - ) -> (Vec>, &mut Scratch) { - let mut scratch: &mut Scratch = scratch; - let mut slice: Vec> = Vec::with_capacity(len); - for _ in 0..len { - let (znx, new_scratch) = B::take_vec_znx_impl(scratch, n, cols, size); - scratch = new_scratch; - slice.push(znx); - } - (slice, scratch) - } -} - -unsafe impl TakeVmpPMatImpl for FFT64Avx -where - B: VmpPMatAllocBytesImpl + ScratchFromBytesImpl, -{ - fn take_vmp_pmat_impl( - scratch: &mut Scratch, - n: usize, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - ) -> (VmpPMat<&mut [u8], B>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned( - &mut scratch.data, - B::vmp_pmat_alloc_bytes_impl(n, rows, cols_in, cols_out, size), - ); - ( - VmpPMat::from_data(take_slice, n, rows, cols_in, cols_out, size), - Scratch::from_bytes(rem_slice), - ) - } -} - -unsafe impl TakeMatZnxImpl for FFT64Avx -where - B: ScratchFromBytesImpl, -{ - fn take_mat_znx_impl( - scratch: &mut Scratch, - n: usize, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - ) -> (MatZnx<&mut [u8]>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned( - &mut scratch.data, - MatZnx::alloc_bytes(n, rows, cols_in, cols_out, size), - ); - ( - MatZnx::from_data(take_slice, n, rows, cols_in, cols_out, size), - Scratch::from_bytes(rem_slice), - ) - } -} - fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]) { let ptr: *mut u8 = data.as_mut_ptr(); let self_len: usize = data.len(); diff --git a/poulpy-backend/src/cpu_fft64_avx/svp.rs b/poulpy-backend/src/cpu_fft64_avx/svp.rs index f505597..1c2c999 100644 --- a/poulpy-backend/src/cpu_fft64_avx/svp.rs +++ b/poulpy-backend/src/cpu_fft64_avx/svp.rs @@ -22,7 +22,7 @@ unsafe impl SvpPPolAllocImpl for FFT64Avx { } unsafe impl SvpPPolAllocBytesImpl for FFT64Avx { - fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize { + fn svp_ppol_bytes_of_impl(n: usize, cols: usize) -> usize { Self::layout_prep_word_count() * n * cols * size_of::() } } diff --git a/poulpy-backend/src/cpu_fft64_avx/vec_znx_big.rs b/poulpy-backend/src/cpu_fft64_avx/vec_znx_big.rs index 99a39fd..08ec98a 100644 --- a/poulpy-backend/src/cpu_fft64_avx/vec_znx_big.rs +++ b/poulpy-backend/src/cpu_fft64_avx/vec_znx_big.rs @@ -27,7 +27,7 @@ use poulpy_hal::{ }; unsafe impl VecZnxBigAllocBytesImpl for FFT64Avx { - fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize { + fn vec_znx_big_bytes_of_impl(n: usize, cols: usize, size: usize) -> usize { Self::layout_big_word_count() * n * cols * size * size_of::() } } diff --git a/poulpy-backend/src/cpu_fft64_avx/vec_znx_dft.rs b/poulpy-backend/src/cpu_fft64_avx/vec_znx_dft.rs index 862f623..063ee26 100644 --- a/poulpy-backend/src/cpu_fft64_avx/vec_znx_dft.rs +++ b/poulpy-backend/src/cpu_fft64_avx/vec_znx_dft.rs @@ -24,7 +24,7 @@ unsafe impl VecZnxDftFromBytesImpl for FFT64Avx { } unsafe impl VecZnxDftAllocBytesImpl for FFT64Avx { - fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize { + fn vec_znx_dft_bytes_of_impl(n: usize, cols: usize, size: usize) -> usize { Self::layout_prep_word_count() * n * cols * size * size_of::<::ScalarPrep>() } } diff --git a/poulpy-backend/src/cpu_fft64_avx/vmp.rs b/poulpy-backend/src/cpu_fft64_avx/vmp.rs index 6b87ce1..fcb6236 100644 --- a/poulpy-backend/src/cpu_fft64_avx/vmp.rs +++ b/poulpy-backend/src/cpu_fft64_avx/vmp.rs @@ -16,7 +16,7 @@ use poulpy_hal::{ use crate::cpu_fft64_avx::{FFT64Avx, module::FFT64ModuleHandle}; unsafe impl VmpPMatAllocBytesImpl for FFT64Avx { - fn vmp_pmat_alloc_bytes_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { + fn vmp_pmat_bytes_of_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { Self::layout_prep_word_count() * n * rows * cols_in * cols_out * size * size_of::() } } diff --git a/poulpy-backend/src/cpu_fft64_ref/scratch.rs b/poulpy-backend/src/cpu_fft64_ref/scratch.rs index 80b228d..1593370 100644 --- a/poulpy-backend/src/cpu_fft64_ref/scratch.rs +++ b/poulpy-backend/src/cpu_fft64_ref/scratch.rs @@ -3,13 +3,8 @@ use std::marker::PhantomData; use poulpy_hal::{ DEFAULTALIGN, alloc_aligned, api::ScratchFromBytes, - layouts::{Backend, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}, - oep::{ - ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, SvpPPolAllocBytesImpl, - TakeMatZnxImpl, TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl, - TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, TakeVmpPMatImpl, VecZnxBigAllocBytesImpl, - VecZnxDftAllocBytesImpl, VmpPMatAllocBytesImpl, - }, + layouts::{Backend, Scratch, ScratchOwned}, + oep::{ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeSliceImpl}, }; use crate::cpu_fft64_ref::FFT64Ref; @@ -64,178 +59,6 @@ where } } -unsafe impl TakeScalarZnxImpl for FFT64Ref -where - B: ScratchFromBytesImpl, -{ - fn take_scalar_znx_impl(scratch: &mut Scratch, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, ScalarZnx::alloc_bytes(n, cols)); - ( - ScalarZnx::from_data(take_slice, n, cols), - Scratch::from_bytes(rem_slice), - ) - } -} - -unsafe impl TakeSvpPPolImpl for FFT64Ref -where - B: SvpPPolAllocBytesImpl + ScratchFromBytesImpl, -{ - fn take_svp_ppol_impl(scratch: &mut Scratch, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, B::svp_ppol_alloc_bytes_impl(n, cols)); - ( - SvpPPol::from_data(take_slice, n, cols), - Scratch::from_bytes(rem_slice), - ) - } -} - -unsafe impl TakeVecZnxImpl for FFT64Ref -where - B: ScratchFromBytesImpl, -{ - fn take_vec_znx_impl(scratch: &mut Scratch, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, VecZnx::alloc_bytes(n, cols, size)); - ( - VecZnx::from_data(take_slice, n, cols, size), - Scratch::from_bytes(rem_slice), - ) - } -} - -unsafe impl TakeVecZnxBigImpl for FFT64Ref -where - B: VecZnxBigAllocBytesImpl + ScratchFromBytesImpl, -{ - fn take_vec_znx_big_impl( - scratch: &mut Scratch, - n: usize, - cols: usize, - size: usize, - ) -> (VecZnxBig<&mut [u8], B>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned( - &mut scratch.data, - B::vec_znx_big_alloc_bytes_impl(n, cols, size), - ); - ( - VecZnxBig::from_data(take_slice, n, cols, size), - Scratch::from_bytes(rem_slice), - ) - } -} - -unsafe impl TakeVecZnxDftImpl for FFT64Ref -where - B: VecZnxDftAllocBytesImpl + ScratchFromBytesImpl, -{ - fn take_vec_znx_dft_impl( - scratch: &mut Scratch, - n: usize, - cols: usize, - size: usize, - ) -> (VecZnxDft<&mut [u8], B>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned( - &mut scratch.data, - B::vec_znx_dft_alloc_bytes_impl(n, cols, size), - ); - - ( - VecZnxDft::from_data(take_slice, n, cols, size), - Scratch::from_bytes(rem_slice), - ) - } -} - -unsafe impl TakeVecZnxDftSliceImpl for FFT64Ref -where - B: VecZnxDftAllocBytesImpl + ScratchFromBytesImpl + TakeVecZnxDftImpl, -{ - fn take_vec_znx_dft_slice_impl( - scratch: &mut Scratch, - len: usize, - n: usize, - cols: usize, - size: usize, - ) -> (Vec>, &mut Scratch) { - let mut scratch: &mut Scratch = scratch; - let mut slice: Vec> = Vec::with_capacity(len); - for _ in 0..len { - let (znx, new_scratch) = B::take_vec_znx_dft_impl(scratch, n, cols, size); - scratch = new_scratch; - slice.push(znx); - } - (slice, scratch) - } -} - -unsafe impl TakeVecZnxSliceImpl for FFT64Ref -where - B: ScratchFromBytesImpl + TakeVecZnxImpl, -{ - fn take_vec_znx_slice_impl( - scratch: &mut Scratch, - len: usize, - n: usize, - cols: usize, - size: usize, - ) -> (Vec>, &mut Scratch) { - let mut scratch: &mut Scratch = scratch; - let mut slice: Vec> = Vec::with_capacity(len); - for _ in 0..len { - let (znx, new_scratch) = B::take_vec_znx_impl(scratch, n, cols, size); - scratch = new_scratch; - slice.push(znx); - } - (slice, scratch) - } -} - -unsafe impl TakeVmpPMatImpl for FFT64Ref -where - B: VmpPMatAllocBytesImpl + ScratchFromBytesImpl, -{ - fn take_vmp_pmat_impl( - scratch: &mut Scratch, - n: usize, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - ) -> (VmpPMat<&mut [u8], B>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned( - &mut scratch.data, - B::vmp_pmat_alloc_bytes_impl(n, rows, cols_in, cols_out, size), - ); - ( - VmpPMat::from_data(take_slice, n, rows, cols_in, cols_out, size), - Scratch::from_bytes(rem_slice), - ) - } -} - -unsafe impl TakeMatZnxImpl for FFT64Ref -where - B: ScratchFromBytesImpl, -{ - fn take_mat_znx_impl( - scratch: &mut Scratch, - n: usize, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - ) -> (MatZnx<&mut [u8]>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned( - &mut scratch.data, - MatZnx::alloc_bytes(n, rows, cols_in, cols_out, size), - ); - ( - MatZnx::from_data(take_slice, n, rows, cols_in, cols_out, size), - Scratch::from_bytes(rem_slice), - ) - } -} - fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]) { let ptr: *mut u8 = data.as_mut_ptr(); let self_len: usize = data.len(); diff --git a/poulpy-backend/src/cpu_fft64_ref/svp.rs b/poulpy-backend/src/cpu_fft64_ref/svp.rs index 06dad9e..37ca14a 100644 --- a/poulpy-backend/src/cpu_fft64_ref/svp.rs +++ b/poulpy-backend/src/cpu_fft64_ref/svp.rs @@ -22,7 +22,7 @@ unsafe impl SvpPPolAllocImpl for FFT64Ref { } unsafe impl SvpPPolAllocBytesImpl for FFT64Ref { - fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize { + fn svp_ppol_bytes_of_impl(n: usize, cols: usize) -> usize { Self::layout_prep_word_count() * n * cols * size_of::() } } diff --git a/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs b/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs index bb75c8f..348c9b6 100644 --- a/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs +++ b/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs @@ -27,7 +27,7 @@ use poulpy_hal::{ }; unsafe impl VecZnxBigAllocBytesImpl for FFT64Ref { - fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize { + fn vec_znx_big_bytes_of_impl(n: usize, cols: usize, size: usize) -> usize { Self::layout_big_word_count() * n * cols * size * size_of::() } } diff --git a/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs b/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs index a08b728..a2a743d 100644 --- a/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs +++ b/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs @@ -24,7 +24,7 @@ unsafe impl VecZnxDftFromBytesImpl for FFT64Ref { } unsafe impl VecZnxDftAllocBytesImpl for FFT64Ref { - fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize { + fn vec_znx_dft_bytes_of_impl(n: usize, cols: usize, size: usize) -> usize { Self::layout_prep_word_count() * n * cols * size * size_of::<::ScalarPrep>() } } diff --git a/poulpy-backend/src/cpu_fft64_ref/vmp.rs b/poulpy-backend/src/cpu_fft64_ref/vmp.rs index 2286de5..34cbf07 100644 --- a/poulpy-backend/src/cpu_fft64_ref/vmp.rs +++ b/poulpy-backend/src/cpu_fft64_ref/vmp.rs @@ -16,7 +16,7 @@ use poulpy_hal::{ use crate::cpu_fft64_ref::{FFT64Ref, module::FFT64ModuleHandle}; unsafe impl VmpPMatAllocBytesImpl for FFT64Ref { - fn vmp_pmat_alloc_bytes_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { + fn vmp_pmat_bytes_of_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { Self::layout_prep_word_count() * n * rows * cols_in * cols_out * size * size_of::() } } diff --git a/poulpy-backend/src/cpu_spqlios/fft64/scratch.rs b/poulpy-backend/src/cpu_spqlios/fft64/scratch.rs index 9bddcb3..d32b9f4 100644 --- a/poulpy-backend/src/cpu_spqlios/fft64/scratch.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/scratch.rs @@ -3,13 +3,8 @@ use std::marker::PhantomData; use poulpy_hal::{ DEFAULTALIGN, alloc_aligned, api::ScratchFromBytes, - layouts::{Backend, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}, - oep::{ - ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, SvpPPolAllocBytesImpl, - TakeMatZnxImpl, TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl, - TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, TakeVmpPMatImpl, VecZnxBigAllocBytesImpl, - VecZnxDftAllocBytesImpl, VmpPMatAllocBytesImpl, - }, + layouts::{Backend, Scratch, ScratchOwned}, + oep::{ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeSliceImpl}, }; use crate::cpu_spqlios::FFT64Spqlios; @@ -64,178 +59,6 @@ where } } -unsafe impl TakeScalarZnxImpl for FFT64Spqlios -where - B: ScratchFromBytesImpl, -{ - fn take_scalar_znx_impl(scratch: &mut Scratch, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, ScalarZnx::alloc_bytes(n, cols)); - ( - ScalarZnx::from_data(take_slice, n, cols), - Scratch::from_bytes(rem_slice), - ) - } -} - -unsafe impl TakeSvpPPolImpl for FFT64Spqlios -where - B: SvpPPolAllocBytesImpl + ScratchFromBytesImpl, -{ - fn take_svp_ppol_impl(scratch: &mut Scratch, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, B::svp_ppol_alloc_bytes_impl(n, cols)); - ( - SvpPPol::from_data(take_slice, n, cols), - Scratch::from_bytes(rem_slice), - ) - } -} - -unsafe impl TakeVecZnxImpl for FFT64Spqlios -where - B: ScratchFromBytesImpl, -{ - fn take_vec_znx_impl(scratch: &mut Scratch, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, VecZnx::alloc_bytes(n, cols, size)); - ( - VecZnx::from_data(take_slice, n, cols, size), - Scratch::from_bytes(rem_slice), - ) - } -} - -unsafe impl TakeVecZnxBigImpl for FFT64Spqlios -where - B: VecZnxBigAllocBytesImpl + ScratchFromBytesImpl, -{ - fn take_vec_znx_big_impl( - scratch: &mut Scratch, - n: usize, - cols: usize, - size: usize, - ) -> (VecZnxBig<&mut [u8], B>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned( - &mut scratch.data, - B::vec_znx_big_alloc_bytes_impl(n, cols, size), - ); - ( - VecZnxBig::from_data(take_slice, n, cols, size), - Scratch::from_bytes(rem_slice), - ) - } -} - -unsafe impl TakeVecZnxDftImpl for FFT64Spqlios -where - B: VecZnxDftAllocBytesImpl + ScratchFromBytesImpl, -{ - fn take_vec_znx_dft_impl( - scratch: &mut Scratch, - n: usize, - cols: usize, - size: usize, - ) -> (VecZnxDft<&mut [u8], B>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned( - &mut scratch.data, - B::vec_znx_dft_alloc_bytes_impl(n, cols, size), - ); - - ( - VecZnxDft::from_data(take_slice, n, cols, size), - Scratch::from_bytes(rem_slice), - ) - } -} - -unsafe impl TakeVecZnxDftSliceImpl for FFT64Spqlios -where - B: VecZnxDftAllocBytesImpl + ScratchFromBytesImpl + TakeVecZnxDftImpl, -{ - fn take_vec_znx_dft_slice_impl( - scratch: &mut Scratch, - len: usize, - n: usize, - cols: usize, - size: usize, - ) -> (Vec>, &mut Scratch) { - let mut scratch: &mut Scratch = scratch; - let mut slice: Vec> = Vec::with_capacity(len); - for _ in 0..len { - let (znx, new_scratch) = B::take_vec_znx_dft_impl(scratch, n, cols, size); - scratch = new_scratch; - slice.push(znx); - } - (slice, scratch) - } -} - -unsafe impl TakeVecZnxSliceImpl for FFT64Spqlios -where - B: ScratchFromBytesImpl + TakeVecZnxImpl, -{ - fn take_vec_znx_slice_impl( - scratch: &mut Scratch, - len: usize, - n: usize, - cols: usize, - size: usize, - ) -> (Vec>, &mut Scratch) { - let mut scratch: &mut Scratch = scratch; - let mut slice: Vec> = Vec::with_capacity(len); - for _ in 0..len { - let (znx, new_scratch) = B::take_vec_znx_impl(scratch, n, cols, size); - scratch = new_scratch; - slice.push(znx); - } - (slice, scratch) - } -} - -unsafe impl TakeVmpPMatImpl for FFT64Spqlios -where - B: VmpPMatAllocBytesImpl + ScratchFromBytesImpl, -{ - fn take_vmp_pmat_impl( - scratch: &mut Scratch, - n: usize, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - ) -> (VmpPMat<&mut [u8], B>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned( - &mut scratch.data, - B::vmp_pmat_alloc_bytes_impl(n, rows, cols_in, cols_out, size), - ); - ( - VmpPMat::from_data(take_slice, n, rows, cols_in, cols_out, size), - Scratch::from_bytes(rem_slice), - ) - } -} - -unsafe impl TakeMatZnxImpl for FFT64Spqlios -where - B: ScratchFromBytesImpl, -{ - fn take_mat_znx_impl( - scratch: &mut Scratch, - n: usize, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - ) -> (MatZnx<&mut [u8]>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned( - &mut scratch.data, - MatZnx::alloc_bytes(n, rows, cols_in, cols_out, size), - ); - ( - MatZnx::from_data(take_slice, n, rows, cols_in, cols_out, size), - Scratch::from_bytes(rem_slice), - ) - } -} - fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]) { let ptr: *mut u8 = data.as_mut_ptr(); let self_len: usize = data.len(); diff --git a/poulpy-backend/src/cpu_spqlios/fft64/svp_ppol.rs b/poulpy-backend/src/cpu_spqlios/fft64/svp_ppol.rs index b917400..f46b795 100644 --- a/poulpy-backend/src/cpu_spqlios/fft64/svp_ppol.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/svp_ppol.rs @@ -27,7 +27,7 @@ unsafe impl SvpPPolAllocImpl for FFT64Spqlios { } unsafe impl SvpPPolAllocBytesImpl for FFT64Spqlios { - fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize { + fn svp_ppol_bytes_of_impl(n: usize, cols: usize) -> usize { FFT64Spqlios::layout_prep_word_count() * n * cols * size_of::() } } diff --git a/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_big.rs b/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_big.rs index 8becaf6..5021f6b 100644 --- a/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_big.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_big.rs @@ -22,7 +22,7 @@ use poulpy_hal::{ }; unsafe impl VecZnxBigAllocBytesImpl for FFT64Spqlios { - fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize { + fn vec_znx_big_bytes_of_impl(n: usize, cols: usize, size: usize) -> usize { Self::layout_big_word_count() * n * cols * size * size_of::() } } diff --git a/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_dft.rs b/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_dft.rs index 461d327..3b67089 100644 --- a/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_dft.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_dft.rs @@ -30,7 +30,7 @@ unsafe impl VecZnxDftFromBytesImpl for FFT64Spqlios { } unsafe impl VecZnxDftAllocBytesImpl for FFT64Spqlios { - fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize { + fn vec_znx_dft_bytes_of_impl(n: usize, cols: usize, size: usize) -> usize { Self::layout_prep_word_count() * n * cols * size * size_of::<::ScalarPrep>() } } diff --git a/poulpy-backend/src/cpu_spqlios/fft64/vmp_pmat.rs b/poulpy-backend/src/cpu_spqlios/fft64/vmp_pmat.rs index ca64992..ff1eaa2 100644 --- a/poulpy-backend/src/cpu_spqlios/fft64/vmp_pmat.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/vmp_pmat.rs @@ -16,7 +16,7 @@ use crate::cpu_spqlios::{ }; unsafe impl VmpPMatAllocBytesImpl for FFT64Spqlios { - fn vmp_pmat_alloc_bytes_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { + fn vmp_pmat_bytes_of_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { Self::layout_prep_word_count() * n * rows * cols_in * cols_out * size * size_of::() } } diff --git a/poulpy-backend/src/cpu_spqlios/ntt120/svp_ppol.rs b/poulpy-backend/src/cpu_spqlios/ntt120/svp_ppol.rs index c98237a..8c7fdcc 100644 --- a/poulpy-backend/src/cpu_spqlios/ntt120/svp_ppol.rs +++ b/poulpy-backend/src/cpu_spqlios/ntt120/svp_ppol.rs @@ -18,7 +18,7 @@ unsafe impl SvpPPolAllocImpl for NTT120 { } unsafe impl SvpPPolAllocBytesImpl for NTT120 { - fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize { + fn svp_ppol_bytes_of_impl(n: usize, cols: usize) -> usize { NTT120::layout_prep_word_count() * n * cols * size_of::() } } diff --git a/poulpy-backend/src/cpu_spqlios/ntt120/vec_znx_big.rs b/poulpy-backend/src/cpu_spqlios/ntt120/vec_znx_big.rs index 715b432..58ddf78 100644 --- a/poulpy-backend/src/cpu_spqlios/ntt120/vec_znx_big.rs +++ b/poulpy-backend/src/cpu_spqlios/ntt120/vec_znx_big.rs @@ -3,7 +3,7 @@ use poulpy_hal::{layouts::Backend, oep::VecZnxBigAllocBytesImpl}; use crate::cpu_spqlios::NTT120; unsafe impl VecZnxBigAllocBytesImpl for NTT120 { - fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize { + fn vec_znx_big_bytes_of_impl(n: usize, cols: usize, size: usize) -> usize { NTT120::layout_big_word_count() * n * cols * size * size_of::() } } diff --git a/poulpy-backend/src/cpu_spqlios/ntt120/vec_znx_dft.rs b/poulpy-backend/src/cpu_spqlios/ntt120/vec_znx_dft.rs index 53dd24f..9e1666b 100644 --- a/poulpy-backend/src/cpu_spqlios/ntt120/vec_znx_dft.rs +++ b/poulpy-backend/src/cpu_spqlios/ntt120/vec_znx_dft.rs @@ -6,7 +6,7 @@ use poulpy_hal::{ use crate::cpu_spqlios::NTT120; unsafe impl VecZnxDftAllocBytesImpl for NTT120 { - fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize { + fn vec_znx_dft_bytes_of_impl(n: usize, cols: usize, size: usize) -> usize { NTT120::layout_prep_word_count() * n * cols * size * size_of::() } } diff --git a/poulpy-core/README.md b/poulpy-core/README.md index 07d5304..259988e 100644 --- a/poulpy-core/README.md +++ b/poulpy-core/README.md @@ -52,8 +52,8 @@ fn main() { // Scratch space let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::encrypt_sk_scratch_space(&module, n, base2k, ct.k()) - | GLWECiphertext::decrypt_scratch_space(&module, n, base2k, ct.k()), + GLWECiphertext::encrypt_sk_tmp_bytes(&module, n, base2k, ct.k()) + | GLWECiphertext::decrypt_tmp_bytes(&module, n, base2k, ct.k()), ); // Generate secret-key diff --git a/poulpy-core/benches/external_product_glwe_fft64.rs b/poulpy-core/benches/external_product_glwe_fft64.rs index 360a9f4..900eda1 100644 --- a/poulpy-core/benches/external_product_glwe_fft64.rs +++ b/poulpy-core/benches/external_product_glwe_fft64.rs @@ -1,7 +1,6 @@ use poulpy_core::layouts::{ - Base2K, Degree, Dnum, Dsize, GGSWCiphertext, GGSWCiphertextLayout, GLWECiphertext, GLWECiphertextLayout, GLWESecret, Rank, - TorusPrecision, - prepared::{GGSWCiphertextPrepared, GLWESecretPrepared, PrepareAlloc}, + Base2K, Degree, Dnum, Dsize, GGSW, GGSWLayout, GLWE, GLWELayout, GLWESecret, Rank, TorusPrecision, + prepared::{GGSWPrepared, GLWESecretPrepared}, }; use std::hint::black_box; @@ -39,7 +38,7 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) { let dnum: Dnum = Dnum(1); //(p.k_ct_in.div_ceil(p.base2k); - let ggsw_layout: GGSWCiphertextLayout = GGSWCiphertextLayout { + let ggsw_layout: GGSWLayout = GGSWLayout { n, base2k, k: k_ggsw, @@ -48,38 +47,40 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) { rank, }; - let glwe_out_layout: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_out_layout: GLWELayout = GLWELayout { n, base2k, k: k_ct_out, rank, }; - let glwe_in_layout: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_in_layout: GLWELayout = GLWELayout { n, base2k, k: k_ct_in, rank, }; - let mut ct_ggsw: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_layout); - let mut ct_glwe_in: GLWECiphertext> = GLWECiphertext::alloc(&glwe_in_layout); - let mut ct_glwe_out: GLWECiphertext> = GLWECiphertext::alloc(&glwe_out_layout); + let mut ct_ggsw: GGSW> = GGSW::alloc_from_infos(&ggsw_layout); + let mut ct_glwe_in: GLWE> = GLWE::alloc_from_infos(&glwe_in_layout); + let mut ct_glwe_out: GLWE> = GLWE::alloc_from_infos(&glwe_out_layout); let pt_rgsw: ScalarZnx> = ScalarZnx::alloc(n.into(), 1); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(&module, &ggsw_layout) - | GLWECiphertext::encrypt_sk_scratch_space(&module, &glwe_in_layout) - | GLWECiphertext::external_product_scratch_space(&module, &glwe_out_layout, &glwe_in_layout, &ggsw_layout), + GGSW::encrypt_sk_tmp_bytes(&module, &ggsw_layout) + | GLWE::encrypt_sk_tmp_bytes(&module, &glwe_in_layout) + | GLWE::external_product_tmp_bytes(&module, &glwe_out_layout, &glwe_in_layout, &ggsw_layout), ); let mut source_xs = Source::new([0u8; 32]); let mut source_xe = Source::new([0u8; 32]); let mut source_xa = Source::new([0u8; 32]); - let mut sk: GLWESecret> = GLWESecret::alloc(&glwe_in_layout); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_in_layout); sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_dft: GLWESecretPrepared, FFT64Spqlios> = sk.prepare_alloc(&module, scratch.borrow()); + + let mut sk_dft: GLWESecretPrepared, FFT64Spqlios> = GLWESecretPrepared::alloc(&module, rank); + sk_dft.prepare(&module, &sk); ct_ggsw.encrypt_sk( &module, @@ -98,7 +99,8 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) { scratch.borrow(), ); - let ggsw_prepared: GGSWCiphertextPrepared, FFT64Spqlios> = ct_ggsw.prepare_alloc(&module, scratch.borrow()); + let mut ggsw_prepared: GGSWPrepared, FFT64Spqlios> = GGSWPrepared::alloc_from_infos(&module, &ct_ggsw); + ggsw_prepared.prepare(&module, &ct_ggsw, scratch.borrow()); move || { ct_glwe_out.external_product(&module, &ct_glwe_in, &ggsw_prepared, scratch.borrow()); @@ -147,7 +149,7 @@ fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) { let dnum: Dnum = p.k_ct.div_ceil(p.base2k).into(); - let ggsw_layout: GGSWCiphertextLayout = GGSWCiphertextLayout { + let ggsw_layout: GGSWLayout = GGSWLayout { n, base2k, k: k_ggsw, @@ -156,30 +158,32 @@ fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) { rank, }; - let glwe_layout: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_layout: GLWELayout = GLWELayout { n, base2k, k: k_glwe, rank, }; - let mut ct_ggsw: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_layout); - let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(&glwe_layout); + let mut ct_ggsw: GGSW> = GGSW::alloc_from_infos(&ggsw_layout); + let mut ct_glwe: GLWE> = GLWE::alloc_from_infos(&glwe_layout); let pt_rgsw: ScalarZnx> = ScalarZnx::alloc(n.into(), 1); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(&module, &ggsw_layout) - | GLWECiphertext::encrypt_sk_scratch_space(&module, &glwe_layout) - | GLWECiphertext::external_product_inplace_scratch_space(&module, &glwe_layout, &ggsw_layout), + GGSW::encrypt_sk_tmp_bytes(&module, &ggsw_layout) + | GLWE::encrypt_sk_tmp_bytes(&module, &glwe_layout) + | GLWE::external_product_tmp_bytes(&module, &glwe_layout, &glwe_layout, &ggsw_layout), ); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut sk: GLWESecret> = GLWESecret::alloc(&glwe_layout); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_layout); sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_dft: GLWESecretPrepared, FFT64Spqlios> = sk.prepare_alloc(&module, scratch.borrow()); + + let mut sk_dft: GLWESecretPrepared, FFT64Spqlios> = GLWESecretPrepared::alloc(&module, rank); + sk_dft.prepare(&module, &sk); ct_ggsw.encrypt_sk( &module, @@ -198,8 +202,8 @@ fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) { scratch.borrow(), ); - let ggsw_prepared: GGSWCiphertextPrepared, FFT64Spqlios> = ct_ggsw.prepare_alloc(&module, scratch.borrow()); - + let mut ggsw_prepared: GGSWPrepared, FFT64Spqlios> = GGSWPrepared::alloc_from_infos(&module, &ct_ggsw); + ggsw_prepared.prepare(&module, &ct_ggsw, scratch.borrow()); move || { let scratch_borrow = scratch.borrow(); ct_glwe.external_product_inplace(&module, &ggsw_prepared, scratch_borrow); diff --git a/poulpy-core/benches/keyswitch_glwe_fft64.rs b/poulpy-core/benches/keyswitch_glwe_fft64.rs index 2da2032..e0ca001 100644 --- a/poulpy-core/benches/keyswitch_glwe_fft64.rs +++ b/poulpy-core/benches/keyswitch_glwe_fft64.rs @@ -1,7 +1,7 @@ use poulpy_core::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWEAutomorphismKey, GGLWEAutomorphismKeyLayout, GGLWESwitchingKey, GGLWESwitchingKeyLayout, - GLWECiphertext, GLWECiphertextLayout, GLWESecret, Rank, TorusPrecision, - prepared::{GGLWEAutomorphismKeyPrepared, GGLWESwitchingKeyPrepared, GLWESecretPrepared, PrepareAlloc}, + Base2K, Degree, Dnum, Dsize, GLWE, GLWEAutomorphismKey, GLWEAutomorphismKeyLayout, GLWELayout, GLWESecret, GLWESwitchingKey, + GLWESwitchingKeyLayout, GLWESwitchingKeyPrepared, Rank, TorusPrecision, + prepared::{GLWEAutomorphismKeyPrepared, GLWESecretPrepared}, }; use std::{hint::black_box, time::Duration}; @@ -39,7 +39,7 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { let dnum: Dnum = p.k_ct_in.div_ceil(p.base2k.0 * dsize.0).into(); - let gglwe_atk_layout: GGLWEAutomorphismKeyLayout = GGLWEAutomorphismKeyLayout { + let gglwe_atk_layout: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n, base2k, k: k_gglwe, @@ -48,28 +48,28 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { dsize, }; - let glwe_in_layout: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_in_layout: GLWELayout = GLWELayout { n, base2k, k: k_glwe_in, rank, }; - let glwe_out_layout: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_out_layout: GLWELayout = GLWELayout { n, base2k, k: k_glwe_out, rank, }; - let mut ksk: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&gglwe_atk_layout); - let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(&glwe_in_layout); - let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(&glwe_out_layout); + let mut ksk: GLWEAutomorphismKey> = GLWEAutomorphismKey::alloc_from_infos(&gglwe_atk_layout); + let mut ct_in: GLWE> = GLWE::alloc_from_infos(&glwe_in_layout); + let mut ct_out: GLWE> = GLWE::alloc_from_infos(&glwe_out_layout); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWESwitchingKey::encrypt_sk_scratch_space(&module, &gglwe_atk_layout) - | GLWECiphertext::encrypt_sk_scratch_space(&module, &glwe_in_layout) - | GLWECiphertext::keyswitch_scratch_space( + GLWESwitchingKey::encrypt_sk_tmp_bytes(&module, &gglwe_atk_layout) + | GLWE::encrypt_sk_tmp_bytes(&module, &glwe_in_layout) + | GLWE::keyswitch_tmp_bytes( &module, &glwe_out_layout, &glwe_in_layout, @@ -81,9 +81,11 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut sk_in: GLWESecret> = GLWESecret::alloc(&glwe_in_layout); + let mut sk_in: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_in_layout); sk_in.fill_ternary_prob(0.5, &mut source_xs); - let sk_in_dft: GLWESecretPrepared, FFT64Spqlios> = sk_in.prepare_alloc(&module, scratch.borrow()); + + let mut sk_in_dft: GLWESecretPrepared, FFT64Spqlios> = GLWESecretPrepared::alloc(&module, rank); + sk_in_dft.prepare(&module, &sk_in); ksk.encrypt_sk( &module, @@ -102,7 +104,9 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { scratch.borrow(), ); - let ksk_prepared: GGLWEAutomorphismKeyPrepared, _> = ksk.prepare_alloc(&module, scratch.borrow()); + let mut ksk_prepared: GLWEAutomorphismKeyPrepared, _> = + GLWEAutomorphismKeyPrepared::alloc_from_infos(&module, &ksk); + ksk_prepared.prepare(&module, &ksk, scratch.borrow()); move || { ct_out.automorphism(&module, &ct_in, &ksk_prepared, scratch.borrow()); @@ -157,7 +161,7 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) { let dnum: Dnum = p.k_ct.div_ceil(p.base2k).into(); - let gglwe_layout: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + let gglwe_layout: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n, base2k, k: k_ksk, @@ -167,31 +171,33 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) { rank_out: rank, }; - let glwe_layout: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_layout: GLWELayout = GLWELayout { n, base2k, k: k_ct, rank, }; - let mut ksk: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&gglwe_layout); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&glwe_layout); + let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc_from_infos(&gglwe_layout); + let mut ct: GLWE> = GLWE::alloc_from_infos(&glwe_layout); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWESwitchingKey::encrypt_sk_scratch_space(&module, &gglwe_layout) - | GLWECiphertext::encrypt_sk_scratch_space(&module, &glwe_layout) - | GLWECiphertext::keyswitch_inplace_scratch_space(&module, &glwe_layout, &gglwe_layout), + GLWESwitchingKey::encrypt_sk_tmp_bytes(&module, &gglwe_layout) + | GLWE::encrypt_sk_tmp_bytes(&module, &glwe_layout) + | GLWE::keyswitch_tmp_bytes(&module, &glwe_layout, &glwe_layout, &gglwe_layout), ); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut sk_in: GLWESecret> = GLWESecret::alloc(&glwe_layout); + let mut sk_in: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_layout); sk_in.fill_ternary_prob(0.5, &mut source_xs); - let sk_in_dft: GLWESecretPrepared, FFT64Spqlios> = sk_in.prepare_alloc(&module, scratch.borrow()); - let mut sk_out: GLWESecret> = GLWESecret::alloc(&glwe_layout); + let mut sk_in_dft: GLWESecretPrepared, FFT64Spqlios> = GLWESecretPrepared::alloc(&module, rank); + sk_in_dft.prepare(&module, &sk_in); + + let mut sk_out: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_layout); sk_out.fill_ternary_prob(0.5, &mut source_xs); ksk.encrypt_sk( @@ -211,7 +217,8 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) { scratch.borrow(), ); - let ksk_prepared: GGLWESwitchingKeyPrepared, FFT64Spqlios> = ksk.prepare_alloc(&module, scratch.borrow()); + let mut ksk_prepared: GLWESwitchingKeyPrepared, _> = GLWESwitchingKeyPrepared::alloc_from_infos(&module, &ksk); + ksk_prepared.prepare(&module, &ksk, scratch.borrow()); move || { ct.keyswitch_inplace(&module, &ksk_prepared, scratch.borrow()); diff --git a/poulpy-core/examples/encryption.rs b/poulpy-core/examples/encryption.rs index a65b473..5ee8c8e 100644 --- a/poulpy-core/examples/encryption.rs +++ b/poulpy-core/examples/encryption.rs @@ -1,10 +1,9 @@ use poulpy_backend::cpu_spqlios::FFT64Spqlios; use poulpy_core::{ - GLWEOperations, SIGMA, + GLWESub, SIGMA, layouts::{ - Base2K, Degree, GLWECiphertext, GLWECiphertextLayout, GLWEPlaintext, GLWEPlaintextLayout, GLWESecret, LWEInfos, Rank, - TorusPrecision, - prepared::{GLWESecretPrepared, PrepareAlloc}, + Base2K, Degree, GLWE, GLWELayout, GLWEPlaintext, GLWEPlaintextLayout, GLWESecret, LWEInfos, Rank, TorusPrecision, + prepared::GLWESecretPrepared, }, }; use poulpy_hal::{ @@ -34,7 +33,7 @@ fn main() { // Instantiate Module (DFT Tables) let module: Module = Module::::new(n.0 as u64); - let glwe_ct_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_ct_infos: GLWELayout = GLWELayout { n, base2k, k: k_ct, @@ -44,9 +43,9 @@ fn main() { let glwe_pt_infos: GLWEPlaintextLayout = GLWEPlaintextLayout { n, base2k, k: k_pt }; // Allocates ciphertext & plaintexts - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&glwe_ct_infos); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_pt_infos); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_pt_infos); + let mut ct: GLWE> = GLWE::alloc_from_infos(&glwe_ct_infos); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_pt_infos); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_pt_infos); // CPRNG let mut source_xs: Source = Source::new([0u8; 32]); @@ -55,16 +54,16 @@ fn main() { // Scratch space let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::encrypt_sk_scratch_space(&module, &glwe_ct_infos) - | GLWECiphertext::decrypt_scratch_space(&module, &glwe_ct_infos), + GLWE::encrypt_sk_tmp_bytes(&module, &glwe_ct_infos) | GLWE::decrypt_tmp_bytes(&module, &glwe_ct_infos), ); // Generate secret-key - let mut sk: GLWESecret> = GLWESecret::alloc(&glwe_ct_infos); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_ct_infos); sk.fill_ternary_prob(0.5, &mut source_xs); // Backend-prepared secret - let sk_prepared: GLWESecretPrepared, FFT64Spqlios> = sk.prepare_alloc(&module, scratch.borrow()); + let mut sk_prepared: GLWESecretPrepared, FFT64Spqlios> = GLWESecretPrepared::alloc(&module, rank); + sk_prepared.prepare(&module, &sk); // Uniform plaintext module.vec_znx_fill_uniform(base2k.into(), &mut pt_want.data, 0, &mut source_xa); @@ -83,7 +82,7 @@ fn main() { ct.decrypt(&module, &mut pt_have, &sk_prepared, scratch.borrow()); // Diff between pt - Dec(Enc(pt)) - pt_want.sub_inplace_ab(&module, &pt_have); + module.glwe_sub_inplace(&mut pt_want, &pt_have); // Ideal vs. actual noise let noise_have: f64 = pt_want.data.std(base2k.into(), 0) * (ct.k().as_u32() as f64).exp2(); diff --git a/poulpy-core/src/automorphism/gglwe_atk.rs b/poulpy-core/src/automorphism/gglwe_atk.rs index 9b08e68..ffc35cd 100644 --- a/poulpy-core/src/automorphism/gglwe_atk.rs +++ b/poulpy-core/src/automorphism/gglwe_atk.rs @@ -1,198 +1,165 @@ use poulpy_hal::{ - api::{ - ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, - VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, - VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, - }, - layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero}, + api::VecZnxAutomorphism, + layouts::{Backend, DataMut, GaloisElement, Module, Scratch}, }; -use crate::layouts::{GGLWEAutomorphismKey, GGLWEInfos, GLWECiphertext, prepared::GGLWEAutomorphismKeyPrepared}; +use crate::{ + ScratchTakeCore, + automorphism::glwe_ct::GLWEAutomorphism, + layouts::{ + GGLWE, GGLWEInfos, GGLWEPreparedToRef, GGLWEToMut, GGLWEToRef, GLWE, GLWEAutomorphismKey, GetGaloisElement, + SetGaloisElement, + }, +}; -impl GGLWEAutomorphismKey> { - pub fn automorphism_scratch_space( - module: &Module, - out_infos: &OUT, - in_infos: &IN, - key_infos: &KEY, - ) -> usize +impl GLWEAutomorphismKey> { + pub fn automorphism_tmp_bytes(module: &M, res_infos: &R, a_infos: &A, key_infos: &K) -> usize where - OUT: GGLWEInfos, - IN: GGLWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, + R: GGLWEInfos, + A: GGLWEInfos, + K: GGLWEInfos, + M: GLWEAutomorphismKeyAutomorphism, { - GLWECiphertext::keyswitch_scratch_space( - module, - &out_infos.glwe_layout(), - &in_infos.glwe_layout(), - key_infos, - ) - } - - pub fn automorphism_inplace_scratch_space(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize - where - OUT: GGLWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, - { - GGLWEAutomorphismKey::automorphism_scratch_space(module, out_infos, out_infos, key_infos) + module.glwe_automorphism_key_automorphism_tmp_bytes(res_infos, a_infos, key_infos) } } -impl GGLWEAutomorphismKey { - pub fn automorphism( - &mut self, - module: &Module, - lhs: &GGLWEAutomorphismKey, - rhs: &GGLWEAutomorphismKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxAutomorphism - + VecZnxAutomorphismInplace - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnx, +impl GLWEAutomorphismKey { + pub fn automorphism(&mut self, module: &M, a: &A, key: &K, scratch: &mut Scratch) + where + A: GGLWEToRef + GetGaloisElement + GGLWEInfos, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, + Scratch: ScratchTakeCore, + M: GLWEAutomorphismKeyAutomorphism, { - #[cfg(debug_assertions)] - { - use crate::layouts::LWEInfos; - - assert_eq!( - self.rank_in(), - lhs.rank_in(), - "ksk_out input rank: {} != ksk_in input rank: {}", - self.rank_in(), - lhs.rank_in() - ); - assert_eq!( - self.rank_out(), - rhs.rank_in(), - "ksk_in output rank: {} != ksk_apply input rank: {}", - self.rank_out(), - rhs.rank_in() - ); - assert_eq!( - self.rank_out(), - rhs.rank_out(), - "ksk_out output rank: {} != ksk_apply output rank: {}", - self.rank_out(), - rhs.rank_out() - ); - assert!( - self.k() <= lhs.k(), - "output k={} cannot be greater than input k={}", - self.k(), - lhs.k() - ) - } - - let cols_out: usize = (rhs.rank_out() + 1).into(); - - let p: i64 = lhs.p(); - let p_inv: i64 = module.galois_element_inv(p); - - (0..self.rank_in().into()).for_each(|col_i| { - (0..self.dnum().into()).for_each(|row_j| { - let mut res_ct: GLWECiphertext<&mut [u8]> = self.at_mut(row_j, col_i); - let lhs_ct: GLWECiphertext<&[u8]> = lhs.at(row_j, col_i); - - // Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a) - (0..cols_out).for_each(|i| { - module.vec_znx_automorphism(lhs.p(), &mut res_ct.data, i, &lhs_ct.data, i); - }); - - // Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) - res_ct.keyswitch_inplace(module, &rhs.key, scratch); - - // Applies back the automorphism X^{-k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) to (-pi^{-1}_{k'+k}(s)a + s, a) - (0..cols_out).for_each(|i| { - module.vec_znx_automorphism_inplace(p_inv, &mut res_ct.data, i, scratch); - }); - }); - }); - - (self.dnum().min(lhs.dnum()).into()..self.dnum().into()).for_each(|row_i| { - (0..self.rank_in().into()).for_each(|col_j| { - self.at_mut(row_i, col_j).data.zero(); - }); - }); - - self.p = (lhs.p * rhs.p) % (module.cyclotomic_order() as i64); + module.glwe_automorphism_key_automorphism(self, a, key, scratch); } - pub fn automorphism_inplace( - &mut self, - module: &Module, - rhs: &GGLWEAutomorphismKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxAutomorphism - + VecZnxAutomorphismInplace - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnx, + pub fn automorphism_inplace(&mut self, module: &M, key: &K, scratch: &mut Scratch) + where + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, + Scratch: ScratchTakeCore, + M: GLWEAutomorphismKeyAutomorphism, { - #[cfg(debug_assertions)] - { - assert_eq!( - self.rank_out(), - rhs.rank_in(), - "ksk_in output rank: {} != ksk_apply input rank: {}", - self.rank_out(), - rhs.rank_in() - ); - assert_eq!( - self.rank_out(), - rhs.rank_out(), - "ksk_out output rank: {} != ksk_apply output rank: {}", - self.rank_out(), - rhs.rank_out() - ); - } - - let cols_out: usize = (rhs.rank_out() + 1).into(); - - let p: i64 = self.p(); - let p_inv = module.galois_element_inv(p); - - (0..self.rank_in().into()).for_each(|col_i| { - (0..self.dnum().into()).for_each(|row_j| { - let mut res_ct: GLWECiphertext<&mut [u8]> = self.at_mut(row_j, col_i); - - // Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a) - (0..cols_out).for_each(|i| { - module.vec_znx_automorphism_inplace(p_inv, &mut res_ct.data, i, scratch); - }); - - // Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) - res_ct.keyswitch_inplace(module, &rhs.key, scratch); - - // Applies back the automorphism X^{-k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) to (-pi^{-1}_{k'+k}(s)a + s, a) - (0..cols_out).for_each(|i| { - module.vec_znx_automorphism_inplace(p_inv, &mut res_ct.data, i, scratch); - }); - }); - }); - - self.p = (self.p * rhs.p) % (module.cyclotomic_order() as i64); + module.glwe_automorphism_key_automorphism_inplace(self, key, scratch); + } +} + +impl GLWEAutomorphismKeyAutomorphism for Module where + Self: GaloisElement + GLWEAutomorphism + VecZnxAutomorphism +{ +} + +pub trait GLWEAutomorphismKeyAutomorphism +where + Self: GaloisElement + GLWEAutomorphism + VecZnxAutomorphism, +{ + fn glwe_automorphism_key_automorphism_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize + where + R: GGLWEInfos, + A: GGLWEInfos, + K: GGLWEInfos, + { + self.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos) + } + + fn glwe_automorphism_key_automorphism(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) + where + R: GGLWEToMut + SetGaloisElement + GGLWEInfos, + A: GGLWEToRef + GetGaloisElement + GGLWEInfos, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, + Scratch: ScratchTakeCore, + { + assert!( + res.dnum().as_u32() <= a.dnum().as_u32(), + "res dnum: {} > a dnum: {}", + res.dnum(), + a.dnum() + ); + + assert_eq!( + res.dsize(), + a.dsize(), + "res dnum: {} != a dnum: {}", + res.dsize(), + a.dsize() + ); + + let cols_out: usize = (key.rank_out() + 1).into(); + let cols_in: usize = key.rank_in().into(); + + let p: i64 = a.p(); + let p_inv: i64 = self.galois_element_inv(p); + + { + let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GGLWE<&[u8]> = &a.to_ref(); + + for row in 0..res.dnum().as_usize() { + for col in 0..cols_in { + let mut res_tmp: GLWE<&mut [u8]> = res.at_mut(row, col); + let a_ct: GLWE<&[u8]> = a.at(row, col); + + // Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a) + for i in 0..cols_out { + self.vec_znx_automorphism(p, res_tmp.data_mut(), i, &a_ct.data, i); + } + + // Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) + self.glwe_keyswitch_inplace(&mut res_tmp, key, scratch); + + // Applies back the automorphism X^{-k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) to (-pi^{-1}_{k'+k}(s)a + s, a) + (0..cols_out).for_each(|i| { + self.vec_znx_automorphism_inplace(p_inv, res_tmp.data_mut(), i, scratch); + }); + } + } + } + + res.set_p((p * key.p()) % self.cyclotomic_order()); + } + + fn glwe_automorphism_key_automorphism_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) + where + R: GGLWEToMut + SetGaloisElement + GetGaloisElement + GGLWEInfos, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, + Scratch: ScratchTakeCore, + { + assert_eq!( + res.rank(), + key.rank(), + "key rank: {} != key rank: {}", + res.rank(), + key.rank() + ); + + let cols_out: usize = (key.rank_out() + 1).into(); + let cols_in: usize = key.rank_in().into(); + let p: i64 = res.p(); + let p_inv: i64 = self.galois_element_inv(p); + + { + let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut(); + for row in 0..res.dnum().as_usize() { + for col in 0..cols_in { + let mut res_tmp: GLWE<&mut [u8]> = res.at_mut(row, col); + + // Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a) + for i in 0..cols_out { + self.vec_znx_automorphism_inplace(p, res_tmp.data_mut(), i, scratch); + } + + // Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) + self.glwe_keyswitch_inplace(&mut res_tmp, key, scratch); + + // Applies back the automorphism X^{-k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) to (-pi^{-1}_{k'+k}(s)a + s, a) + for i in 0..cols_out { + self.vec_znx_automorphism_inplace(p_inv, res_tmp.data_mut(), i, scratch); + } + } + } + } + + res.set_p((res.p() * key.p()) % self.cyclotomic_order()); } } diff --git a/poulpy-core/src/automorphism/ggsw_ct.rs b/poulpy-core/src/automorphism/ggsw_ct.rs index eef3082..fb54f6d 100644 --- a/poulpy-core/src/automorphism/ggsw_ct.rs +++ b/poulpy-core/src/automorphism/ggsw_ct.rs @@ -1,171 +1,124 @@ use poulpy_hal::{ - api::{ - ScratchAvailable, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, - VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAddInplace, VecZnxDftAllocBytes, - VecZnxDftApply, VecZnxDftCopy, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeTmpBytes, - VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + api::ScratchAvailable, + layouts::{Backend, DataMut, Module, Scratch}, +}; + +use crate::{ + GGSWExpandRows, ScratchTakeCore, + automorphism::glwe_ct::GLWEAutomorphism, + layouts::{ + GGLWEInfos, GGLWEPreparedToRef, GGSW, GGSWInfos, GGSWToMut, GGSWToRef, GetGaloisElement, + prepared::{GLWETensorKeyPrepared, GLWETensorKeyPreparedToRef}, }, - layouts::{Backend, DataMut, DataRef, Module, Scratch}, }; -use crate::layouts::{ - GGLWEInfos, GGSWCiphertext, GGSWInfos, GLWECiphertext, - prepared::{GGLWEAutomorphismKeyPrepared, GGLWETensorKeyPrepared}, -}; - -impl GGSWCiphertext> { - pub fn automorphism_scratch_space( - module: &Module, - out_infos: &OUT, - in_infos: &IN, - key_infos: &KEY, - tsk_infos: &TSK, +impl GGSW> { + pub fn automorphism_tmp_bytes( + module: &M, + res_infos: &R, + a_infos: &A, + key_infos: &K, + tsk_infos: &T, ) -> usize where - OUT: GGSWInfos, - IN: GGSWInfos, - KEY: GGLWEInfos, - TSK: GGLWEInfos, - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigAllocBytes - + VecZnxNormalizeTmpBytes - + VecZnxBigNormalizeTmpBytes, + R: GGSWInfos, + A: GGSWInfos, + K: GGLWEInfos, + T: GGLWEInfos, + M: GGSWAutomorphism, { - let out_size: usize = out_infos.size(); - let ci_dft: usize = module.vec_znx_dft_alloc_bytes((key_infos.rank_out() + 1).into(), out_size); - let ks_internal: usize = GLWECiphertext::keyswitch_scratch_space( - module, - &out_infos.glwe_layout(), - &in_infos.glwe_layout(), - key_infos, - ); - let expand: usize = GGSWCiphertext::expand_row_scratch_space(module, out_infos, tsk_infos); - ci_dft + (ks_internal | expand) - } - - pub fn automorphism_inplace_scratch_space( - module: &Module, - out_infos: &OUT, - key_infos: &KEY, - tsk_infos: &TSK, - ) -> usize - where - OUT: GGSWInfos, - KEY: GGLWEInfos, - TSK: GGLWEInfos, - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigAllocBytes - + VecZnxNormalizeTmpBytes - + VecZnxBigNormalizeTmpBytes, - { - GGSWCiphertext::automorphism_scratch_space(module, out_infos, out_infos, key_infos, tsk_infos) + module.ggsw_automorphism_tmp_bytes(res_infos, a_infos, key_infos, tsk_infos) } } -impl GGSWCiphertext { - pub fn automorphism( - &mut self, - module: &Module, - lhs: &GGSWCiphertext, - auto_key: &GGLWEAutomorphismKeyPrepared, - tensor_key: &GGLWETensorKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxAutomorphismInplace - + VecZnxBigAllocBytes - + VecZnxNormalizeTmpBytes - + VecZnxDftCopy - + VecZnxDftAddInplace - + VecZnxIdftApplyTmpA - + VecZnxNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnxBig + TakeVecZnx, +impl GGSW { + pub fn automorphism(&mut self, module: &M, a: &A, key: &K, tsk: &T, scratch: &mut Scratch) + where + A: GGSWToRef, + K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, + T: GLWETensorKeyPreparedToRef, + Scratch: ScratchTakeCore, + M: GGSWAutomorphism, { - #[cfg(debug_assertions)] - { - use crate::layouts::{GLWEInfos, LWEInfos}; - - assert_eq!(self.n(), module.n() as u32); - assert_eq!(lhs.n(), module.n() as u32); - assert_eq!(auto_key.n(), module.n() as u32); - assert_eq!(tensor_key.n(), module.n() as u32); - - assert_eq!( - self.rank(), - lhs.rank(), - "ggsw_out rank: {} != ggsw_in rank: {}", - self.rank(), - lhs.rank() - ); - assert_eq!( - self.rank(), - auto_key.rank_out(), - "ggsw_in rank: {} != auto_key rank: {}", - self.rank(), - auto_key.rank_out() - ); - assert_eq!( - self.rank(), - tensor_key.rank_out(), - "ggsw_in rank: {} != tensor_key rank: {}", - self.rank(), - tensor_key.rank_out() - ); - assert!(scratch.available() >= GGSWCiphertext::automorphism_scratch_space(module, self, lhs, auto_key, tensor_key)) - }; - - // Keyswitch the j-th row of the col 0 - (0..lhs.dnum().into()).for_each(|row_i| { - // Key-switch column 0, i.e. - // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2) - self.at_mut(row_i, 0) - .automorphism(module, &lhs.at(row_i, 0), auto_key, scratch); - }); - self.expand_row(module, tensor_key, scratch); + module.ggsw_automorphism(self, a, key, tsk, scratch); } - pub fn automorphism_inplace( - &mut self, - module: &Module, - auto_key: &GGLWEAutomorphismKeyPrepared, - tensor_key: &GGLWETensorKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxAutomorphismInplace - + VecZnxBigAllocBytes - + VecZnxNormalizeTmpBytes - + VecZnxDftCopy - + VecZnxDftAddInplace - + VecZnxIdftApplyTmpA - + VecZnxNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnxBig + TakeVecZnx, + pub fn automorphism_inplace(&mut self, module: &M, key: &K, tsk: &T, scratch: &mut Scratch) + where + K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, + T: GLWETensorKeyPreparedToRef, + Scratch: ScratchTakeCore, + M: GGSWAutomorphism, { - // Keyswitch the j-th row of the col 0 - (0..self.dnum().into()).for_each(|row_i| { - // Key-switch column 0, i.e. - // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2) - self.at_mut(row_i, 0) - .automorphism_inplace(module, auto_key, scratch); - }); - self.expand_row(module, tensor_key, scratch); + module.ggsw_automorphism_inplace(self, key, tsk, scratch); } } + +impl GGSWAutomorphism for Module where Self: GLWEAutomorphism + GGSWExpandRows {} + +pub trait GGSWAutomorphism +where + Self: GLWEAutomorphism + GGSWExpandRows, +{ + fn ggsw_automorphism_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &K, tsk_infos: &T) -> usize + where + R: GGSWInfos, + A: GGSWInfos, + K: GGLWEInfos, + T: GGLWEInfos, + { + let out_size: usize = res_infos.size(); + let ci_dft: usize = self.bytes_of_vec_znx_dft((key_infos.rank_out() + 1).into(), out_size); + let ks_internal: usize = self.glwe_automorphism_tmp_bytes(res_infos, a_infos, key_infos); + let expand: usize = self.ggsw_expand_rows_tmp_bytes(res_infos, tsk_infos); + ci_dft + (ks_internal.max(expand)) + } + + fn ggsw_automorphism(&self, res: &mut R, a: &A, key: &K, tsk: &T, scratch: &mut Scratch) + where + R: GGSWToMut, + A: GGSWToRef, + K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, + T: GLWETensorKeyPreparedToRef, + Scratch: ScratchTakeCore, + { + let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); + let a: &GGSW<&[u8]> = &a.to_ref(); + let tsk: &GLWETensorKeyPrepared<&[u8], BE> = &tsk.to_ref(); + + assert_eq!(res.dsize(), a.dsize()); + assert!(res.dnum() <= a.dnum()); + assert!(scratch.available() >= self.ggsw_automorphism_tmp_bytes(res, a, key, tsk)); + + // Keyswitch the j-th row of the col 0 + for row in 0..res.dnum().as_usize() { + // Key-switch column 0, i.e. + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2) + self.glwe_automorphism(&mut res.at_mut(row, 0), &a.at(row, 0), key, scratch); + } + + self.ggsw_expand_row(res, tsk, scratch); + } + + fn ggsw_automorphism_inplace(&self, res: &mut R, key: &K, tsk: &T, scratch: &mut Scratch) + where + R: GGSWToMut, + K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, + T: GLWETensorKeyPreparedToRef, + Scratch: ScratchTakeCore, + { + let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); + let tsk: &GLWETensorKeyPrepared<&[u8], BE> = &tsk.to_ref(); + + // Keyswitch the j-th row of the col 0 + for row in 0..res.dnum().as_usize() { + // Key-switch column 0, i.e. + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2) + self.glwe_automorphism_inplace(&mut res.at_mut(row, 0), key, scratch); + } + + self.ggsw_expand_row(res, tsk, scratch); + } +} + +impl GGSW {} diff --git a/poulpy-core/src/automorphism/glwe_ct.rs b/poulpy-core/src/automorphism/glwe_ct.rs index 79fcb12..7161239 100644 --- a/poulpy-core/src/automorphism/glwe_ct.rs +++ b/poulpy-core/src/automorphism/glwe_ct.rs @@ -1,345 +1,322 @@ use poulpy_hal::{ api::{ - ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, - VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallInplace, - VecZnxBigSubSmallNegateInplace, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalize, - VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + ScratchTakeBasic, VecZnxAutomorphismInplace, VecZnxBigAutomorphismInplace, VecZnxBigSubSmallInplace, + VecZnxBigSubSmallNegateInplace, }, - layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnxBig}, + layouts::{Backend, DataMut, Module, Scratch, VecZnxBig}, }; -use crate::layouts::{GGLWEInfos, GLWECiphertext, GLWEInfos, LWEInfos, prepared::GGLWEAutomorphismKeyPrepared}; +use crate::{ + GLWEKeyswitch, ScratchTakeCore, keyswitch_internal, + layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, GetGaloisElement, LWEInfos}, +}; -impl GLWECiphertext> { - pub fn automorphism_scratch_space( - module: &Module, - out_infos: &OUT, - in_infos: &IN, - key_infos: &KEY, - ) -> usize +impl GLWE> { + pub fn automorphism_tmp_bytes(module: &M, res_infos: &R, a_infos: &A, key_infos: &K) -> usize where - OUT: GLWEInfos, - IN: GLWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, + R: GLWEInfos, + A: GLWEInfos, + K: GGLWEInfos, + M: GLWEAutomorphism, { - Self::keyswitch_scratch_space(module, out_infos, in_infos, key_infos) - } - - pub fn automorphism_inplace_scratch_space(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize - where - OUT: GLWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, - { - Self::keyswitch_inplace_scratch_space(module, out_infos, key_infos) + module.glwe_automorphism_tmp_bytes(res_infos, a_infos, key_infos) } } -impl GLWECiphertext { - pub fn automorphism( - &mut self, - module: &Module, - lhs: &GLWECiphertext, - rhs: &GGLWEAutomorphismKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxAutomorphismInplace - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, +impl GLWE { + pub fn automorphism(&mut self, module: &M, a: &A, key: &K, scratch: &mut Scratch) + where + M: GLWEAutomorphism, + A: GLWEToRef, + K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, + Scratch: ScratchTakeCore, { - self.keyswitch(module, lhs, &rhs.key, scratch); - (0..(self.rank() + 1).into()).for_each(|i| { - module.vec_znx_automorphism_inplace(rhs.p(), &mut self.data, i, scratch); - }) + module.glwe_automorphism(self, a, key, scratch); } - pub fn automorphism_inplace( - &mut self, - module: &Module, - rhs: &GGLWEAutomorphismKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxAutomorphismInplace - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + pub fn automorphism_add(&mut self, module: &M, a: &A, key: &K, scratch: &mut Scratch) + where + M: GLWEAutomorphism, + A: GLWEToRef, + K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, + Scratch: ScratchTakeCore, { - self.keyswitch_inplace(module, &rhs.key, scratch); - (0..(self.rank() + 1).into()).for_each(|i| { - module.vec_znx_automorphism_inplace(rhs.p(), &mut self.data, i, scratch); - }) + module.glwe_automorphism_add(self, a, key, scratch); } - pub fn automorphism_add( - &mut self, - module: &Module, - lhs: &GLWECiphertext, - rhs: &GGLWEAutomorphismKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxBigAutomorphismInplace - + VecZnxNormalizeTmpBytes - + VecZnxNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + pub fn automorphism_sub(&mut self, module: &M, a: &A, key: &K, scratch: &mut Scratch) + where + M: GLWEAutomorphism, + A: GLWEToRef, + K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, + Scratch: ScratchTakeCore, { - #[cfg(debug_assertions)] - { - self.assert_keyswitch(module, lhs, &rhs.key, scratch); - } - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch_1); - (0..(self.rank() + 1).into()).for_each(|i| { - module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1); - module.vec_znx_big_add_small_inplace(&mut res_big, i, &lhs.data, i); - module.vec_znx_big_normalize( - self.base2k().into(), - &mut self.data, - i, - rhs.base2k().into(), - &res_big, - i, - scratch_1, - ); - }) + module.glwe_automorphism_sub(self, a, key, scratch); } - pub fn automorphism_add_inplace( - &mut self, - module: &Module, - rhs: &GGLWEAutomorphismKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxBigAutomorphismInplace - + VecZnxNormalizeTmpBytes - + VecZnxNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + pub fn automorphism_sub_negate(&mut self, module: &M, a: &A, key: &K, scratch: &mut Scratch) + where + M: GLWEAutomorphism, + A: GLWEToRef, + K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, + Scratch: ScratchTakeCore, { - #[cfg(debug_assertions)] - { - self.assert_keyswitch_inplace(module, &rhs.key, scratch); - } - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, B> = self.keyswitch_internal(module, res_dft, &rhs.key, scratch_1); - (0..(self.rank() + 1).into()).for_each(|i| { - module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1); - module.vec_znx_big_add_small_inplace(&mut res_big, i, &self.data, i); - module.vec_znx_big_normalize( - self.base2k().into(), - &mut self.data, - i, - rhs.base2k().into(), - &res_big, - i, - scratch_1, - ); - }) + module.glwe_automorphism_sub_negate(self, a, key, scratch); } - pub fn automorphism_sub_ab( - &mut self, - module: &Module, - lhs: &GLWECiphertext, - rhs: &GGLWEAutomorphismKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxBigAutomorphismInplace - + VecZnxBigSubSmallInplace - + VecZnxNormalizeTmpBytes - + VecZnxNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + pub fn automorphism_inplace(&mut self, module: &M, key: &K, scratch: &mut Scratch) + where + M: GLWEAutomorphism, + K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, + Scratch: ScratchTakeCore, { - #[cfg(debug_assertions)] - { - self.assert_keyswitch(module, lhs, &rhs.key, scratch); - } - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch_1); - (0..(self.rank() + 1).into()).for_each(|i| { - module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1); - module.vec_znx_big_sub_small_inplace(&mut res_big, i, &lhs.data, i); - module.vec_znx_big_normalize( - self.base2k().into(), - &mut self.data, - i, - rhs.base2k().into(), - &res_big, - i, - scratch_1, - ); - }) + module.glwe_automorphism_inplace(self, key, scratch); } - pub fn automorphism_sub_inplace( - &mut self, - module: &Module, - rhs: &GGLWEAutomorphismKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxBigAutomorphismInplace - + VecZnxBigSubSmallInplace - + VecZnxNormalizeTmpBytes - + VecZnxNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + pub fn automorphism_add_inplace(&mut self, module: &M, key: &K, scratch: &mut Scratch) + where + M: GLWEAutomorphism, + K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, + Scratch: ScratchTakeCore, { - #[cfg(debug_assertions)] - { - self.assert_keyswitch_inplace(module, &rhs.key, scratch); - } - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, B> = self.keyswitch_internal(module, res_dft, &rhs.key, scratch_1); - (0..(self.rank() + 1).into()).for_each(|i| { - module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1); - module.vec_znx_big_sub_small_inplace(&mut res_big, i, &self.data, i); - module.vec_znx_big_normalize( - self.base2k().into(), - &mut self.data, - i, - rhs.base2k().into(), - &res_big, - i, - scratch_1, - ); - }) + module.glwe_automorphism_add_inplace(self, key, scratch); } - pub fn automorphism_sub_negate( - &mut self, - module: &Module, - lhs: &GLWECiphertext, - rhs: &GGLWEAutomorphismKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxBigAutomorphismInplace - + VecZnxBigSubSmallNegateInplace - + VecZnxNormalizeTmpBytes - + VecZnxNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + pub fn automorphism_sub_inplace(&mut self, module: &M, key: &K, scratch: &mut Scratch) + where + M: GLWEAutomorphism, + K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, + Scratch: ScratchTakeCore, { - #[cfg(debug_assertions)] - { - self.assert_keyswitch(module, lhs, &rhs.key, scratch); - } - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch_1); - (0..(self.rank() + 1).into()).for_each(|i| { - module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1); - module.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, &lhs.data, i); - module.vec_znx_big_normalize( - self.base2k().into(), - &mut self.data, - i, - rhs.base2k().into(), - &res_big, - i, - scratch_1, - ); - }) + module.glwe_automorphism_sub_inplace(self, key, scratch); } - pub fn automorphism_sub_negate_inplace( - &mut self, - module: &Module, - rhs: &GGLWEAutomorphismKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxBigAutomorphismInplace - + VecZnxBigSubSmallNegateInplace - + VecZnxNormalizeTmpBytes - + VecZnxNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + pub fn automorphism_sub_negate_inplace(&mut self, module: &M, key: &K, scratch: &mut Scratch) + where + M: GLWEAutomorphism, + K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, + Scratch: ScratchTakeCore, { - #[cfg(debug_assertions)] - { - self.assert_keyswitch_inplace(module, &rhs.key, scratch); - } - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, B> = self.keyswitch_internal(module, res_dft, &rhs.key, scratch_1); - (0..(self.rank() + 1).into()).for_each(|i| { - module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1); - module.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, &self.data, i); - module.vec_znx_big_normalize( - self.base2k().into(), - &mut self.data, - i, - rhs.base2k().into(), - &res_big, - i, - scratch_1, - ); - }) + module.glwe_automorphism_sub_negate_inplace(self, key, scratch); } } + +pub trait GLWEAutomorphism +where + Self: GLWEKeyswitch + + VecZnxAutomorphismInplace + + VecZnxBigAutomorphismInplace + + VecZnxBigSubSmallInplace + + VecZnxBigSubSmallNegateInplace, +{ + fn glwe_automorphism_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize + where + R: GLWEInfos, + A: GLWEInfos, + K: GGLWEInfos, + { + self.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos) + } + + fn glwe_automorphism(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + A: GLWEToRef, + K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, + Scratch: ScratchTakeCore, + { + self.glwe_keyswitch(res, a, key, scratch); + + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_automorphism_inplace(key.p(), res.data_mut(), i, scratch); + } + } + + fn glwe_automorphism_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, + Scratch: ScratchTakeCore, + { + self.glwe_keyswitch_inplace(res, key, scratch); + + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_automorphism_inplace(key.p(), res.data_mut(), i, scratch); + } + } + + fn glwe_automorphism_add(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + A: GLWEToRef, + K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, + Scratch: ScratchTakeCore, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); + + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size + let mut res_big: VecZnxBig<_, BE> = keyswitch_internal(self, res_dft, a, key, scratch_1); + + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); + self.vec_znx_big_add_small_inplace(&mut res_big, i, a.data(), i); + self.vec_znx_big_normalize( + res.base2k().into(), + res.data_mut(), + i, + key.base2k().into(), + &res_big, + i, + scratch_1, + ); + } + } + + fn glwe_automorphism_add_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, + Scratch: ScratchTakeCore, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size + let mut res_big: VecZnxBig<_, BE> = keyswitch_internal(self, res_dft, res, key, scratch_1); + + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); + self.vec_znx_big_add_small_inplace(&mut res_big, i, res.data(), i); + self.vec_znx_big_normalize( + res.base2k().into(), + res.data_mut(), + i, + key.base2k().into(), + &res_big, + i, + scratch_1, + ); + } + } + + fn glwe_automorphism_sub(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + A: GLWEToRef, + K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, + Scratch: ScratchTakeCore, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); + + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size + let mut res_big: VecZnxBig<_, BE> = keyswitch_internal(self, res_dft, a, key, scratch_1); + + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); + self.vec_znx_big_sub_small_inplace(&mut res_big, i, a.data(), i); + self.vec_znx_big_normalize( + res.base2k().into(), + res.data_mut(), + i, + key.base2k().into(), + &res_big, + i, + scratch_1, + ); + } + } + + fn glwe_automorphism_sub_negate(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + A: GLWEToRef, + K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, + Scratch: ScratchTakeCore, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); + + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size + let mut res_big: VecZnxBig<_, BE> = keyswitch_internal(self, res_dft, a, key, scratch_1); + + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); + self.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, a.data(), i); + self.vec_znx_big_normalize( + res.base2k().into(), + res.data_mut(), + i, + key.base2k().into(), + &res_big, + i, + scratch_1, + ); + } + } + + fn glwe_automorphism_sub_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, + Scratch: ScratchTakeCore, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size + let mut res_big: VecZnxBig<_, BE> = keyswitch_internal(self, res_dft, res, key, scratch_1); + + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); + self.vec_znx_big_sub_small_inplace(&mut res_big, i, res.data(), i); + self.vec_znx_big_normalize( + res.base2k().into(), + res.data_mut(), + i, + key.base2k().into(), + &res_big, + i, + scratch_1, + ); + } + } + + fn glwe_automorphism_sub_negate_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, + Scratch: ScratchTakeCore, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size + let mut res_big: VecZnxBig<_, BE> = keyswitch_internal(self, res_dft, res, key, scratch_1); + + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); + self.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, res.data(), i); + self.vec_znx_big_normalize( + res.base2k().into(), + res.data_mut(), + i, + key.base2k().into(), + &res_big, + i, + scratch_1, + ); + } + } +} + +impl GLWEAutomorphism for Module where + Self: GLWEKeyswitch + + VecZnxAutomorphismInplace + + VecZnxBigAutomorphismInplace + + VecZnxBigSubSmallInplace + + VecZnxBigSubSmallNegateInplace +{ +} diff --git a/poulpy-core/src/automorphism/mod.rs b/poulpy-core/src/automorphism/mod.rs index f985c5e..fd10f33 100644 --- a/poulpy-core/src/automorphism/mod.rs +++ b/poulpy-core/src/automorphism/mod.rs @@ -1,3 +1,7 @@ mod gglwe_atk; mod ggsw_ct; mod glwe_ct; + +pub use gglwe_atk::*; +pub use ggsw_ct::*; +pub use glwe_ct::*; diff --git a/poulpy-core/src/conversion/gglwe_to_ggsw.rs b/poulpy-core/src/conversion/gglwe_to_ggsw.rs new file mode 100644 index 0000000..b33759e --- /dev/null +++ b/poulpy-core/src/conversion/gglwe_to_ggsw.rs @@ -0,0 +1,286 @@ +use poulpy_hal::{ + api::{ + ModuleN, ScratchAvailable, ScratchTakeBasic, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxDftAddInplace, VecZnxDftApply, + VecZnxDftBytesOf, VecZnxDftCopy, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, + VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + }, + layouts::{Backend, DataMut, Module, Scratch, VmpPMat, ZnxInfos}, +}; + +use crate::{ + GLWECopy, ScratchTakeCore, + layouts::{ + GGLWE, GGLWEInfos, GGLWEToRef, GGSW, GGSWInfos, GGSWToMut, GLWEInfos, LWEInfos, + prepared::{GLWETensorKeyPrepared, GLWETensorKeyPreparedToRef}, + }, +}; + +impl GGLWE> { + pub fn from_gglw_tmp_bytes(module: &M, res_infos: &R, tsk_infos: &A) -> usize + where + M: GGSWFromGGLWE, + R: GGSWInfos, + A: GGLWEInfos, + { + module.ggsw_from_gglwe_tmp_bytes(res_infos, tsk_infos) + } +} + +impl GGSW { + pub fn from_gglwe(&mut self, module: &M, gglwe: &G, tsk: &T, scratch: &mut Scratch) + where + M: GGSWFromGGLWE, + G: GGLWEToRef, + T: GLWETensorKeyPreparedToRef, + Scratch: ScratchTakeCore, + { + module.ggsw_from_gglwe(self, gglwe, tsk, scratch); + } +} + +impl GGSWFromGGLWE for Module +where + Self: GGSWExpandRows + GLWECopy, +{ + fn ggsw_from_gglwe_tmp_bytes(&self, res_infos: &R, tsk_infos: &A) -> usize + where + R: GGSWInfos, + A: GGLWEInfos, + { + self.ggsw_expand_rows_tmp_bytes(res_infos, tsk_infos) + } + + fn ggsw_from_gglwe(&self, res: &mut R, a: &A, tsk: &T, scratch: &mut Scratch) + where + R: GGSWToMut, + A: GGLWEToRef, + T: GLWETensorKeyPreparedToRef, + Scratch: ScratchTakeCore, + { + let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); + let a: &GGLWE<&[u8]> = &a.to_ref(); + let tsk: &GLWETensorKeyPrepared<&[u8], BE> = &tsk.to_ref(); + + assert_eq!(res.rank(), a.rank_out()); + assert_eq!(res.dnum(), a.dnum()); + assert_eq!(res.n(), self.n() as u32); + assert_eq!(a.n(), self.n() as u32); + assert_eq!(tsk.n(), self.n() as u32); + + for row in 0..res.dnum().into() { + self.glwe_copy(&mut res.at_mut(row, 0), &a.at(row, 0)); + } + + self.ggsw_expand_row(res, tsk, scratch); + } +} + +pub trait GGSWFromGGLWE { + fn ggsw_from_gglwe_tmp_bytes(&self, res_infos: &R, tsk_infos: &A) -> usize + where + R: GGSWInfos, + A: GGLWEInfos; + + fn ggsw_from_gglwe(&self, res: &mut R, a: &A, tsk: &T, scratch: &mut Scratch) + where + R: GGSWToMut, + A: GGLWEToRef, + T: GLWETensorKeyPreparedToRef, + Scratch: ScratchTakeCore; +} + +impl GGSWExpandRows for Module where + Self: Sized + + ModuleN + + VecZnxDftBytesOf + + VmpApplyDftToDftTmpBytes + + VecZnxBigBytesOf + + VecZnxNormalizeTmpBytes + + VecZnxDftBytesOf + + VmpApplyDftToDftTmpBytes + + VecZnxBigBytesOf + + VecZnxNormalizeTmpBytes + + VecZnxDftApply + + VecZnxDftCopy + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VecZnxDftAddInplace + + VecZnxBigNormalize + + VecZnxIdftApplyTmpA + + VecZnxNormalize +{ +} + +pub trait GGSWExpandRows +where + Self: Sized + + ModuleN + + VecZnxDftBytesOf + + VmpApplyDftToDftTmpBytes + + VecZnxBigBytesOf + + VecZnxNormalizeTmpBytes + + VecZnxDftApply + + VecZnxDftCopy + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VecZnxDftAddInplace + + VecZnxBigNormalize + + VecZnxIdftApplyTmpA + + VecZnxNormalize, +{ + fn ggsw_expand_rows_tmp_bytes(&self, res_infos: &R, tsk_infos: &A) -> usize + where + R: GGSWInfos, + A: GGLWEInfos, + { + let tsk_size: usize = tsk_infos.k().div_ceil(tsk_infos.base2k()) as usize; + let size_in: usize = res_infos + .k() + .div_ceil(tsk_infos.base2k()) + .div_ceil(tsk_infos.dsize().into()) as usize; + + let tmp_dft_i: usize = self.bytes_of_vec_znx_dft((tsk_infos.rank_out() + 1).into(), tsk_size); + let tmp_a: usize = self.bytes_of_vec_znx_dft(1, size_in); + let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes( + tsk_size, + size_in, + size_in, + (tsk_infos.rank_in()).into(), // Verify if rank+1 + (tsk_infos.rank_out()).into(), // Verify if rank+1 + tsk_size, + ); + let tmp_idft: usize = self.bytes_of_vec_znx_big(1, tsk_size); + let norm: usize = self.vec_znx_normalize_tmp_bytes(); + + tmp_dft_i + ((tmp_a + vmp) | (tmp_idft + norm)) + } + + fn ggsw_expand_row(&self, res: &mut R, tsk: &T, scratch: &mut Scratch) + where + R: GGSWToMut, + T: GLWETensorKeyPreparedToRef, + Scratch: ScratchTakeCore, + { + let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); + let tsk: &GLWETensorKeyPrepared<&[u8], BE> = &tsk.to_ref(); + + let basek_in: usize = res.base2k().into(); + let basek_tsk: usize = tsk.base2k().into(); + + assert!(scratch.available() >= self.ggsw_expand_rows_tmp_bytes(res, tsk)); + + let rank: usize = res.rank().into(); + let cols: usize = rank + 1; + + let a_size: usize = (res.size() * basek_in).div_ceil(basek_tsk); + + // Keyswitch the j-th row of the col 0 + for row_i in 0..res.dnum().into() { + let a = &res.at(row_i, 0).data; + + // Pre-compute DFT of (a0, a1, a2) + let (mut ci_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, a_size); + + if basek_in == basek_tsk { + for i in 0..cols { + self.vec_znx_dft_apply(1, 0, &mut ci_dft, i, a, i); + } + } else { + let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(self.n(), 1, a_size); + for i in 0..cols { + self.vec_znx_normalize(basek_tsk, &mut a_conv, 0, basek_in, a, i, scratch_2); + self.vec_znx_dft_apply(1, 0, &mut ci_dft, i, &a_conv, 0); + } + } + + for col_j in 1..cols { + // Example for rank 3: + // + // Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is + // actually composed of that many dnum and we focus on a specific row here + // implicitely given ci_dft. + // + // # Input + // + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) + // col 1: (0, 0, 0, 0) + // col 2: (0, 0, 0, 0) + // col 3: (0, 0, 0, 0) + // + // # Output + // + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) + // col 1: (-(b0s0 + b1s1 + b2s2) , b0 + M[i], b1 , b2 ) + // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M[i], c2 ) + // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M[i]) + + let dsize: usize = tsk.dsize().into(); + + let (mut tmp_dft_i, scratch_2) = scratch_1.take_vec_znx_dft(self, cols, tsk.size()); + let (mut tmp_a, scratch_3) = scratch_2.take_vec_znx_dft(self, 1, ci_dft.size().div_ceil(dsize)); + + { + // Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2 + // + // # Example for col=1 + // + // a0 * (-(f0s0 + f1s1 + f1s2) + s0^2, f0, f1, f2) = (-(a0f0s0 + a0f1s1 + a0f1s2) + a0s0^2, a0f0, a0f1, a0f2) + // + + // a1 * (-(g0s0 + g1s1 + g1s2) + s0s1, g0, g1, g2) = (-(a1g0s0 + a1g1s1 + a1g1s2) + a1s0s1, a1g0, a1g1, a1g2) + // + + // a2 * (-(h0s0 + h1s1 + h1s2) + s0s2, h0, h1, h2) = (-(a2h0s0 + a2h1s1 + a2h1s2) + a2s0s2, a2h0, a2h1, a2h2) + // = + // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2) + for col_i in 1..cols { + let pmat: &VmpPMat<&[u8], BE> = &tsk.at(col_i - 1, col_j - 1).data; // Selects Enc(s[i]s[j]) + + // Extracts a[i] and multipies with Enc(s[i]s[j]) + for di in 0..dsize { + tmp_a.set_size((ci_dft.size() + di) / dsize); + + // Small optimization for dsize > 2 + // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then + // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}. + // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. + // It is possible to further ignore the last dsize-1 limbs, but this introduce + // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same + // noise is kept with respect to the ideal functionality. + tmp_dft_i.set_size(tsk.size() - ((dsize - di) as isize - 2).max(0) as usize); + + self.vec_znx_dft_copy(dsize, dsize - 1 - di, &mut tmp_a, 0, &ci_dft, col_i); + if di == 0 && col_i == 1 { + self.vmp_apply_dft_to_dft(&mut tmp_dft_i, &tmp_a, pmat, scratch_3); + } else { + self.vmp_apply_dft_to_dft_add(&mut tmp_dft_i, &tmp_a, pmat, di, scratch_3); + } + } + } + } + + // Adds -(sum a[i] * s[i]) + m) on the i-th column of tmp_idft_i + // + // (-(x0s0 + x1s1 + x2s2) + a0s0s0 + a1s0s1 + a2s0s2, x0, x1, x2) + // + + // (0, -(a0s0 + a1s1 + a2s2) + M[i], 0, 0) + // = + // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0 -(a0s0 + a1s1 + a2s2) + M[i], x1, x2) + // = + // (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2) + self.vec_znx_dft_add_inplace(&mut tmp_dft_i, col_j, &ci_dft, 0); + let (mut tmp_idft, scratch_3) = scratch_2.take_vec_znx_big(self, 1, tsk.size()); + for i in 0..cols { + self.vec_znx_idft_apply_tmpa(&mut tmp_idft, 0, &mut tmp_dft_i, i); + self.vec_znx_big_normalize( + basek_in, + &mut res.at_mut(row_i, col_j).data, + i, + basek_tsk, + &tmp_idft, + 0, + scratch_3, + ); + } + } + } + } +} diff --git a/poulpy-core/src/conversion/glwe_to_lwe.rs b/poulpy-core/src/conversion/glwe_to_lwe.rs index 517da90..fbf5912 100644 --- a/poulpy-core/src/conversion/glwe_to_lwe.rs +++ b/poulpy-core/src/conversion/glwe_to_lwe.rs @@ -1,106 +1,124 @@ use poulpy_hal::{ - api::{ - ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, - VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, - }, - layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut, ZnxZero}, + api::ModuleN, + layouts::{Backend, DataMut, Module, Scratch, ZnxView, ZnxViewMut, ZnxZero}, }; use crate::{ - TakeGLWECt, - layouts::{ - GGLWEInfos, GLWECiphertext, GLWECiphertextLayout, GLWEInfos, LWECiphertext, LWEInfos, Rank, - prepared::GLWEToLWESwitchingKeyPrepared, - }, + GLWEKeyswitch, ScratchTakeCore, + layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWELayout, GLWEToRef, LWE, LWEInfos, LWEToMut, Rank}, }; -impl LWECiphertext> { - pub fn from_glwe_scratch_space( - module: &Module, - lwe_infos: &OUT, - glwe_infos: &IN, - key_infos: &KEY, - ) -> usize +pub trait LWESampleExtract +where + Self: ModuleN, +{ + fn lwe_sample_extract(&self, res: &mut R, a: &A) where - OUT: LWEInfos, - IN: GLWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, + R: LWEToMut, + A: GLWEToRef, { - let glwe_layout: GLWECiphertextLayout = GLWECiphertextLayout { - n: module.n().into(), + let res: &mut LWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); + + assert!(res.n() <= a.n()); + assert_eq!(a.n(), self.n() as u32); + assert!(res.base2k() == a.base2k()); + + let min_size: usize = res.size().min(a.size()); + let n: usize = res.n().into(); + + res.data.zero(); + (0..min_size).for_each(|i| { + let data_lwe: &mut [i64] = res.data.at_mut(0, i); + data_lwe[0] = a.data.at(0, i)[0]; + data_lwe[1..].copy_from_slice(&a.data.at(1, i)[..n]); + }); + } +} + +impl LWESampleExtract for Module where Self: ModuleN {} +impl LWEFromGLWE for Module where Self: GLWEKeyswitch + LWESampleExtract {} + +pub trait LWEFromGLWE +where + Self: GLWEKeyswitch + LWESampleExtract, +{ + fn lwe_from_glwe_tmp_bytes(&self, lwe_infos: &R, glwe_infos: &A, key_infos: &K) -> usize + where + R: LWEInfos, + A: GLWEInfos, + K: GGLWEInfos, + { + let res_infos: GLWELayout = GLWELayout { + n: self.n().into(), base2k: lwe_infos.base2k(), k: lwe_infos.k(), rank: Rank(1), }; - GLWECiphertext::alloc_bytes_with( - module.n().into(), + GLWE::bytes_of( + self.n().into(), lwe_infos.base2k(), lwe_infos.k(), 1u32.into(), - ) + GLWECiphertext::keyswitch_scratch_space(module, &glwe_layout, glwe_infos, key_infos) - } -} - -impl LWECiphertext { - pub fn sample_extract(&mut self, a: &GLWECiphertext) { - #[cfg(debug_assertions)] - { - assert!(self.n() <= a.n()); - assert!(self.base2k() == a.base2k()); - } - - let min_size: usize = self.size().min(a.size()); - let n: usize = self.n().into(); - - self.data.zero(); - (0..min_size).for_each(|i| { - let data_lwe: &mut [i64] = self.data.at_mut(0, i); - data_lwe[0] = a.data.at(0, i)[0]; - data_lwe[1..].copy_from_slice(&a.data.at(1, i)[..n]); - }); + ) + self.glwe_keyswitch_tmp_bytes(&res_infos, glwe_infos, key_infos) } - pub fn from_glwe( - &mut self, - module: &Module, - a: &GLWECiphertext, - ks: &GLWEToLWESwitchingKeyPrepared, - scratch: &mut Scratch, - ) where - DGlwe: DataRef, - DKs: DataRef, - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable + TakeVecZnxDft + TakeGLWECt + TakeVecZnx, + fn lwe_from_glwe(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) + where + R: LWEToMut, + A: GLWEToRef, + K: GGLWEPreparedToRef + GGLWEInfos, + Scratch: ScratchTakeCore, { - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), module.n() as u32); - assert_eq!(ks.n(), module.n() as u32); - assert!(self.n() <= module.n() as u32); - } + let res: &mut LWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); - let glwe_layout: GLWECiphertextLayout = GLWECiphertextLayout { - n: module.n().into(), - base2k: self.base2k(), - k: self.k(), + assert_eq!(a.n(), self.n() as u32); + assert_eq!(key.n(), self.n() as u32); + assert!(res.n() <= self.n() as u32); + + let glwe_layout: GLWELayout = GLWELayout { + n: self.n().into(), + base2k: res.base2k(), + k: res.k(), rank: Rank(1), }; - let (mut tmp_glwe, scratch_1) = scratch.take_glwe_ct(&glwe_layout); - tmp_glwe.keyswitch(module, a, &ks.0, scratch_1); - self.sample_extract(&tmp_glwe); + let (mut tmp_glwe, scratch_1) = scratch.take_glwe(&glwe_layout); + self.glwe_keyswitch(&mut tmp_glwe, a, key, scratch_1); + self.lwe_sample_extract(res, &tmp_glwe); + } +} + +impl LWE> { + pub fn from_glwe_tmp_bytes(module: &M, lwe_infos: &R, glwe_infos: &A, key_infos: &K) -> usize + where + R: LWEInfos, + A: GLWEInfos, + K: GGLWEInfos, + M: LWEFromGLWE, + { + module.lwe_from_glwe_tmp_bytes(lwe_infos, glwe_infos, key_infos) + } +} + +impl LWE { + pub fn sample_extract(&mut self, module: &M, a: &A) + where + A: GLWEToRef, + M: LWESampleExtract, + { + module.lwe_sample_extract(self, a); + } + + pub fn from_glwe(&mut self, module: &M, a: &A, key: &K, scratch: &mut Scratch) + where + A: GLWEToRef, + K: GGLWEPreparedToRef + GGLWEInfos, + M: LWEFromGLWE, + Scratch: ScratchTakeCore, + { + module.lwe_from_glwe(self, a, key, scratch); } } diff --git a/poulpy-core/src/conversion/lwe_to_glwe.rs b/poulpy-core/src/conversion/lwe_to_glwe.rs index d3ae616..c759ee5 100644 --- a/poulpy-core/src/conversion/lwe_to_glwe.rs +++ b/poulpy-core/src/conversion/lwe_to_glwe.rs @@ -1,80 +1,56 @@ use poulpy_hal::{ - api::{ - ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, - VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, - }, - layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx, ZnxView, ZnxViewMut, ZnxZero}, + api::ScratchTakeBasic, + layouts::{Backend, DataMut, Module, Scratch, VecZnx, ZnxView, ZnxViewMut, ZnxZero}, }; use crate::{ - TakeGLWECt, - layouts::{ - GGLWEInfos, GLWECiphertext, GLWECiphertextLayout, GLWEInfos, LWECiphertext, LWEInfos, - prepared::LWEToGLWESwitchingKeyPrepared, - }, + GLWEKeyswitch, ScratchTakeCore, + layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWELayout, GLWEToMut, LWE, LWEInfos, LWEToRef}, }; -impl GLWECiphertext> { - pub fn from_lwe_scratch_space( - module: &Module, - glwe_infos: &OUT, - lwe_infos: &IN, - key_infos: &KEY, - ) -> usize +impl GLWEFromLWE for Module where Self: GLWEKeyswitch {} + +pub trait GLWEFromLWE +where + Self: GLWEKeyswitch, +{ + fn glwe_from_lwe_tmp_bytes(&self, glwe_infos: &R, lwe_infos: &A, key_infos: &K) -> usize where - OUT: GLWEInfos, - IN: LWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, + R: GLWEInfos, + A: LWEInfos, + K: GGLWEInfos, { - let ct: usize = GLWECiphertext::alloc_bytes_with( - module.n().into(), + let ct: usize = GLWE::bytes_of( + self.n().into(), key_infos.base2k(), lwe_infos.k().max(glwe_infos.k()), 1u32.into(), ); - let ks: usize = GLWECiphertext::keyswitch_inplace_scratch_space(module, glwe_infos, key_infos); + + let ks: usize = self.glwe_keyswitch_tmp_bytes(glwe_infos, glwe_infos, key_infos); if lwe_infos.base2k() == key_infos.base2k() { ct + ks } else { - let a_conv = VecZnx::alloc_bytes(module.n(), 1, lwe_infos.size()) + module.vec_znx_normalize_tmp_bytes(); + let a_conv = VecZnx::bytes_of(self.n(), 1, lwe_infos.size()) + self.vec_znx_normalize_tmp_bytes(); ct + a_conv + ks } } -} -impl GLWECiphertext { - pub fn from_lwe( - &mut self, - module: &Module, - lwe: &LWECiphertext, - ksk: &LWEToGLWESwitchingKeyPrepared, - scratch: &mut Scratch, - ) where - DLwe: DataRef, - DKsk: DataRef, - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable + TakeVecZnxDft + TakeGLWECt + TakeVecZnx, + fn glwe_from_lwe(&self, res: &mut R, lwe: &A, ksk: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + A: LWEToRef, + K: GGLWEPreparedToRef + GGLWEInfos, + Scratch: ScratchTakeCore, { - #[cfg(debug_assertions)] - { - assert_eq!(self.n(), module.n() as u32); - assert_eq!(ksk.n(), module.n() as u32); - assert!(lwe.n() <= module.n() as u32); - } + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let lwe: &LWE<&[u8]> = &lwe.to_ref(); - let (mut glwe, scratch_1) = scratch.take_glwe_ct(&GLWECiphertextLayout { + assert_eq!(res.n(), self.n() as u32); + assert_eq!(ksk.n(), self.n() as u32); + assert!(lwe.n() <= self.n() as u32); + + let (mut glwe, scratch_1) = scratch.take_glwe(&GLWELayout { n: ksk.n(), base2k: ksk.base2k(), k: lwe.k(), @@ -91,14 +67,14 @@ impl GLWECiphertext { glwe.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]); } } else { - let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(module.n(), 1, lwe.size()); + let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(self.n(), 1, lwe.size()); a_conv.zero(); for j in 0..lwe.size() { let data_lwe: &[i64] = lwe.data.at(0, j); a_conv.at_mut(0, j)[0] = data_lwe[0] } - module.vec_znx_normalize( + self.vec_znx_normalize( ksk.base2k().into(), &mut glwe.data, 0, @@ -114,7 +90,7 @@ impl GLWECiphertext { a_conv.at_mut(0, j)[..n_lwe].copy_from_slice(&data_lwe[1..]); } - module.vec_znx_normalize( + self.vec_znx_normalize( ksk.base2k().into(), &mut glwe.data, 1, @@ -125,6 +101,30 @@ impl GLWECiphertext { ); } - self.keyswitch(module, &glwe, &ksk.0, scratch_1); + self.glwe_keyswitch(res, &glwe, ksk, scratch_1); + } +} + +impl GLWE> { + pub fn from_lwe_tmp_bytes(module: &M, glwe_infos: &R, lwe_infos: &A, key_infos: &K) -> usize + where + R: GLWEInfos, + A: LWEInfos, + K: GGLWEInfos, + M: GLWEFromLWE, + { + module.glwe_from_lwe_tmp_bytes(glwe_infos, lwe_infos, key_infos) + } +} + +impl GLWE { + pub fn from_lwe(&mut self, module: &M, lwe: &A, ksk: &K, scratch: &mut Scratch) + where + M: GLWEFromLWE, + A: LWEToRef, + K: GGLWEPreparedToRef + GGLWEInfos, + Scratch: ScratchTakeCore, + { + module.glwe_from_lwe(self, lwe, ksk, scratch); } } diff --git a/poulpy-core/src/conversion/mod.rs b/poulpy-core/src/conversion/mod.rs index 090208b..8be980a 100644 --- a/poulpy-core/src/conversion/mod.rs +++ b/poulpy-core/src/conversion/mod.rs @@ -1,2 +1,7 @@ +mod gglwe_to_ggsw; mod glwe_to_lwe; mod lwe_to_glwe; + +pub use gglwe_to_ggsw::*; +pub use glwe_to_lwe::*; +pub use lwe_to_glwe::*; diff --git a/poulpy-core/src/decryption/glwe.rs b/poulpy-core/src/decryption/glwe.rs new file mode 100644 index 0000000..6dc7f5a --- /dev/null +++ b/poulpy-core/src/decryption/glwe.rs @@ -0,0 +1,125 @@ +use poulpy_hal::{ + api::{ + ModuleN, ScratchTakeBasic, SvpApplyDftToDftInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigBytesOf, + VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalizeTmpBytes, + }, + layouts::{Backend, DataRef, DataViewMut, Module, Scratch}, +}; + +use crate::layouts::{ + GLWE, GLWEInfos, GLWEPlaintext, GLWEPlaintextToMut, GLWEToRef, LWEInfos, + prepared::{GLWESecretPrepared, GLWESecretPreparedToRef}, +}; + +impl GLWE> { + pub fn decrypt_tmp_bytes(module: &M, a_infos: &A) -> usize + where + A: GLWEInfos, + M: GLWEDecrypt, + { + module.glwe_decrypt_tmp_bytes(a_infos) + } +} + +impl GLWE { + pub fn decrypt(&self, module: &M, pt: &mut P, sk: &S, scratch: &mut Scratch) + where + P: GLWEPlaintextToMut, + S: GLWESecretPreparedToRef, + M: GLWEDecrypt, + Scratch: ScratchTakeBasic, + { + module.glwe_decrypt(self, pt, sk, scratch); + } +} + +pub trait GLWEDecrypt +where + Self: Sized + + ModuleN + + VecZnxDftBytesOf + + VecZnxNormalizeTmpBytes + + VecZnxBigBytesOf + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxBigAddInplace + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize, +{ + fn glwe_decrypt_tmp_bytes(&self, infos: &A) -> usize + where + A: GLWEInfos, + { + let size: usize = infos.size(); + (self.vec_znx_normalize_tmp_bytes() | self.bytes_of_vec_znx_dft(1, size)) + self.bytes_of_vec_znx_dft(1, size) + } + + fn glwe_decrypt(&self, res: &R, pt: &mut P, sk: &S, scratch: &mut Scratch) + where + R: GLWEToRef, + P: GLWEPlaintextToMut, + S: GLWESecretPreparedToRef, + Scratch: ScratchTakeBasic, + { + let res: &GLWE<&[u8]> = &res.to_ref(); + let pt: &mut GLWEPlaintext<&mut [u8]> = &mut pt.to_ref(); + let sk: &GLWESecretPrepared<&[u8], BE> = &sk.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.rank(), sk.rank()); + assert_eq!(res.n(), sk.n()); + assert_eq!(pt.n(), sk.n()); + } + + let cols: usize = (res.rank() + 1).into(); + + let (mut c0_big, scratch_1) = scratch.take_vec_znx_big(self, 1, res.size()); // TODO optimize size when pt << ct + c0_big.data_mut().fill(0); + + { + (1..cols).for_each(|i| { + // ci_dft = DFT(a[i]) * DFT(s[i]) + let (mut ci_dft, _) = scratch_1.take_vec_znx_dft(self, 1, res.size()); // TODO optimize size when pt << ct + self.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, &res.data, i); + self.svp_apply_dft_to_dft_inplace(&mut ci_dft, 0, &sk.data, i - 1); + let ci_big = self.vec_znx_idft_apply_consume(ci_dft); + + // c0_big += a[i] * s[i] + self.vec_znx_big_add_inplace(&mut c0_big, 0, &ci_big, 0); + }); + } + + // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) + self.vec_znx_big_add_small_inplace(&mut c0_big, 0, &res.data, 0); + + // pt = norm(BIG(m + e)) + self.vec_znx_big_normalize( + res.base2k().into(), + &mut pt.data, + 0, + res.base2k().into(), + &c0_big, + 0, + scratch_1, + ); + + pt.base2k = res.base2k(); + pt.k = pt.k().min(res.k()); + } +} + +impl GLWEDecrypt for Module where + Self: ModuleN + + VecZnxDftBytesOf + + VecZnxNormalizeTmpBytes + + VecZnxBigBytesOf + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxBigAddInplace + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize +{ +} diff --git a/poulpy-core/src/decryption/glwe_ct.rs b/poulpy-core/src/decryption/glwe_ct.rs deleted file mode 100644 index 19b4a82..0000000 --- a/poulpy-core/src/decryption/glwe_ct.rs +++ /dev/null @@ -1,80 +0,0 @@ -use poulpy_hal::{ - api::{ - SvpApplyDftToDftInplace, TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigNormalize, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalizeTmpBytes, - }, - layouts::{Backend, DataMut, DataRef, DataViewMut, Module, Scratch}, -}; - -use crate::layouts::{GLWECiphertext, GLWEInfos, GLWEPlaintext, LWEInfos, prepared::GLWESecretPrepared}; - -impl GLWECiphertext> { - pub fn decrypt_scratch_space(module: &Module, infos: &A) -> usize - where - A: GLWEInfos, - Module: VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes, - { - let size: usize = infos.size(); - (module.vec_znx_normalize_tmp_bytes() | module.vec_znx_dft_alloc_bytes(1, size)) + module.vec_znx_dft_alloc_bytes(1, size) - } -} - -impl GLWECiphertext { - pub fn decrypt( - &self, - module: &Module, - pt: &mut GLWEPlaintext, - sk: &GLWESecretPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize, - Scratch: TakeVecZnxDft + TakeVecZnxBig, - { - #[cfg(debug_assertions)] - { - assert_eq!(self.rank(), sk.rank()); - assert_eq!(self.n(), sk.n()); - assert_eq!(pt.n(), sk.n()); - } - - let cols: usize = (self.rank() + 1).into(); - - let (mut c0_big, scratch_1) = scratch.take_vec_znx_big(self.n().into(), 1, self.size()); // TODO optimize size when pt << ct - c0_big.data_mut().fill(0); - - { - (1..cols).for_each(|i| { - // ci_dft = DFT(a[i]) * DFT(s[i]) - let (mut ci_dft, _) = scratch_1.take_vec_znx_dft(self.n().into(), 1, self.size()); // TODO optimize size when pt << ct - module.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, &self.data, i); - module.svp_apply_dft_to_dft_inplace(&mut ci_dft, 0, &sk.data, i - 1); - let ci_big = module.vec_znx_idft_apply_consume(ci_dft); - - // c0_big += a[i] * s[i] - module.vec_znx_big_add_inplace(&mut c0_big, 0, &ci_big, 0); - }); - } - - // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) - module.vec_znx_big_add_small_inplace(&mut c0_big, 0, &self.data, 0); - - // pt = norm(BIG(m + e)) - module.vec_znx_big_normalize( - self.base2k().into(), - &mut pt.data, - 0, - self.base2k().into(), - &c0_big, - 0, - scratch_1, - ); - - pt.base2k = self.base2k(); - pt.k = pt.k().min(self.k()); - } -} diff --git a/poulpy-core/src/decryption/lwe.rs b/poulpy-core/src/decryption/lwe.rs new file mode 100644 index 0000000..edd727b --- /dev/null +++ b/poulpy-core/src/decryption/lwe.rs @@ -0,0 +1,65 @@ +use poulpy_hal::{ + api::{ScratchOwnedAlloc, ScratchOwnedBorrow, ZnNormalizeInplace}, + layouts::{Backend, DataMut, DataRef, Module, ScratchOwned, ZnxView, ZnxViewMut}, +}; + +use crate::layouts::{LWE, LWEInfos, LWEPlaintext, LWEPlaintextToMut, LWESecret, LWESecretToRef, LWEToMut}; + +impl LWE { + pub fn decrypt(&mut self, module: &M, pt: &mut P, sk: &S) + where + P: LWEPlaintextToMut, + S: LWESecretToRef, + M: LWEDecrypt, + { + module.lwe_decrypt(self, pt, sk); + } +} + +pub trait LWEDecrypt { + fn lwe_decrypt(&self, res: &mut R, pt: &mut P, sk: &S) + where + R: LWEToMut, + P: LWEPlaintextToMut, + S: LWESecretToRef; +} + +impl LWEDecrypt for Module +where + Self: Sized + ZnNormalizeInplace, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + fn lwe_decrypt(&self, res: &mut R, pt: &mut P, sk: &S) + where + R: LWEToMut, + P: LWEPlaintextToMut, + S: LWESecretToRef, + { + let res: &mut LWE<&mut [u8]> = &mut res.to_mut(); + let pt: &mut LWEPlaintext<&mut [u8]> = &mut pt.to_mut(); + let sk: LWESecret<&[u8]> = sk.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.n(), sk.n()); + } + + (0..pt.size().min(res.size())).for_each(|i| { + pt.data.at_mut(0, i)[0] = res.data.at(0, i)[0] + + res.data.at(0, i)[1..] + .iter() + .zip(sk.data.at(0, 0)) + .map(|(x, y)| x * y) + .sum::(); + }); + self.zn_normalize_inplace( + 1, + res.base2k().into(), + &mut pt.data, + 0, + ScratchOwned::alloc(size_of::()).borrow(), + ); + pt.base2k = res.base2k(); + pt.k = crate::layouts::TorusPrecision(res.k().0.min(pt.size() as u32 * res.base2k().0)); + } +} diff --git a/poulpy-core/src/decryption/lwe_ct.rs b/poulpy-core/src/decryption/lwe_ct.rs deleted file mode 100644 index 57abdc6..0000000 --- a/poulpy-core/src/decryption/lwe_ct.rs +++ /dev/null @@ -1,43 +0,0 @@ -use poulpy_hal::{ - api::{ScratchOwnedAlloc, ScratchOwnedBorrow, ZnNormalizeInplace}, - layouts::{Backend, DataMut, DataRef, Module, ScratchOwned, ZnxView, ZnxViewMut}, - oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl}, -}; - -use crate::layouts::{LWECiphertext, LWEInfos, LWEPlaintext, LWESecret}; - -impl LWECiphertext -where - DataSelf: DataRef, -{ - pub fn decrypt(&self, module: &Module, pt: &mut LWEPlaintext, sk: &LWESecret) - where - DataPt: DataMut, - DataSk: DataRef, - Module: ZnNormalizeInplace, - B: Backend + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, - { - #[cfg(debug_assertions)] - { - assert_eq!(self.n(), sk.n()); - } - - (0..pt.size().min(self.size())).for_each(|i| { - pt.data.at_mut(0, i)[0] = self.data.at(0, i)[0] - + self.data.at(0, i)[1..] - .iter() - .zip(sk.data.at(0, 0)) - .map(|(x, y)| x * y) - .sum::(); - }); - module.zn_normalize_inplace( - 1, - self.base2k().into(), - &mut pt.data, - 0, - ScratchOwned::alloc(size_of::()).borrow(), - ); - pt.base2k = self.base2k(); - pt.k = crate::layouts::TorusPrecision(self.k().0.min(pt.size() as u32 * self.base2k().0)); - } -} diff --git a/poulpy-core/src/decryption/mod.rs b/poulpy-core/src/decryption/mod.rs index 8165d78..4266117 100644 --- a/poulpy-core/src/decryption/mod.rs +++ b/poulpy-core/src/decryption/mod.rs @@ -1,2 +1,5 @@ -mod glwe_ct; -mod lwe_ct; +mod glwe; +mod lwe; + +pub use glwe::*; +pub use lwe::*; diff --git a/poulpy-core/src/dist.rs b/poulpy-core/src/dist.rs index 415b0b5..c754278 100644 --- a/poulpy-core/src/dist.rs +++ b/poulpy-core/src/dist.rs @@ -1,5 +1,13 @@ use std::io::{Read, Result, Write}; +pub trait GetDistribution { + fn dist(&self) -> &Distribution; +} + +pub trait GetDistributionMut { + fn dist_mut(&mut self) -> &mut Distribution; +} + #[derive(Clone, Copy, Debug)] pub enum Distribution { TernaryFixed(usize), // Ternary with fixed Hamming weight diff --git a/poulpy-core/src/encryption/compressed/gglwe.rs b/poulpy-core/src/encryption/compressed/gglwe.rs new file mode 100644 index 0000000..1dfbf58 --- /dev/null +++ b/poulpy-core/src/encryption/compressed/gglwe.rs @@ -0,0 +1,175 @@ +use poulpy_hal::{ + api::{ModuleN, ScratchAvailable, VecZnxAddScalarInplace, VecZnxDftBytesOf, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes}, + layouts::{Backend, DataMut, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxInfos, ZnxZero}, + source::Source, +}; + +use crate::{ + ScratchTakeCore, + encryption::{GLWEEncryptSk, GLWEEncryptSkInternal, SIGMA}, + layouts::{ + GGLWECompressedSeedMut, GGLWEInfos, GLWEPlaintext, GLWESecretPrepared, LWEInfos, + compressed::{GGLWECompressed, GGLWECompressedToMut}, + prepared::GLWESecretPreparedToRef, + }, +}; + +impl GGLWECompressed { + #[allow(clippy::too_many_arguments)] + pub fn encrypt_sk( + &mut self, + module: &M, + pt: &P, + sk: &S, + seed: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + P: ScalarZnxToRef, + S: GLWESecretPreparedToRef, + M: GGLWECompressedEncryptSk, + { + module.gglwe_compressed_encrypt_sk(self, pt, sk, seed, source_xe, scratch); + } +} + +impl GGLWECompressed> { + pub fn encrypt_sk_tmp_bytes(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: GGLWECompressedEncryptSk, + { + module.gglwe_compressed_encrypt_sk_tmp_bytes(infos) + } +} + +pub trait GGLWECompressedEncryptSk { + fn gglwe_compressed_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos; + + fn gglwe_compressed_encrypt_sk( + &self, + res: &mut R, + pt: &P, + sk: &S, + seed: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGLWECompressedToMut + GGLWECompressedSeedMut, + P: ScalarZnxToRef, + S: GLWESecretPreparedToRef; +} + +impl GGLWECompressedEncryptSk for Module +where + Self: ModuleN + + GLWEEncryptSkInternal + + GLWEEncryptSk + + VecZnxDftBytesOf + + VecZnxNormalizeInplace + + VecZnxAddScalarInplace + + VecZnxNormalizeTmpBytes, + Scratch: ScratchTakeCore, +{ + fn gglwe_compressed_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + self.glwe_encrypt_sk_tmp_bytes(infos) + .max(self.vec_znx_normalize_tmp_bytes()) + + GLWEPlaintext::bytes_of_from_infos(infos) + } + + fn gglwe_compressed_encrypt_sk( + &self, + res: &mut R, + pt: &P, + sk: &S, + seed: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGLWECompressedToMut + GGLWECompressedSeedMut, + P: ScalarZnxToRef, + S: GLWESecretPreparedToRef, + { + let mut seeds: Vec<[u8; 32]> = vec![[0u8; 32]; res.seed_mut().len()]; + + { + let res: &mut GGLWECompressed<&mut [u8]> = &mut res.to_mut(); + let pt: &ScalarZnx<&[u8]> = &pt.to_ref(); + let sk: &GLWESecretPrepared<&[u8], BE> = &sk.to_ref(); + + assert_eq!( + res.rank_in(), + pt.cols() as u32, + "res.rank_in(): {} != pt.cols(): {}", + res.rank_in(), + pt.cols() + ); + assert_eq!( + res.rank_out(), + sk.rank(), + "res.rank_out(): {} != sk.rank(): {}", + res.rank_out(), + sk.rank() + ); + assert_eq!(res.n(), sk.n()); + assert_eq!(pt.n() as u32, sk.n()); + assert!( + scratch.available() >= GGLWECompressed::encrypt_sk_tmp_bytes(self, res), + "scratch.available: {} < GGLWECiphertext::encrypt_sk_tmp_bytes: {}", + scratch.available(), + GGLWECompressed::encrypt_sk_tmp_bytes(self, res) + ); + assert!( + res.dnum().0 * res.dsize().0 * res.base2k().0 <= res.k().0, + "res.dnum() : {} * res.dsize() : {} * res.base2k() : {} = {} >= res.k() = {}", + res.dnum(), + res.dsize(), + res.base2k(), + res.dnum().0 * res.dsize().0 * res.base2k().0, + res.k() + ); + + let dnum: usize = res.dnum().into(); + let dsize: usize = res.dsize().into(); + let base2k: usize = res.base2k().into(); + let rank_in: usize = res.rank_in().into(); + let cols: usize = (res.rank_out() + 1).into(); + + let mut source_xa = Source::new(seed); + + let (mut tmp_pt, scrach_1) = scratch.take_glwe_plaintext(res); + for col_i in 0..rank_in { + for d_i in 0..dnum { + // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt + tmp_pt.data.zero(); // zeroes for next iteration + self.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (dsize - 1) + d_i * dsize, pt, col_i); + self.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scrach_1); + + let (seed, mut source_xa_tmp) = source_xa.branch(); + seeds[col_i * dnum + d_i] = seed; + + self.glwe_encrypt_sk_internal( + res.base2k().into(), + res.k().into(), + &mut res.at_mut(d_i, col_i).data, + cols, + true, + Some((&tmp_pt, 0)), + sk, + &mut source_xa_tmp, + source_xe, + SIGMA, + scrach_1, + ); + } + } + } + + res.seed_mut().copy_from_slice(&seeds); + } +} diff --git a/poulpy-core/src/encryption/compressed/gglwe_atk.rs b/poulpy-core/src/encryption/compressed/gglwe_atk.rs deleted file mode 100644 index 95dcf20..0000000 --- a/poulpy-core/src/encryption/compressed/gglwe_atk.rs +++ /dev/null @@ -1,95 +0,0 @@ -use poulpy_hal::{ - api::{ - ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxBigNormalize, VecZnxDftAllocBytes, - VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, - }, - layouts::{Backend, DataMut, DataRef, Module, Scratch}, - source::Source, -}; - -use crate::{ - TakeGLWESecret, TakeGLWESecretPrepared, - layouts::{ - GGLWEInfos, GLWEInfos, GLWESecret, LWEInfos, - compressed::{GGLWEAutomorphismKeyCompressed, GGLWESwitchingKeyCompressed}, - }, -}; - -impl GGLWEAutomorphismKeyCompressed> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize - where - A: GGLWEInfos, - Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + SvpPPolAllocBytes, - { - assert_eq!(module.n() as u32, infos.n()); - GGLWESwitchingKeyCompressed::encrypt_sk_scratch_space(module, infos) - + GLWESecret::alloc_bytes_with(infos.n(), infos.rank_out()) - } -} - -impl GGLWEAutomorphismKeyCompressed { - #[allow(clippy::too_many_arguments)] - pub fn encrypt_sk( - &mut self, - module: &Module, - p: i64, - sk: &GLWESecret, - seed_xa: [u8; 32], - source_xe: &mut Source, - scratch: &mut Scratch, - ) where - Module: VecZnxAutomorphism - + SvpPrepare - + SvpPPolAllocBytes - + VecZnxSwitchRing - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + VecZnxAddScalarInplace, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared, - { - #[cfg(debug_assertions)] - { - assert_eq!(self.n(), sk.n()); - assert_eq!(self.rank_out(), self.rank_in()); - assert_eq!(sk.rank(), self.rank_out()); - assert!( - scratch.available() >= GGLWEAutomorphismKeyCompressed::encrypt_sk_scratch_space(module, self), - "scratch.available(): {} < AutomorphismKey::encrypt_sk_scratch_space: {}", - scratch.available(), - GGLWEAutomorphismKeyCompressed::encrypt_sk_scratch_space(module, self) - ) - } - - let (mut sk_out, scratch_1) = scratch.take_glwe_secret(sk.n(), sk.rank()); - - { - (0..self.rank_out().into()).for_each(|i| { - module.vec_znx_automorphism( - module.galois_element_inv(p), - &mut sk_out.data.as_vec_znx_mut(), - i, - &sk.data.as_vec_znx(), - i, - ); - }); - } - - self.key - .encrypt_sk(module, sk, &sk_out, seed_xa, source_xe, scratch_1); - - self.p = p; - } -} diff --git a/poulpy-core/src/encryption/compressed/gglwe_ct.rs b/poulpy-core/src/encryption/compressed/gglwe_ct.rs deleted file mode 100644 index 76871da..0000000 --- a/poulpy-core/src/encryption/compressed/gglwe_ct.rs +++ /dev/null @@ -1,127 +0,0 @@ -use poulpy_hal::{ - api::{ - ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, - VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, - VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, - }, - layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, ZnxZero}, - source::Source, -}; - -use crate::{ - TakeGLWEPt, - encryption::{SIGMA, glwe_encrypt_sk_internal}, - layouts::{GGLWECiphertext, GGLWEInfos, LWEInfos, compressed::GGLWECiphertextCompressed, prepared::GLWESecretPrepared}, -}; - -impl GGLWECiphertextCompressed> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize - where - A: GGLWEInfos, - Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes, - { - GGLWECiphertext::encrypt_sk_scratch_space(module, infos) - } -} - -impl GGLWECiphertextCompressed { - #[allow(clippy::too_many_arguments)] - pub fn encrypt_sk( - &mut self, - module: &Module, - pt: &ScalarZnx, - sk: &GLWESecretPrepared, - seed: [u8; 32], - source_xe: &mut Source, - scratch: &mut Scratch, - ) where - Module: VecZnxAddScalarInplace - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, - { - #[cfg(debug_assertions)] - { - use poulpy_hal::layouts::ZnxInfos; - - assert_eq!( - self.rank_in(), - pt.cols() as u32, - "self.rank_in(): {} != pt.cols(): {}", - self.rank_in(), - pt.cols() - ); - assert_eq!( - self.rank_out(), - sk.rank(), - "self.rank_out(): {} != sk.rank(): {}", - self.rank_out(), - sk.rank() - ); - assert_eq!(self.n(), sk.n()); - assert_eq!(pt.n() as u32, sk.n()); - assert!( - scratch.available() >= GGLWECiphertextCompressed::encrypt_sk_scratch_space(module, self), - "scratch.available: {} < GGLWECiphertext::encrypt_sk_scratch_space: {}", - scratch.available(), - GGLWECiphertextCompressed::encrypt_sk_scratch_space(module, self) - ); - assert!( - self.dnum().0 * self.dsize().0 * self.base2k().0 <= self.k().0, - "self.dnum() : {} * self.dsize() : {} * self.base2k() : {} = {} >= self.k() = {}", - self.dnum(), - self.dsize(), - self.base2k(), - self.dnum().0 * self.dsize().0 * self.base2k().0, - self.k() - ); - } - - let dnum: usize = self.dnum().into(); - let dsize: usize = self.dsize().into(); - let base2k: usize = self.base2k().into(); - let rank_in: usize = self.rank_in().into(); - let cols: usize = (self.rank_out() + 1).into(); - - let mut source_xa = Source::new(seed); - - let (mut tmp_pt, scrach_1) = scratch.take_glwe_pt(self); - (0..rank_in).for_each(|col_i| { - (0..dnum).for_each(|d_i| { - // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt - tmp_pt.data.zero(); // zeroes for next iteration - module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (dsize - 1) + d_i * dsize, pt, col_i); - module.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scrach_1); - - let (seed, mut source_xa_tmp) = source_xa.branch(); - self.seed[col_i * dnum + d_i] = seed; - - glwe_encrypt_sk_internal( - module, - self.base2k().into(), - self.k().into(), - &mut self.at_mut(d_i, col_i).data, - cols, - true, - Some((&tmp_pt, 0)), - sk, - &mut source_xa_tmp, - source_xe, - SIGMA, - scrach_1, - ); - }); - }); - } -} diff --git a/poulpy-core/src/encryption/compressed/gglwe_ksk.rs b/poulpy-core/src/encryption/compressed/gglwe_ksk.rs deleted file mode 100644 index 8dd177f..0000000 --- a/poulpy-core/src/encryption/compressed/gglwe_ksk.rs +++ /dev/null @@ -1,108 +0,0 @@ -use poulpy_hal::{ - api::{ - ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, - VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, - VecZnxSubInplace, VecZnxSwitchRing, - }, - layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch}, - source::Source, -}; - -use crate::{ - TakeGLWESecretPrepared, - layouts::{ - Degree, GGLWECiphertext, GGLWEInfos, GLWEInfos, GLWESecret, LWEInfos, compressed::GGLWESwitchingKeyCompressed, - prepared::GLWESecretPrepared, - }, -}; - -impl GGLWESwitchingKeyCompressed> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize - where - A: GGLWEInfos, - Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + SvpPPolAllocBytes, - { - (GGLWECiphertext::encrypt_sk_scratch_space(module, infos) | ScalarZnx::alloc_bytes(module.n(), 1)) - + ScalarZnx::alloc_bytes(module.n(), infos.rank_in().into()) - + GLWESecretPrepared::alloc_bytes_with(module, infos.rank_out()) - } -} - -impl GGLWESwitchingKeyCompressed { - #[allow(clippy::too_many_arguments)] - pub fn encrypt_sk( - &mut self, - module: &Module, - sk_in: &GLWESecret, - sk_out: &GLWESecret, - seed_xa: [u8; 32], - source_xe: &mut Source, - scratch: &mut Scratch, - ) where - Module: SvpPrepare - + SvpPPolAllocBytes - + VecZnxSwitchRing - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + VecZnxAddScalarInplace, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared, - { - #[cfg(debug_assertions)] - { - use crate::layouts::GGLWESwitchingKey; - - assert!(sk_in.n().0 <= module.n() as u32); - assert!(sk_out.n().0 <= module.n() as u32); - assert!( - scratch.available() >= GGLWESwitchingKey::encrypt_sk_scratch_space(module, self), - "scratch.available()={} < GLWESwitchingKey::encrypt_sk_scratch_space={}", - scratch.available(), - GGLWESwitchingKey::encrypt_sk_scratch_space(module, self) - ) - } - - let n: usize = sk_in.n().max(sk_out.n()).into(); - - let (mut sk_in_tmp, scratch_1) = scratch.take_scalar_znx(n, sk_in.rank().into()); - (0..sk_in.rank().into()).for_each(|i| { - module.vec_znx_switch_ring( - &mut sk_in_tmp.as_vec_znx_mut(), - i, - &sk_in.data.as_vec_znx(), - i, - ); - }); - - let (mut sk_out_tmp, scratch_2) = scratch_1.take_glwe_secret_prepared(Degree(n as u32), sk_out.rank()); - { - let (mut tmp, _) = scratch_2.take_scalar_znx(n, 1); - (0..sk_out.rank().into()).for_each(|i| { - module.vec_znx_switch_ring(&mut tmp.as_vec_znx_mut(), 0, &sk_out.data.as_vec_znx(), i); - module.svp_prepare(&mut sk_out_tmp.data, i, &tmp, 0); - }); - } - - self.key.encrypt_sk( - module, - &sk_in_tmp, - &sk_out_tmp, - seed_xa, - source_xe, - scratch_2, - ); - self.sk_in_n = sk_in.n().into(); - self.sk_out_n = sk_out.n().into(); - } -} diff --git a/poulpy-core/src/encryption/compressed/gglwe_tsk.rs b/poulpy-core/src/encryption/compressed/gglwe_tsk.rs deleted file mode 100644 index 6a75a57..0000000 --- a/poulpy-core/src/encryption/compressed/gglwe_tsk.rs +++ /dev/null @@ -1,114 +0,0 @@ -use poulpy_hal::{ - api::{ - ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, - TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAllocBytes, - VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, - VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, - }, - layouts::{Backend, DataMut, DataRef, Module, Scratch}, - source::Source, -}; - -use crate::{ - TakeGLWESecret, TakeGLWESecretPrepared, - layouts::{ - GGLWEInfos, GGLWETensorKey, GLWEInfos, GLWESecret, LWEInfos, Rank, compressed::GGLWETensorKeyCompressed, - prepared::Prepare, - }, -}; - -impl GGLWETensorKeyCompressed> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize - where - A: GGLWEInfos, - Module: - SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + VecZnxBigAllocBytes, - { - GGLWETensorKey::encrypt_sk_scratch_space(module, infos) - } -} - -impl GGLWETensorKeyCompressed { - pub fn encrypt_sk( - &mut self, - module: &Module, - sk: &GLWESecret, - seed_xa: [u8; 32], - source_xe: &mut Source, - scratch: &mut Scratch, - ) where - Module: SvpApplyDftToDft - + VecZnxIdftApplyTmpA - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + VecZnxSwitchRing - + VecZnxAddScalarInplace - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc, - Scratch: ScratchAvailable - + TakeScalarZnx - + TakeVecZnxDft - + TakeGLWESecretPrepared - + ScratchAvailable - + TakeVecZnx - + TakeVecZnxBig, - { - #[cfg(debug_assertions)] - { - assert_eq!(self.rank_out(), sk.rank()); - assert_eq!(self.n(), sk.n()); - } - - let n: usize = sk.n().into(); - let rank: usize = self.rank_out().into(); - - let (mut sk_dft_prep, scratch_1) = scratch.take_glwe_secret_prepared(sk.n(), self.rank_out()); - sk_dft_prep.prepare(module, sk, scratch_1); - - let (mut sk_dft, scratch_2) = scratch_1.take_vec_znx_dft(n, rank, 1); - - for i in 0..rank { - module.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); - } - - let (mut sk_ij_big, scratch_3) = scratch_2.take_vec_znx_big(n, 1, 1); - let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(sk.n(), Rank(1)); - let (mut sk_ij_dft, scratch_5) = scratch_4.take_vec_znx_dft(n, 1, 1); - - let mut source_xa: Source = Source::new(seed_xa); - - for i in 0..rank { - for j in i..rank { - module.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_dft_prep.data, j, &sk_dft, i); - - module.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0); - module.vec_znx_big_normalize( - self.base2k().into(), - &mut sk_ij.data.as_vec_znx_mut(), - 0, - self.base2k().into(), - &sk_ij_big, - 0, - scratch_5, - ); - - let (seed_xa_tmp, _) = source_xa.branch(); - - self.at_mut(i, j) - .encrypt_sk(module, &sk_ij, sk, seed_xa_tmp, source_xe, scratch_5); - } - } - } -} diff --git a/poulpy-core/src/encryption/compressed/ggsw.rs b/poulpy-core/src/encryption/compressed/ggsw.rs new file mode 100644 index 0000000..14b0de5 --- /dev/null +++ b/poulpy-core/src/encryption/compressed/ggsw.rs @@ -0,0 +1,147 @@ +use poulpy_hal::{ + api::{ModuleN, VecZnxAddScalarInplace, VecZnxNormalizeInplace}, + layouts::{Backend, DataMut, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxInfos, ZnxZero}, + source::Source, +}; + +use crate::{ + ScratchTakeCore, + encryption::{GGSWEncryptSk, GLWEEncryptSkInternal, SIGMA}, + layouts::{ + GGSWCompressedSeedMut, GGSWInfos, LWEInfos, + compressed::{GGSWCompressed, GGSWCompressedToMut}, + prepared::{GLWESecretPrepared, GLWESecretPreparedToRef}, + }, +}; + +impl GGSWCompressed> { + pub fn encrypt_sk_tmp_bytes(module: &M, infos: &A) -> usize + where + A: GGSWInfos, + M: GGSWCompressedEncryptSk, + { + module.ggsw_compressed_encrypt_sk_tmp_bytes(infos) + } +} + +impl GGSWCompressed { + #[allow(clippy::too_many_arguments)] + pub fn encrypt_sk( + &mut self, + module: &M, + pt: &P, + sk: &S, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + P: ScalarZnxToRef, + S: GLWESecretPreparedToRef, + M: GGSWCompressedEncryptSk, + { + module.ggsw_compressed_encrypt_sk(self, pt, sk, seed_xa, source_xe, scratch); + } +} + +pub trait GGSWCompressedEncryptSk { + fn ggsw_compressed_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGSWInfos; + + fn ggsw_compressed_encrypt_sk( + &self, + res: &mut R, + pt: &P, + sk: &S, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGSWCompressedToMut + GGSWCompressedSeedMut + GGSWInfos, + P: ScalarZnxToRef, + S: GLWESecretPreparedToRef; +} + +impl GGSWCompressedEncryptSk for Module +where + Self: ModuleN + GLWEEncryptSkInternal + GGSWEncryptSk + VecZnxAddScalarInplace + VecZnxNormalizeInplace, + Scratch: ScratchTakeCore, +{ + fn ggsw_compressed_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGSWInfos, + { + self.ggsw_encrypt_sk_tmp_bytes(infos) + } + + fn ggsw_compressed_encrypt_sk( + &self, + res: &mut R, + pt: &P, + sk: &S, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGSWCompressedToMut + GGSWCompressedSeedMut + GGSWInfos, + P: ScalarZnxToRef, + S: GLWESecretPreparedToRef, + { + let base2k: usize = res.base2k().into(); + let rank: usize = res.rank().into(); + let cols: usize = rank + 1; + let dsize: usize = res.dsize().into(); + + let sk: &GLWESecretPrepared<&[u8], BE> = &sk.to_ref(); + let pt: &ScalarZnx<&[u8]> = &pt.to_ref(); + + assert_eq!(res.rank(), sk.rank()); + assert_eq!(pt.n(), self.n()); + assert_eq!(res.n(), self.n() as u32); + assert_eq!(sk.n(), self.n() as u32); + + let mut seeds: Vec<[u8; 32]> = vec![[0u8; 32]; res.dnum().as_usize() * (res.rank().as_usize() + 1)]; + + { + let res: &mut GGSWCompressed<&mut [u8]> = &mut res.to_mut(); + + println!("res.seed: {:?}", res.seed); + + let (mut tmp_pt, scratch_1) = scratch.take_glwe_plaintext(res); + + let mut source = Source::new(seed_xa); + + for row_i in 0..res.dnum().into() { + tmp_pt.data.zero(); + + // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt + self.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (dsize - 1) + row_i * dsize, pt, 0); + self.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scratch_1); + + for col_j in 0..rank + 1 { + // rlwe encrypt of vec_znx_pt into vec_znx_ct + + let (seed, mut source_xa_tmp) = source.branch(); + + seeds[row_i * cols + col_j] = seed; + + self.glwe_encrypt_sk_internal( + res.base2k().into(), + res.k().into(), + &mut res.at_mut(row_i, col_j).data, + cols, + true, + Some((&tmp_pt, col_j)), + sk, + &mut source_xa_tmp, + source_xe, + SIGMA, + scratch_1, + ); + } + } + } + + res.seed_mut().copy_from_slice(&seeds); + } +} diff --git a/poulpy-core/src/encryption/compressed/ggsw_ct.rs b/poulpy-core/src/encryption/compressed/ggsw_ct.rs deleted file mode 100644 index e49f246..0000000 --- a/poulpy-core/src/encryption/compressed/ggsw_ct.rs +++ /dev/null @@ -1,107 +0,0 @@ -use poulpy_hal::{ - api::{ - ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, - VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, - VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, - }, - layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, ZnxZero}, - source::Source, -}; - -use crate::{ - TakeGLWEPt, - encryption::{SIGMA, glwe_encrypt_sk_internal}, - layouts::{ - GGSWCiphertext, GGSWInfos, GLWEInfos, LWEInfos, compressed::GGSWCiphertextCompressed, prepared::GLWESecretPrepared, - }, -}; - -impl GGSWCiphertextCompressed> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize - where - A: GGSWInfos, - Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes, - { - GGSWCiphertext::encrypt_sk_scratch_space(module, infos) - } -} - -impl GGSWCiphertextCompressed { - #[allow(clippy::too_many_arguments)] - pub fn encrypt_sk( - &mut self, - module: &Module, - pt: &ScalarZnx, - sk: &GLWESecretPrepared, - seed_xa: [u8; 32], - source_xe: &mut Source, - scratch: &mut Scratch, - ) where - Module: VecZnxAddScalarInplace - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, - { - #[cfg(debug_assertions)] - { - use poulpy_hal::layouts::ZnxInfos; - - assert_eq!(self.rank(), sk.rank()); - assert_eq!(self.n(), sk.n()); - assert_eq!(pt.n() as u32, sk.n()); - } - - let base2k: usize = self.base2k().into(); - let rank: usize = self.rank().into(); - let cols: usize = rank + 1; - let dsize: usize = self.dsize().into(); - - let (mut tmp_pt, scratch_1) = scratch.take_glwe_pt(&self.glwe_layout()); - - let mut source = Source::new(seed_xa); - - self.seed = vec![[0u8; 32]; self.dnum().0 as usize * cols]; - - (0..self.dnum().into()).for_each(|row_i| { - tmp_pt.data.zero(); - - // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt - module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (dsize - 1) + row_i * dsize, pt, 0); - module.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scratch_1); - - (0..rank + 1).for_each(|col_j| { - // rlwe encrypt of vec_znx_pt into vec_znx_ct - - let (seed, mut source_xa_tmp) = source.branch(); - - self.seed[row_i * cols + col_j] = seed; - - glwe_encrypt_sk_internal( - module, - self.base2k().into(), - self.k().into(), - &mut self.at_mut(row_i, col_j).data, - cols, - true, - Some((&tmp_pt, col_j)), - sk, - &mut source_xa_tmp, - source_xe, - SIGMA, - scratch_1, - ); - }); - }); - } -} diff --git a/poulpy-core/src/encryption/compressed/glwe_automorphism_key.rs b/poulpy-core/src/encryption/compressed/glwe_automorphism_key.rs new file mode 100644 index 0000000..0c15e96 --- /dev/null +++ b/poulpy-core/src/encryption/compressed/glwe_automorphism_key.rs @@ -0,0 +1,126 @@ +use poulpy_hal::{ + api::{ModuleN, ScratchAvailable, VecZnxAutomorphism}, + layouts::{Backend, DataMut, GaloisElement, Module, Scratch}, + source::Source, +}; + +use crate::{ + GGLWECompressedEncryptSk, ScratchTakeCore, + layouts::{ + GGLWECompressedSeedMut, GGLWECompressedToMut, GGLWEInfos, GLWEInfos, GLWESecret, GLWESecretPrepared, + GLWESecretPreparedFactory, GLWESecretToRef, LWEInfos, SetGaloisElement, compressed::GLWEAutomorphismKeyCompressed, + }, +}; + +impl GLWEAutomorphismKeyCompressed> { + pub fn encrypt_sk_tmp_bytes(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: GLWEAutomorphismKeyCompressedEncryptSk, + { + module.glwe_automorphism_key_compressed_encrypt_sk_tmp_bytes(infos) + } +} + +impl GLWEAutomorphismKeyCompressed { + #[allow(clippy::too_many_arguments)] + pub fn encrypt_sk( + &mut self, + module: &M, + p: i64, + sk: &S, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + S: GLWESecretToRef + GLWEInfos, + M: GLWEAutomorphismKeyCompressedEncryptSk, + { + module.glwe_automorphism_key_compressed_encrypt_sk(self, p, sk, seed_xa, source_xe, scratch); + } +} + +pub trait GLWEAutomorphismKeyCompressedEncryptSk { + fn glwe_automorphism_key_compressed_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos; + + fn glwe_automorphism_key_compressed_encrypt_sk( + &self, + res: &mut R, + p: i64, + sk: &S, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGLWECompressedToMut + GGLWECompressedSeedMut + SetGaloisElement + GGLWEInfos, + S: GLWESecretToRef + GLWEInfos; +} + +impl GLWEAutomorphismKeyCompressedEncryptSk for Module +where + Self: ModuleN + GaloisElement + VecZnxAutomorphism + GGLWECompressedEncryptSk + GLWESecretPreparedFactory, + Scratch: ScratchTakeCore, +{ + fn glwe_automorphism_key_compressed_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + assert_eq!(self.n() as u32, infos.n()); + self.gglwe_compressed_encrypt_sk_tmp_bytes(infos) + .max(GLWESecret::bytes_of_from_infos(infos)) + + GLWESecretPrepared::bytes_of_from_infos(self, infos) + } + + fn glwe_automorphism_key_compressed_encrypt_sk( + &self, + res: &mut R, + p: i64, + sk: &S, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGLWECompressedToMut + GGLWECompressedSeedMut + SetGaloisElement + GGLWEInfos, + S: GLWESecretToRef + GLWEInfos, + { + let sk: &GLWESecret<&[u8]> = &sk.to_ref(); + assert_eq!(res.n(), sk.n()); + assert_eq!(res.rank_out(), res.rank_in()); + assert_eq!(sk.rank(), res.rank_out()); + assert!( + scratch.available() >= GLWEAutomorphismKeyCompressed::encrypt_sk_tmp_bytes(self, res), + "scratch.available(): {} < AutomorphismKey::encrypt_sk_tmp_bytes: {}", + scratch.available(), + GLWEAutomorphismKeyCompressed::encrypt_sk_tmp_bytes(self, res) + ); + + let (mut sk_out_prepared, scratch_1) = scratch.take_glwe_secret_prepared(self, sk.rank()); + { + let (mut sk_out, _) = scratch_1.take_glwe_secret(self.n().into(), sk.rank()); + sk_out.dist = sk.dist; + for i in 0..sk.rank().into() { + self.vec_znx_automorphism( + self.galois_element_inv(p), + &mut sk_out.data.as_vec_znx_mut(), + i, + &sk.data.as_vec_znx(), + i, + ); + } + sk_out_prepared.prepare(self, &sk_out); + } + + self.gglwe_compressed_encrypt_sk( + res, + &sk.data, + &sk_out_prepared, + seed_xa, + source_xe, + scratch_1, + ); + + res.set_p(p); + } +} diff --git a/poulpy-core/src/encryption/compressed/glwe_ct.rs b/poulpy-core/src/encryption/compressed/glwe_ct.rs index 834f968..47b8565 100644 --- a/poulpy-core/src/encryption/compressed/glwe_ct.rs +++ b/poulpy-core/src/encryption/compressed/glwe_ct.rs @@ -1,100 +1,109 @@ use poulpy_hal::{ - api::{ - ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, - VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, - VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, - }, - layouts::{Backend, DataMut, DataRef, Module, Scratch}, + layouts::{Backend, DataMut, Module, Scratch}, source::Source, }; use crate::{ - encryption::{SIGMA, glwe_ct::glwe_encrypt_sk_internal}, + encryption::{GLWEEncryptSk, GLWEEncryptSkInternal, SIGMA}, layouts::{ - GLWECiphertext, GLWEInfos, GLWEPlaintext, LWEInfos, compressed::GLWECiphertextCompressed, prepared::GLWESecretPrepared, + GLWECompressedSeedMut, GLWEInfos, GLWEPlaintextToRef, LWEInfos, + compressed::{GLWECompressed, GLWECompressedToMut}, + prepared::GLWESecretPreparedToRef, }, }; -impl GLWECiphertextCompressed> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize +impl GLWECompressed> { + pub fn encrypt_sk_tmp_bytes(module: &M, infos: &A) -> usize where A: GLWEInfos, - Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes, + M: GLWECompressedEncryptSk, { - GLWECiphertext::encrypt_sk_scratch_space(module, infos) + module.glwe_compressed_encrypt_sk_tmp_bytes(infos) } } -impl GLWECiphertextCompressed { +impl GLWECompressed { #[allow(clippy::too_many_arguments)] - pub fn encrypt_sk( + pub fn encrypt_sk( &mut self, - module: &Module, - pt: &GLWEPlaintext, - sk: &GLWESecretPrepared, + module: &M, + pt: &P, + sk: &S, seed_xa: [u8; 32], source_xe: &mut Source, - scratch: &mut Scratch, + scratch: &mut Scratch, ) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + M: GLWECompressedEncryptSk, + P: GLWEPlaintextToRef, + S: GLWESecretPreparedToRef, { - self.encrypt_sk_internal(module, Some((pt, 0)), sk, seed_xa, source_xe, scratch); - } - - #[allow(clippy::too_many_arguments)] - pub(crate) fn encrypt_sk_internal( - &mut self, - module: &Module, - pt: Option<(&GLWEPlaintext, usize)>, - sk: &GLWESecretPrepared, - seed_xa: [u8; 32], - source_xe: &mut Source, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, - { - let mut source_xa = Source::new(seed_xa); - let cols: usize = (self.rank() + 1).into(); - glwe_encrypt_sk_internal( - module, - self.base2k().into(), - self.k().into(), - &mut self.data, - cols, - true, - pt, - sk, - &mut source_xa, - source_xe, - SIGMA, - scratch, - ); - self.seed = seed_xa; + module.glwe_compressed_encrypt_sk(self, pt, sk, seed_xa, source_xe, scratch); + } +} + +pub trait GLWECompressedEncryptSk { + fn glwe_compressed_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GLWEInfos; + + fn glwe_compressed_encrypt_sk( + &self, + res: &mut R, + pt: &P, + sk: &S, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GLWECompressedToMut + GLWECompressedSeedMut, + P: GLWEPlaintextToRef, + S: GLWESecretPreparedToRef; +} + +impl GLWECompressedEncryptSk for Module +where + Self: GLWEEncryptSkInternal + GLWEEncryptSk, +{ + fn glwe_compressed_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GLWEInfos, + { + self.glwe_encrypt_sk_tmp_bytes(infos) + } + + fn glwe_compressed_encrypt_sk( + &self, + res: &mut R, + pt: &P, + sk: &S, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GLWECompressedToMut + GLWECompressedSeedMut, + P: GLWEPlaintextToRef, + S: GLWESecretPreparedToRef, + { + { + let res: &mut GLWECompressed<&mut [u8]> = &mut res.to_mut(); + let mut source_xa: Source = Source::new(seed_xa); + let cols: usize = (res.rank() + 1).into(); + + self.glwe_encrypt_sk_internal( + res.base2k().into(), + res.k().into(), + &mut res.data, + cols, + true, + Some((pt, 0)), + sk, + &mut source_xa, + source_xe, + SIGMA, + scratch, + ); + } + + res.seed_mut().copy_from_slice(&seed_xa); } } diff --git a/poulpy-core/src/encryption/compressed/glwe_switching_key.rs b/poulpy-core/src/encryption/compressed/glwe_switching_key.rs new file mode 100644 index 0000000..b801927 --- /dev/null +++ b/poulpy-core/src/encryption/compressed/glwe_switching_key.rs @@ -0,0 +1,131 @@ +use poulpy_hal::{ + api::{ModuleN, ScratchAvailable, ScratchTakeBasic, SvpPrepare, VecZnxSwitchRing}, + layouts::{Backend, DataMut, Module, ScalarZnx, Scratch}, + source::Source, +}; + +use crate::{ + GGLWECompressedEncryptSk, ScratchTakeCore, + layouts::{ + GGLWECompressedSeedMut, GGLWECompressedToMut, GGLWEInfos, GLWEInfos, GLWESecret, GLWESecretToRef, + GLWESwitchingKeyDegreesMut, LWEInfos, + compressed::GLWESwitchingKeyCompressed, + prepared::{GLWESecretPrepared, GLWESecretPreparedFactory}, + }, +}; + +impl GLWESwitchingKeyCompressed> { + pub fn encrypt_sk_tmp_bytes(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: GLWESwitchingKeyCompressedEncryptSk, + { + module.glwe_switching_key_compressed_encrypt_sk_tmp_bytes(infos) + } +} + +impl GLWESwitchingKeyCompressed { + #[allow(clippy::too_many_arguments)] + pub fn encrypt_sk( + &mut self, + module: &M, + sk_in: &S1, + sk_out: &S2, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + S1: GLWESecretToRef, + S2: GLWESecretToRef, + M: GLWESwitchingKeyCompressedEncryptSk, + { + module.glwe_switching_key_compressed_encrypt_sk(self, sk_in, sk_out, seed_xa, source_xe, scratch); + } +} + +pub trait GLWESwitchingKeyCompressedEncryptSk { + fn glwe_switching_key_compressed_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos; + + fn glwe_switching_key_compressed_encrypt_sk( + &self, + res: &mut R, + sk_in: &S1, + sk_out: &S2, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGLWECompressedToMut + GGLWECompressedSeedMut + GLWESwitchingKeyDegreesMut + GGLWEInfos, + S1: GLWESecretToRef, + S2: GLWESecretToRef; +} + +impl GLWESwitchingKeyCompressedEncryptSk for Module +where + Self: ModuleN + GGLWECompressedEncryptSk + GLWESecretPreparedFactory + VecZnxSwitchRing, + Scratch: ScratchTakeCore, +{ + fn glwe_switching_key_compressed_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + self.gglwe_compressed_encrypt_sk_tmp_bytes(infos) + .max(ScalarZnx::bytes_of(self.n(), 1)) + + ScalarZnx::bytes_of(self.n(), infos.rank_in().into()) + + GLWESecretPrepared::bytes_of(self, infos.rank_out()) + } + + fn glwe_switching_key_compressed_encrypt_sk( + &self, + res: &mut R, + sk_in: &S1, + sk_out: &S2, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGLWECompressedToMut + GGLWECompressedSeedMut + GLWESwitchingKeyDegreesMut + GGLWEInfos, + S1: GLWESecretToRef, + S2: GLWESecretToRef, + { + let sk_in: &GLWESecret<&[u8]> = &sk_in.to_ref(); + let sk_out: &GLWESecret<&[u8]> = &sk_out.to_ref(); + + assert!(sk_in.n().0 <= self.n() as u32); + assert!(sk_out.n().0 <= self.n() as u32); + assert!( + scratch.available() >= self.gglwe_compressed_encrypt_sk_tmp_bytes(res), + "scratch.available()={} < GLWESwitchingKey::encrypt_sk_tmp_bytes={}", + scratch.available(), + self.gglwe_compressed_encrypt_sk_tmp_bytes(res) + ); + + let (mut sk_in_tmp, scratch_1) = scratch.take_scalar_znx(self.n(), sk_in.rank().into()); + for i in 0..sk_in.rank().into() { + self.vec_znx_switch_ring( + &mut sk_in_tmp.as_vec_znx_mut(), + i, + &sk_in.data.as_vec_znx(), + i, + ); + } + + let (mut sk_out_tmp, scratch_2) = scratch_1.take_glwe_secret_prepared(self, sk_out.rank()); + { + let (mut tmp, _) = scratch_2.take_scalar_znx(self.n(), 1); + for i in 0..sk_out.rank().into() { + self.vec_znx_switch_ring(&mut tmp.as_vec_znx_mut(), 0, &sk_out.data.as_vec_znx(), i); + self.svp_prepare(&mut sk_out_tmp.data, i, &tmp, 0); + } + } + + sk_out_tmp.dist = sk_out.dist; + + self.gglwe_compressed_encrypt_sk(res, &sk_in_tmp, &sk_out_tmp, seed_xa, source_xe, scratch_2); + + *res.input_degree() = sk_in.n(); + *res.output_degree() = sk_out.n(); + } +} diff --git a/poulpy-core/src/encryption/compressed/glwe_tensor_key.rs b/poulpy-core/src/encryption/compressed/glwe_tensor_key.rs new file mode 100644 index 0000000..14c9217 --- /dev/null +++ b/poulpy-core/src/encryption/compressed/glwe_tensor_key.rs @@ -0,0 +1,156 @@ +use poulpy_hal::{ + api::{ + ModuleN, ScratchTakeBasic, SvpApplyDftToDft, SvpPPolBytesOf, SvpPrepare, VecZnxBigBytesOf, VecZnxBigNormalize, + VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyTmpA, + }, + layouts::{Backend, DataMut, Module, Scratch}, + source::Source, +}; + +use crate::{ + GGLWECompressedEncryptSk, GLWETensorKeyEncryptSk, GetDistribution, ScratchTakeCore, + layouts::{ + GGLWEInfos, GLWEInfos, GLWESecret, GLWESecretPrepared, GLWESecretPreparedFactory, GLWESecretToRef, + GLWETensorKeyCompressedAtMut, LWEInfos, Rank, compressed::GLWETensorKeyCompressed, + }, +}; + +impl GLWETensorKeyCompressed> { + pub fn encrypt_sk_tmp_bytes(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: GLWETensorKeyCompressedEncryptSk, + { + module.glwe_tensor_key_compressed_encrypt_sk_tmp_bytes(infos) + } +} + +impl GLWETensorKeyCompressed { + pub fn encrypt_sk( + &mut self, + module: &M, + sk: &S, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + S: GLWESecretToRef + GetDistribution, + M: GLWETensorKeyCompressedEncryptSk, + { + module.glwe_tensor_key_compressed_encrypt_sk(self, sk, seed_xa, source_xe, scratch); + } +} + +pub trait GLWETensorKeyCompressedEncryptSk { + fn glwe_tensor_key_compressed_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos; + + fn glwe_tensor_key_compressed_encrypt_sk( + &self, + res: &mut R, + sk: &S, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + D: DataMut, + R: GLWETensorKeyCompressedAtMut + GGLWEInfos, + S: GLWESecretToRef + GetDistribution; +} + +impl GLWETensorKeyCompressedEncryptSk for Module +where + Self: ModuleN + + GGLWECompressedEncryptSk + + GLWETensorKeyEncryptSk + + VecZnxDftApply + + SvpApplyDftToDft + + VecZnxIdftApplyTmpA + + VecZnxBigNormalize + + SvpPrepare + + SvpPPolBytesOf + + VecZnxDftBytesOf + + VecZnxBigBytesOf + + GLWESecretPreparedFactory, + Scratch: ScratchTakeBasic + ScratchTakeCore, +{ + fn glwe_tensor_key_compressed_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + GLWESecretPrepared::bytes_of(self, infos.rank_out()) + + self.bytes_of_vec_znx_dft(infos.rank_out().into(), 1) + + self.bytes_of_vec_znx_big(1, 1) + + self.bytes_of_vec_znx_dft(1, 1) + + GLWESecret::bytes_of(self.n().into(), Rank(1)) + + self.gglwe_compressed_encrypt_sk_tmp_bytes(infos) + } + + fn glwe_tensor_key_compressed_encrypt_sk( + &self, + res: &mut R, + sk: &S, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + D: DataMut, + R: GGLWEInfos + GLWETensorKeyCompressedAtMut, + S: GLWESecretToRef + GetDistribution, + { + let (mut sk_dft_prep, scratch_1) = scratch.take_glwe_secret_prepared(self, res.rank()); + sk_dft_prep.prepare(self, sk); + + let sk: &GLWESecret<&[u8]> = &sk.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.rank_out(), sk.rank()); + assert_eq!(res.n(), sk.n()); + } + + // let n: usize = sk.n().into(); + let rank: usize = res.rank_out().into(); + + let (mut sk_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, rank, 1); + + for i in 0..rank { + self.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); + } + + let (mut sk_ij_big, scratch_3) = scratch_2.take_vec_znx_big(self, 1, 1); + let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(self.n().into(), Rank(1)); + let (mut sk_ij_dft, scratch_5) = scratch_4.take_vec_znx_dft(self, 1, 1); + + let mut source_xa: Source = Source::new(seed_xa); + + for i in 0..rank { + for j in i..rank { + self.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_dft_prep.data, j, &sk_dft, i); + + self.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0); + self.vec_znx_big_normalize( + res.base2k().into(), + &mut sk_ij.data.as_vec_znx_mut(), + 0, + res.base2k().into(), + &sk_ij_big, + 0, + scratch_5, + ); + + let (seed_xa_tmp, _) = source_xa.branch(); + + self.gglwe_compressed_encrypt_sk( + res.at_mut(i, j), + &sk_ij.data, + &sk_dft_prep, + seed_xa_tmp, + source_xe, + scratch_5, + ); + } + } + } +} diff --git a/poulpy-core/src/encryption/compressed/mod.rs b/poulpy-core/src/encryption/compressed/mod.rs index e763db4..e96eeb5 100644 --- a/poulpy-core/src/encryption/compressed/mod.rs +++ b/poulpy-core/src/encryption/compressed/mod.rs @@ -1,6 +1,13 @@ -mod gglwe_atk; -mod gglwe_ct; -mod gglwe_ksk; -mod gglwe_tsk; -mod ggsw_ct; +mod gglwe; +mod ggsw; +mod glwe_automorphism_key; mod glwe_ct; +mod glwe_switching_key; +mod glwe_tensor_key; + +pub use gglwe::*; +pub use ggsw::*; +pub use glwe_automorphism_key::*; +pub use glwe_ct::*; +pub use glwe_switching_key::*; +pub use glwe_tensor_key::*; diff --git a/poulpy-core/src/encryption/gglwe.rs b/poulpy-core/src/encryption/gglwe.rs new file mode 100644 index 0000000..ba78cde --- /dev/null +++ b/poulpy-core/src/encryption/gglwe.rs @@ -0,0 +1,174 @@ +use poulpy_hal::{ + api::{ModuleN, ScratchAvailable, VecZnxAddScalarInplace, VecZnxDftBytesOf, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes}, + layouts::{Backend, DataMut, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxInfos, ZnxZero}, + source::Source, +}; + +use crate::{ + GLWEEncryptSk, ScratchTakeCore, + layouts::{ + GGLWE, GGLWEInfos, GGLWEToMut, GLWEPlaintext, LWEInfos, + prepared::{GLWESecretPrepared, GLWESecretPreparedToRef}, + }, +}; + +impl GGLWE> { + pub fn encrypt_sk_tmp_bytes(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: GGLWEEncryptSk, + { + module.gglwe_encrypt_sk_tmp_bytes(infos) + } + + pub fn encrypt_pk_tmp_bytes(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: GGLWEEncryptSk, + { + module.gglwe_encrypt_sk_tmp_bytes(infos) + } +} + +impl GGLWE { + #[allow(clippy::too_many_arguments)] + pub fn encrypt_sk( + &mut self, + module: &M, + pt: &P, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + P: ScalarZnxToRef, + S: GLWESecretPreparedToRef, + M: GGLWEEncryptSk, + Scratch: ScratchTakeCore, + { + module.gglwe_encrypt_sk(self, pt, sk, source_xa, source_xe, scratch); + } +} + +pub trait GGLWEEncryptSk { + fn gglwe_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos; + + fn gglwe_encrypt_sk( + &self, + res: &mut R, + pt: &P, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGLWEToMut, + P: ScalarZnxToRef, + S: GLWESecretPreparedToRef; +} + +impl GGLWEEncryptSk for Module +where + Self: ModuleN + + GLWEEncryptSk + + VecZnxNormalizeTmpBytes + + VecZnxDftBytesOf + + VecZnxAddScalarInplace + + VecZnxNormalizeInplace, + Scratch: ScratchTakeCore, +{ + fn gglwe_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + self.glwe_encrypt_sk_tmp_bytes(infos) + GLWEPlaintext::bytes_of_from_infos(infos).max(self.vec_znx_normalize_tmp_bytes()) + } + + fn gglwe_encrypt_sk( + &self, + res: &mut R, + pt: &P, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGLWEToMut, + P: ScalarZnxToRef, + S: GLWESecretPreparedToRef, + { + let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut(); + let pt: &ScalarZnx<&[u8]> = &pt.to_ref(); + let sk: &GLWESecretPrepared<&[u8], BE> = &sk.to_ref(); + + assert_eq!( + res.rank_in(), + pt.cols() as u32, + "res.rank_in(): {} != pt.cols(): {}", + res.rank_in(), + pt.cols() + ); + assert_eq!( + res.rank_out(), + sk.rank(), + "res.rank_out(): {} != sk.rank(): {}", + res.rank_out(), + sk.rank() + ); + assert_eq!(res.n(), sk.n()); + assert_eq!(pt.n() as u32, sk.n()); + assert!( + scratch.available() >= self.gglwe_encrypt_sk_tmp_bytes(res), + "scratch.available: {} < GGLWE::encrypt_sk_tmp_bytes(self, res.rank()={}, res.size()={}): {}", + scratch.available(), + res.rank_out(), + res.size(), + self.gglwe_encrypt_sk_tmp_bytes(res) + ); + assert!( + res.dnum().0 * res.dsize().0 * res.base2k().0 <= res.k().0, + "res.dnum() : {} * res.dsize() : {} * res.base2k() : {} = {} >= res.k() = {}", + res.dnum(), + res.dsize(), + res.base2k(), + res.dnum().0 * res.dsize().0 * res.base2k().0, + res.k() + ); + + let dnum: usize = res.dnum().into(); + let dsize: usize = res.dsize().into(); + let base2k: usize = res.base2k().into(); + let rank_in: usize = res.rank_in().into(); + + let (mut tmp_pt, scrach_1) = scratch.take_glwe_plaintext(res); + // For each input column (i.e. rank) produces a GGLWE of rank_out+1 columns + // + // Example for ksk rank 2 to rank 3: + // + // (-(a0*s0 + a1*s1 + a2*s2) + s0', a0, a1, a2) + // (-(b0*s0 + b1*s1 + b2*s2) + s0', b0, b1, b2) + // + // Example ksk rank 2 to rank 1 + // + // (-(a*s) + s0, a) + // (-(b*s) + s1, b) + for col_i in 0..rank_in { + for row_i in 0..dnum { + // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt + tmp_pt.data.zero(); // zeroes for next iteration + self.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (dsize - 1) + row_i * dsize, pt, col_i); + self.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scrach_1); + self.glwe_encrypt_sk( + &mut res.at_mut(row_i, col_i), + &tmp_pt, + sk, + source_xa, + source_xe, + scrach_1, + ); + } + } + } +} diff --git a/poulpy-core/src/encryption/gglwe_atk.rs b/poulpy-core/src/encryption/gglwe_atk.rs deleted file mode 100644 index 6d45b37..0000000 --- a/poulpy-core/src/encryption/gglwe_atk.rs +++ /dev/null @@ -1,109 +0,0 @@ -use poulpy_hal::{ - api::{ - ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxBigNormalize, VecZnxDftAllocBytes, - VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, - }, - layouts::{Backend, DataMut, DataRef, Module, Scratch}, - source::Source, -}; - -use crate::{ - TakeGLWESecret, TakeGLWESecretPrepared, - layouts::{GGLWEAutomorphismKey, GGLWEInfos, GGLWESwitchingKey, GLWEInfos, GLWESecret, LWEInfos}, -}; - -impl GGLWEAutomorphismKey> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize - where - A: GGLWEInfos, - Module: SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes, - { - assert_eq!( - infos.rank_in(), - infos.rank_out(), - "rank_in != rank_out is not supported for GGLWEAutomorphismKey" - ); - GGLWESwitchingKey::encrypt_sk_scratch_space(module, infos) + GLWESecret::alloc_bytes(&infos.glwe_layout()) - } - - pub fn encrypt_pk_scratch_space(module: &Module, _infos: &A) -> usize - where - A: GGLWEInfos, - { - assert_eq!( - _infos.rank_in(), - _infos.rank_out(), - "rank_in != rank_out is not supported for GGLWEAutomorphismKey" - ); - GGLWESwitchingKey::encrypt_pk_scratch_space(module, _infos) - } -} - -impl GGLWEAutomorphismKey { - #[allow(clippy::too_many_arguments)] - pub fn encrypt_sk( - &mut self, - module: &Module, - p: i64, - sk: &GLWESecret, - source_xa: &mut Source, - source_xe: &mut Source, - scratch: &mut Scratch, - ) where - Module: VecZnxAddScalarInplace - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + VecZnxSwitchRing - + SvpPPolAllocBytes - + VecZnxAutomorphism, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared, - { - #[cfg(debug_assertions)] - { - use crate::layouts::{GLWEInfos, LWEInfos}; - - assert_eq!(self.n(), sk.n()); - assert_eq!(self.rank_out(), self.rank_in()); - assert_eq!(sk.rank(), self.rank_out()); - assert!( - scratch.available() >= GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, self), - "scratch.available(): {} < AutomorphismKey::encrypt_sk_scratch_space: {:?}", - scratch.available(), - GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, self) - ) - } - - let (mut sk_out, scratch_1) = scratch.take_glwe_secret(sk.n(), sk.rank()); - - { - (0..self.rank_out().into()).for_each(|i| { - module.vec_znx_automorphism( - module.galois_element_inv(p), - &mut sk_out.data.as_vec_znx_mut(), - i, - &sk.data.as_vec_znx(), - i, - ); - }); - } - - self.key - .encrypt_sk(module, sk, &sk_out, source_xa, source_xe, scratch_1); - - self.p = p; - } -} diff --git a/poulpy-core/src/encryption/gglwe_ct.rs b/poulpy-core/src/encryption/gglwe_ct.rs deleted file mode 100644 index 51054cb..0000000 --- a/poulpy-core/src/encryption/gglwe_ct.rs +++ /dev/null @@ -1,130 +0,0 @@ -use poulpy_hal::{ - api::{ - ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, - VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, - VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, - }, - layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, ZnxZero}, - source::Source, -}; - -use crate::{ - TakeGLWEPt, - layouts::{GGLWECiphertext, GGLWEInfos, GLWECiphertext, GLWEPlaintext, LWEInfos, prepared::GLWESecretPrepared}, -}; - -impl GGLWECiphertext> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize - where - A: GGLWEInfos, - Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes, - { - GLWECiphertext::encrypt_sk_scratch_space(module, &infos.glwe_layout()) - + (GLWEPlaintext::alloc_bytes(&infos.glwe_layout()) | module.vec_znx_normalize_tmp_bytes()) - } - - pub fn encrypt_pk_scratch_space(_module: &Module, _infos: &A) -> usize - where - A: GGLWEInfos, - { - unimplemented!() - } -} - -impl GGLWECiphertext { - #[allow(clippy::too_many_arguments)] - pub fn encrypt_sk( - &mut self, - module: &Module, - pt: &ScalarZnx, - sk: &GLWESecretPrepared, - source_xa: &mut Source, - source_xe: &mut Source, - scratch: &mut Scratch, - ) where - Module: VecZnxAddScalarInplace - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, - { - #[cfg(debug_assertions)] - { - use poulpy_hal::layouts::ZnxInfos; - - assert_eq!( - self.rank_in(), - pt.cols() as u32, - "self.rank_in(): {} != pt.cols(): {}", - self.rank_in(), - pt.cols() - ); - assert_eq!( - self.rank_out(), - sk.rank(), - "self.rank_out(): {} != sk.rank(): {}", - self.rank_out(), - sk.rank() - ); - assert_eq!(self.n(), sk.n()); - assert_eq!(pt.n() as u32, sk.n()); - assert!( - scratch.available() >= GGLWECiphertext::encrypt_sk_scratch_space(module, self), - "scratch.available: {} < GGLWECiphertext::encrypt_sk_scratch_space(module, self.rank()={}, self.size()={}): {}", - scratch.available(), - self.rank_out(), - self.size(), - GGLWECiphertext::encrypt_sk_scratch_space(module, self) - ); - assert!( - self.dnum().0 * self.dsize().0 * self.base2k().0 <= self.k().0, - "self.dnum() : {} * self.dsize() : {} * self.base2k() : {} = {} >= self.k() = {}", - self.dnum(), - self.dsize(), - self.base2k(), - self.dnum().0 * self.dsize().0 * self.base2k().0, - self.k() - ); - } - - let dnum: usize = self.dnum().into(); - let dsize: usize = self.dsize().into(); - let base2k: usize = self.base2k().into(); - let rank_in: usize = self.rank_in().into(); - - let (mut tmp_pt, scrach_1) = scratch.take_glwe_pt(self); - // For each input column (i.e. rank) produces a GGLWE ciphertext of rank_out+1 columns - // - // Example for ksk rank 2 to rank 3: - // - // (-(a0*s0 + a1*s1 + a2*s2) + s0', a0, a1, a2) - // (-(b0*s0 + b1*s1 + b2*s2) + s0', b0, b1, b2) - // - // Example ksk rank 2 to rank 1 - // - // (-(a*s) + s0, a) - // (-(b*s) + s1, b) - (0..rank_in).for_each(|col_i| { - (0..dnum).for_each(|row_i| { - // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt - tmp_pt.data.zero(); // zeroes for next iteration - module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (dsize - 1) + row_i * dsize, pt, col_i); - module.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scrach_1); - - // rlwe encrypt of vec_znx_pt into vec_znx_ct - self.at_mut(row_i, col_i) - .encrypt_sk(module, &tmp_pt, sk, source_xa, source_xe, scrach_1); - }); - }); - } -} diff --git a/poulpy-core/src/encryption/gglwe_ksk.rs b/poulpy-core/src/encryption/gglwe_ksk.rs deleted file mode 100644 index 0629bec..0000000 --- a/poulpy-core/src/encryption/gglwe_ksk.rs +++ /dev/null @@ -1,112 +0,0 @@ -use poulpy_hal::{ - api::{ - ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, - VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, - VecZnxSubInplace, VecZnxSwitchRing, - }, - layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch}, - source::Source, -}; - -use crate::{ - TakeGLWESecretPrepared, - layouts::{ - Degree, GGLWECiphertext, GGLWEInfos, GGLWESwitchingKey, GLWEInfos, GLWESecret, LWEInfos, prepared::GLWESecretPrepared, - }, -}; - -impl GGLWESwitchingKey> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize - where - A: GGLWEInfos, - Module: SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes, - { - (GGLWECiphertext::encrypt_sk_scratch_space(module, infos) | ScalarZnx::alloc_bytes(module.n(), 1)) - + ScalarZnx::alloc_bytes(module.n(), infos.rank_in().into()) - + GLWESecretPrepared::alloc_bytes(module, &infos.glwe_layout()) - } - - pub fn encrypt_pk_scratch_space(module: &Module, _infos: &A) -> usize - where - A: GGLWEInfos, - { - GGLWECiphertext::encrypt_pk_scratch_space(module, _infos) - } -} - -impl GGLWESwitchingKey { - #[allow(clippy::too_many_arguments)] - pub fn encrypt_sk( - &mut self, - module: &Module, - sk_in: &GLWESecret, - sk_out: &GLWESecret, - source_xa: &mut Source, - source_xe: &mut Source, - scratch: &mut Scratch, - ) where - Module: VecZnxAddScalarInplace - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + VecZnxSwitchRing - + SvpPPolAllocBytes, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared, - { - #[cfg(debug_assertions)] - { - assert!(sk_in.n().0 <= module.n() as u32); - assert!(sk_out.n().0 <= module.n() as u32); - assert!( - scratch.available() >= GGLWESwitchingKey::encrypt_sk_scratch_space(module, self), - "scratch.available()={} < GLWESwitchingKey::encrypt_sk_scratch_space={}", - scratch.available(), - GGLWESwitchingKey::encrypt_sk_scratch_space(module, self) - ) - } - - let n: usize = sk_in.n().max(sk_out.n()).into(); - - let (mut sk_in_tmp, scratch_1) = scratch.take_scalar_znx(n, sk_in.rank().into()); - (0..sk_in.rank().into()).for_each(|i| { - module.vec_znx_switch_ring( - &mut sk_in_tmp.as_vec_znx_mut(), - i, - &sk_in.data.as_vec_znx(), - i, - ); - }); - - let (mut sk_out_tmp, scratch_2) = scratch_1.take_glwe_secret_prepared(Degree(n as u32), sk_out.rank()); - { - let (mut tmp, _) = scratch_2.take_scalar_znx(n, 1); - (0..sk_out.rank().into()).for_each(|i| { - module.vec_znx_switch_ring(&mut tmp.as_vec_znx_mut(), 0, &sk_out.data.as_vec_znx(), i); - module.svp_prepare(&mut sk_out_tmp.data, i, &tmp, 0); - }); - } - - self.key.encrypt_sk( - module, - &sk_in_tmp, - &sk_out_tmp, - source_xa, - source_xe, - scratch_2, - ); - self.sk_in_n = sk_in.n().into(); - self.sk_out_n = sk_out.n().into(); - } -} diff --git a/poulpy-core/src/encryption/gglwe_tsk.rs b/poulpy-core/src/encryption/gglwe_tsk.rs deleted file mode 100644 index 1946929..0000000 --- a/poulpy-core/src/encryption/gglwe_tsk.rs +++ /dev/null @@ -1,109 +0,0 @@ -use poulpy_hal::{ - api::{ - ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, - TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAllocBytes, - VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, - VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, - }, - layouts::{Backend, DataMut, DataRef, Module, Scratch}, - source::Source, -}; - -use crate::{ - TakeGLWESecret, TakeGLWESecretPrepared, - layouts::{ - Degree, GGLWEInfos, GGLWESwitchingKey, GGLWETensorKey, GLWEInfos, GLWESecret, LWEInfos, Rank, - prepared::{GLWESecretPrepared, Prepare}, - }, -}; - -impl GGLWETensorKey> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize - where - A: GGLWEInfos, - Module: - SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + VecZnxBigAllocBytes, - { - GLWESecretPrepared::alloc_bytes_with(module, infos.rank_out()) - + module.vec_znx_dft_alloc_bytes(infos.rank_out().into(), 1) - + module.vec_znx_big_alloc_bytes(1, 1) - + module.vec_znx_dft_alloc_bytes(1, 1) - + GLWESecret::alloc_bytes_with(Degree(module.n() as u32), Rank(1)) - + GGLWESwitchingKey::encrypt_sk_scratch_space(module, infos) - } -} - -impl GGLWETensorKey { - pub fn encrypt_sk( - &mut self, - module: &Module, - sk: &GLWESecret, - source_xa: &mut Source, - source_xe: &mut Source, - scratch: &mut Scratch, - ) where - Module: SvpApplyDftToDft - + VecZnxIdftApplyTmpA - + VecZnxAddScalarInplace - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + VecZnxSwitchRing - + SvpPPolAllocBytes, - Scratch: - TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared + TakeVecZnxBig, - { - #[cfg(debug_assertions)] - { - assert_eq!(self.rank_out(), sk.rank()); - assert_eq!(self.n(), sk.n()); - } - - let n: Degree = sk.n(); - let rank: Rank = self.rank_out(); - - let (mut sk_dft_prep, scratch_1) = scratch.take_glwe_secret_prepared(n, rank); - sk_dft_prep.prepare(module, sk, scratch_1); - - let (mut sk_dft, scratch_2) = scratch_1.take_vec_znx_dft(n.into(), rank.into(), 1); - - (0..rank.into()).for_each(|i| { - module.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); - }); - - let (mut sk_ij_big, scratch_3) = scratch_2.take_vec_znx_big(n.into(), 1, 1); - let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(n, Rank(1)); - let (mut sk_ij_dft, scratch_5) = scratch_4.take_vec_znx_dft(n.into(), 1, 1); - - (0..rank.into()).for_each(|i| { - (i..rank.into()).for_each(|j| { - module.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_dft_prep.data, j, &sk_dft, i); - - module.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0); - module.vec_znx_big_normalize( - self.base2k().into(), - &mut sk_ij.data.as_vec_znx_mut(), - 0, - self.base2k().into(), - &sk_ij_big, - 0, - scratch_5, - ); - - self.at_mut(i, j) - .encrypt_sk(module, &sk_ij, sk, source_xa, source_xe, scratch_5); - }); - }) - } -} diff --git a/poulpy-core/src/encryption/ggsw.rs b/poulpy-core/src/encryption/ggsw.rs new file mode 100644 index 0000000..86b810c --- /dev/null +++ b/poulpy-core/src/encryption/ggsw.rs @@ -0,0 +1,136 @@ +use poulpy_hal::{ + api::{ModuleN, VecZnxAddScalarInplace, VecZnxDftBytesOf, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes}, + layouts::{Backend, DataMut, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxInfos, ZnxZero}, + source::Source, +}; + +use crate::{ + GLWEEncryptSk, GLWEEncryptSkInternal, SIGMA, ScratchTakeCore, + layouts::{ + GGSW, GGSWInfos, GGSWToMut, GLWEInfos, GLWEPlaintext, LWEInfos, + prepared::{GLWESecretPrepared, GLWESecretPreparedToRef}, + }, +}; + +impl GGSW> { + pub fn encrypt_sk_tmp_bytes(module: &M, infos: &A) -> usize + where + A: GGSWInfos, + M: GGSWEncryptSk, + { + module.ggsw_encrypt_sk_tmp_bytes(infos) + } +} + +impl GGSW { + #[allow(clippy::too_many_arguments)] + pub fn encrypt_sk( + &mut self, + module: &M, + pt: &P, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + P: ScalarZnxToRef, + S: GLWESecretPreparedToRef, + M: GGSWEncryptSk, + Scratch: ScratchTakeCore, + { + module.ggsw_encrypt_sk(self, pt, sk, source_xa, source_xe, scratch); + } +} + +pub trait GGSWEncryptSk { + fn ggsw_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGSWInfos; + + fn ggsw_encrypt_sk( + &self, + res: &mut R, + pt: &P, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGSWToMut, + P: ScalarZnxToRef, + S: GLWESecretPreparedToRef; +} + +impl GGSWEncryptSk for Module +where + Self: ModuleN + + GLWEEncryptSkInternal + + GLWEEncryptSk + + VecZnxDftBytesOf + + VecZnxNormalizeInplace + + VecZnxAddScalarInplace + + VecZnxNormalizeTmpBytes, + Scratch: ScratchTakeCore, +{ + fn ggsw_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGSWInfos, + { + self.glwe_encrypt_sk_tmp_bytes(infos) + .max(self.vec_znx_normalize_tmp_bytes()) + + GLWEPlaintext::bytes_of_from_infos(infos) + } + + fn ggsw_encrypt_sk( + &self, + res: &mut R, + pt: &P, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGSWToMut, + P: ScalarZnxToRef, + S: GLWESecretPreparedToRef, + { + let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); + let pt: &ScalarZnx<&[u8]> = &pt.to_ref(); + let sk: &GLWESecretPrepared<&[u8], BE> = &sk.to_ref(); + + assert_eq!(res.rank(), sk.rank()); + assert_eq!(res.n(), self.n() as u32); + assert_eq!(pt.n(), self.n()); + assert_eq!(sk.n(), self.n() as u32); + + let k: usize = res.k().into(); + let base2k: usize = res.base2k().into(); + let rank: usize = res.rank().into(); + let dsize: usize = res.dsize().into(); + let cols: usize = rank + 1; + + let (mut tmp_pt, scratch_1) = scratch.take_glwe_plaintext(res); + + for row_i in 0..res.dnum().into() { + tmp_pt.data.zero(); + // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt + self.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (dsize - 1) + row_i * dsize, pt, 0); + self.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scratch_1); + for col_j in 0..rank + 1 { + self.glwe_encrypt_sk_internal( + base2k, + k, + res.at_mut(row_i, col_j).data_mut(), + cols, + false, + Some((&tmp_pt, col_j)), + sk, + source_xa, + source_xe, + SIGMA, + scratch_1, + ); + } + } + } +} diff --git a/poulpy-core/src/encryption/ggsw_ct.rs b/poulpy-core/src/encryption/ggsw_ct.rs deleted file mode 100644 index 6195458..0000000 --- a/poulpy-core/src/encryption/ggsw_ct.rs +++ /dev/null @@ -1,93 +0,0 @@ -use poulpy_hal::{ - api::{ - ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, - VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, - VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, - }, - layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, VecZnx, ZnxZero}, - source::Source, -}; - -use crate::{ - TakeGLWEPt, - layouts::{GGSWCiphertext, GGSWInfos, GLWECiphertext, GLWEInfos, LWEInfos, prepared::GLWESecretPrepared}, -}; - -impl GGSWCiphertext> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize - where - A: GGSWInfos, - Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes, - { - let size = infos.size(); - GLWECiphertext::encrypt_sk_scratch_space(module, &infos.glwe_layout()) - + VecZnx::alloc_bytes(module.n(), (infos.rank() + 1).into(), size) - + VecZnx::alloc_bytes(module.n(), 1, size) - + module.vec_znx_dft_alloc_bytes((infos.rank() + 1).into(), size) - } -} - -impl GGSWCiphertext { - #[allow(clippy::too_many_arguments)] - pub fn encrypt_sk( - &mut self, - module: &Module, - pt: &ScalarZnx, - sk: &GLWESecretPrepared, - source_xa: &mut Source, - source_xe: &mut Source, - scratch: &mut Scratch, - ) where - Module: VecZnxAddScalarInplace - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, - { - #[cfg(debug_assertions)] - { - use poulpy_hal::layouts::ZnxInfos; - - assert_eq!(self.rank(), sk.rank()); - assert_eq!(self.n(), sk.n()); - assert_eq!(pt.n() as u32, sk.n()); - } - - let base2k: usize = self.base2k().into(); - let rank: usize = self.rank().into(); - let dsize: usize = self.dsize().into(); - - let (mut tmp_pt, scratch_1) = scratch.take_glwe_pt(&self.glwe_layout()); - - (0..self.dnum().into()).for_each(|row_i| { - tmp_pt.data.zero(); - - // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt - module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (dsize - 1) + row_i * dsize, pt, 0); - module.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scratch_1); - - (0..rank + 1).for_each(|col_j| { - // rlwe encrypt of vec_znx_pt into vec_znx_ct - - self.at_mut(row_i, col_j).encrypt_sk_internal( - module, - Some((&tmp_pt, col_j)), - sk, - source_xa, - source_xe, - scratch_1, - ); - }); - }); - } -} diff --git a/poulpy-core/src/encryption/glwe.rs b/poulpy-core/src/encryption/glwe.rs new file mode 100644 index 0000000..c81833e --- /dev/null +++ b/poulpy-core/src/encryption/glwe.rs @@ -0,0 +1,562 @@ +use poulpy_hal::{ + api::{ + ModuleN, ScratchAvailable, ScratchTakeBasic, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolBytesOf, SvpPrepare, + VecZnxAddInplace, VecZnxAddNormal, VecZnxBigAddNormal, VecZnxBigAddSmallInplace, VecZnxBigBytesOf, VecZnxBigNormalize, + VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, + VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, + }, + layouts::{Backend, DataMut, Module, ScalarZnx, Scratch, VecZnx, VecZnxBig, VecZnxToMut, ZnxInfos, ZnxZero}, + source::Source, +}; + +use crate::{ + GetDistribution, ScratchTakeCore, + dist::Distribution, + encryption::{SIGMA, SIGMA_BOUND}, + layouts::{ + GLWE, GLWEInfos, GLWEPlaintext, GLWEPlaintextToRef, GLWEPrepared, GLWEPreparedToRef, GLWEToMut, LWEInfos, + prepared::{GLWESecretPrepared, GLWESecretPreparedToRef}, + }, +}; + +impl GLWE> { + pub fn encrypt_sk_tmp_bytes(module: &M, infos: &A) -> usize + where + A: GLWEInfos, + M: GLWEEncryptSk, + { + module.glwe_encrypt_sk_tmp_bytes(infos) + } + + pub fn encrypt_pk_tmp_bytes(module: &M, infos: &A) -> usize + where + A: GLWEInfos, + M: GLWEEncryptPk, + { + module.glwe_encrypt_pk_tmp_bytes(infos) + } +} + +impl GLWE { + pub fn encrypt_sk( + &mut self, + module: &M, + pt: &P, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + P: GLWEPlaintextToRef, + S: GLWESecretPreparedToRef, + M: GLWEEncryptSk, + Scratch: ScratchTakeCore, + { + module.glwe_encrypt_sk(self, pt, sk, source_xa, source_xe, scratch); + } + + pub fn encrypt_zero_sk( + &mut self, + module: &M, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + S: GLWESecretPreparedToRef, + M: GLWEEncryptSk, + Scratch: ScratchTakeCore, + { + module.glwe_encrypt_zero_sk(self, sk, source_xa, source_xe, scratch); + } + + pub fn encrypt_pk( + &mut self, + module: &M, + pt: &P, + pk: &K, + source_xu: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + P: GLWEPlaintextToRef + GLWEInfos, + K: GLWEPreparedToRef + GetDistribution + GLWEInfos, + M: GLWEEncryptPk, + { + module.glwe_encrypt_pk(self, pt, pk, source_xu, source_xe, scratch); + } + + pub fn encrypt_zero_pk( + &mut self, + module: &M, + pk: &K, + source_xu: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + K: GLWEPreparedToRef + GetDistribution + GLWEInfos, + M: GLWEEncryptPk, + { + module.glwe_encrypt_zero_pk(self, pk, source_xu, source_xe, scratch); + } +} + +pub trait GLWEEncryptSk { + fn glwe_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GLWEInfos; + + fn glwe_encrypt_sk( + &self, + res: &mut R, + pt: &P, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GLWEToMut, + P: GLWEPlaintextToRef, + S: GLWESecretPreparedToRef; + + fn glwe_encrypt_zero_sk( + &self, + res: &mut R, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GLWEToMut, + S: GLWESecretPreparedToRef; +} + +impl GLWEEncryptSk for Module +where + Self: Sized + ModuleN + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + GLWEEncryptSkInternal, + Scratch: ScratchAvailable, +{ + fn glwe_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GLWEInfos, + { + let size: usize = infos.size(); + assert_eq!(self.n() as u32, infos.n()); + self.vec_znx_normalize_tmp_bytes() + 2 * VecZnx::bytes_of(self.n(), 1, size) + self.bytes_of_vec_znx_dft(1, size) + } + + fn glwe_encrypt_sk( + &self, + res: &mut R, + pt: &P, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GLWEToMut, + P: GLWEPlaintextToRef, + S: GLWESecretPreparedToRef, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let pt: &GLWEPlaintext<&[u8]> = &pt.to_ref(); + let sk: &GLWESecretPrepared<&[u8], BE> = &sk.to_ref(); + + assert_eq!(res.rank(), sk.rank()); + assert_eq!(res.n(), self.n() as u32); + assert_eq!(sk.n(), self.n() as u32); + assert_eq!(pt.n(), self.n() as u32); + assert!( + scratch.available() >= self.glwe_encrypt_sk_tmp_bytes(res), + "scratch.available(): {} < GLWE::encrypt_sk_tmp_bytes: {}", + scratch.available(), + self.glwe_encrypt_sk_tmp_bytes(res) + ); + + let cols: usize = (res.rank() + 1).into(); + self.glwe_encrypt_sk_internal( + res.base2k().into(), + res.k().into(), + res.data_mut(), + cols, + false, + Some((pt, 0)), + sk, + source_xa, + source_xe, + SIGMA, + scratch, + ); + } + + fn glwe_encrypt_zero_sk( + &self, + res: &mut R, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GLWEToMut, + S: GLWESecretPreparedToRef, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let sk: &GLWESecretPrepared<&[u8], BE> = &sk.to_ref(); + + assert_eq!(res.rank(), sk.rank()); + assert_eq!(res.n(), self.n() as u32); + assert_eq!(sk.n(), self.n() as u32); + assert!( + scratch.available() >= self.glwe_encrypt_sk_tmp_bytes(res), + "scratch.available(): {} < GLWE::encrypt_sk_tmp_bytes: {}", + scratch.available(), + self.glwe_encrypt_sk_tmp_bytes(res) + ); + + let cols: usize = (res.rank() + 1).into(); + self.glwe_encrypt_sk_internal( + res.base2k().into(), + res.k().into(), + res.data_mut(), + cols, + false, + None::<(&GLWEPlaintext>, usize)>, + sk, + source_xa, + source_xe, + SIGMA, + scratch, + ); + } +} + +pub trait GLWEEncryptPk { + fn glwe_encrypt_pk_tmp_bytes(&self, infos: &A) -> usize + where + A: GLWEInfos; + + fn glwe_encrypt_pk( + &self, + res: &mut R, + pt: &P, + pk: &K, + source_xu: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GLWEToMut, + P: GLWEPlaintextToRef + GLWEInfos, + K: GLWEPreparedToRef + GetDistribution + GLWEInfos; + + fn glwe_encrypt_zero_pk( + &self, + res: &mut R, + pk: &K, + source_xu: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GLWEToMut, + K: GLWEPreparedToRef + GetDistribution + GLWEInfos; +} + +impl GLWEEncryptPk for Module +where + Self: GLWEEncryptPkInternal + VecZnxDftBytesOf + SvpPPolBytesOf + VecZnxBigBytesOf + VecZnxNormalizeTmpBytes, +{ + fn glwe_encrypt_pk_tmp_bytes(&self, infos: &A) -> usize + where + A: GLWEInfos, + { + let size: usize = infos.size(); + assert_eq!(self.n() as u32, infos.n()); + ((self.bytes_of_vec_znx_dft(1, size) + self.bytes_of_vec_znx_big(1, size)).max(ScalarZnx::bytes_of(self.n(), 1))) + + self.bytes_of_svp_ppol(1) + + self.vec_znx_normalize_tmp_bytes() + } + + fn glwe_encrypt_pk( + &self, + res: &mut R, + pt: &P, + pk: &K, + source_xu: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GLWEToMut, + P: GLWEPlaintextToRef + GLWEInfos, + K: GLWEPreparedToRef + GetDistribution + GLWEInfos, + { + self.glwe_encrypt_pk_internal(res, Some((pt, 0)), pk, source_xu, source_xe, scratch); + } + + fn glwe_encrypt_zero_pk( + &self, + res: &mut R, + pk: &K, + source_xu: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GLWEToMut, + K: GLWEPreparedToRef + GetDistribution + GLWEInfos, + { + self.glwe_encrypt_pk_internal( + res, + None::<(&GLWEPlaintext>, usize)>, + pk, + source_xu, + source_xe, + scratch, + ); + } +} + +pub(crate) trait GLWEEncryptPkInternal { + fn glwe_encrypt_pk_internal( + &self, + res: &mut R, + pt: Option<(&P, usize)>, + pk: &K, + source_xu: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GLWEToMut, + P: GLWEPlaintextToRef + GLWEInfos, + K: GLWEPreparedToRef + GetDistribution + GLWEInfos; +} + +impl GLWEEncryptPkInternal for Module +where + Self: SvpPrepare + + SvpApplyDftToDft + + VecZnxIdftApplyConsume + + VecZnxBigAddNormal + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize + + SvpPPolBytesOf + + ModuleN + + VecZnxDftBytesOf, + Scratch: ScratchTakeBasic, +{ + fn glwe_encrypt_pk_internal( + &self, + res: &mut R, + pt: Option<(&P, usize)>, + pk: &K, + source_xu: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GLWEToMut, + P: GLWEPlaintextToRef + GLWEInfos, + K: GLWEPreparedToRef + GetDistribution + GLWEInfos, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + + assert_eq!(res.base2k(), pk.base2k()); + assert_eq!(res.n(), pk.n()); + assert_eq!(res.rank(), pk.rank()); + if let Some((pt, _)) = pt { + assert_eq!(pt.base2k(), pk.base2k()); + assert_eq!(pt.n(), pk.n()); + } + + let base2k: usize = pk.base2k().into(); + let size_pk: usize = pk.size(); + let cols: usize = (res.rank() + 1).into(); + + // Generates u according to the underlying secret distribution. + let (mut u_dft, scratch_1) = scratch.take_svp_ppol(self, 1); + + { + let (mut u, _) = scratch_1.take_scalar_znx(self.n(), 1); + match pk.dist() { + Distribution::NONE => panic!( + "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through \ + Self::generate" + ), + Distribution::TernaryFixed(hw) => u.fill_ternary_hw(0, *hw, source_xu), + Distribution::TernaryProb(prob) => u.fill_ternary_prob(0, *prob, source_xu), + Distribution::BinaryFixed(hw) => u.fill_binary_hw(0, *hw, source_xu), + Distribution::BinaryProb(prob) => u.fill_binary_prob(0, *prob, source_xu), + Distribution::BinaryBlock(block_size) => u.fill_binary_block(0, *block_size, source_xu), + Distribution::ZERO => {} + } + + self.svp_prepare(&mut u_dft, 0, &u, 0); + } + + { + let pk: &GLWEPrepared<&[u8], BE> = &pk.to_ref(); + + // ct[i] = pk[i] * u + ei (+ m if col = i) + for i in 0..cols { + let (mut ci_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, 1, size_pk); + // ci_dft = DFT(u) * DFT(pk[i]) + self.svp_apply_dft_to_dft(&mut ci_dft, 0, &u_dft, 0, &pk.data, i); + + // ci_big = u * p[i] + let mut ci_big = self.vec_znx_idft_apply_consume(ci_dft); + + // ci_big = u * pk[i] + e + self.vec_znx_big_add_normal( + base2k, + &mut ci_big, + 0, + pk.k().into(), + source_xe, + SIGMA, + SIGMA_BOUND, + ); + + // ci_big = u * pk[i] + e + m (if col = i) + if let Some((pt, col)) = pt + && col == i + { + self.vec_znx_big_add_small_inplace(&mut ci_big, 0, &pt.to_ref().data, 0); + } + + // ct[i] = norm(ci_big) + self.vec_znx_big_normalize(base2k, &mut res.data, i, base2k, &ci_big, 0, scratch_2); + } + } + } +} + +pub(crate) trait GLWEEncryptSkInternal { + #[allow(clippy::too_many_arguments)] + fn glwe_encrypt_sk_internal( + &self, + base2k: usize, + k: usize, + res: &mut R, + cols: usize, + compressed: bool, + pt: Option<(&P, usize)>, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + P: GLWEPlaintextToRef, + S: GLWESecretPreparedToRef; +} + +impl GLWEEncryptSkInternal for Module +where + Self: ModuleN + + VecZnxDftBytesOf + + VecZnxBigNormalize + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxNormalizeTmpBytes + + VecZnxFillUniform + + VecZnxSubInplace + + VecZnxAddInplace + + VecZnxNormalizeInplace + + VecZnxAddNormal + + VecZnxNormalize + + VecZnxSub, + Scratch: ScratchTakeBasic, +{ + fn glwe_encrypt_sk_internal( + &self, + base2k: usize, + k: usize, + res: &mut R, + cols: usize, + compressed: bool, + pt: Option<(&P, usize)>, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + P: GLWEPlaintextToRef, + S: GLWESecretPreparedToRef, + { + let ct: &mut VecZnx<&mut [u8]> = &mut res.to_mut(); + let sk: GLWESecretPrepared<&[u8], BE> = sk.to_ref(); + + if compressed { + assert_eq!( + ct.cols(), + 1, + "invalid glwe: compressed tag=true but #cols={} != 1", + ct.cols() + ) + } + + assert!( + sk.dist != Distribution::NONE, + "glwe secret distribution is NONE (have you prepared the key?)" + ); + + let size: usize = ct.size(); + + let (mut c0, scratch_1) = scratch.take_vec_znx(self.n(), 1, size); + c0.zero(); + + { + let (mut ci, scratch_2) = scratch_1.take_vec_znx(self.n(), 1, size); + + // ct[i] = uniform + // ct[0] -= c[i] * s[i], + (1..cols).for_each(|i| { + let col_ct: usize = if compressed { 0 } else { i }; + + // ct[i] = uniform (+ pt) + self.vec_znx_fill_uniform(base2k, ct, col_ct, source_xa); + + // println!("vec_znx_fill_uniform: {}", ct); + + let (mut ci_dft, scratch_3) = scratch_2.take_vec_znx_dft(self, 1, size); + + // ci = ct[i] - pt + // i.e. we act as we sample ct[i] already as uniform + pt + // and if there is a pt, then we subtract it before applying DFT + if let Some((pt, col)) = pt { + if i == col { + self.vec_znx_sub(&mut ci, 0, ct, col_ct, &pt.to_ref().data, 0); + self.vec_znx_normalize_inplace(base2k, &mut ci, 0, scratch_3); + self.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, &ci, 0); + } else { + self.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, ct, col_ct); + } + } else { + self.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, ct, col_ct); + } + + self.svp_apply_dft_to_dft_inplace(&mut ci_dft, 0, &sk.data, i - 1); + let ci_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(ci_dft); + + // use c[0] as buffer, which is overwritten later by the normalization step + self.vec_znx_big_normalize(base2k, &mut ci, 0, base2k, &ci_big, 0, scratch_3); + + // c0_tmp = -c[i] * s[i] (use c[0] as buffer) + self.vec_znx_sub_inplace(&mut c0, 0, &ci, 0); + }); + } + + // c[0] += e + self.vec_znx_add_normal(base2k, &mut c0, 0, k, source_xe, sigma, SIGMA_BOUND); + + // c[0] += m if col = 0 + if let Some((pt, col)) = pt + && col == 0 + { + self.vec_znx_add_inplace(&mut c0, 0, &pt.to_ref().data, 0); + } + + // c[0] = norm(c[0]) + self.vec_znx_normalize(base2k, ct, 0, base2k, &c0, 0, scratch_1); + } +} diff --git a/poulpy-core/src/encryption/glwe_automorphism_key.rs b/poulpy-core/src/encryption/glwe_automorphism_key.rs new file mode 100644 index 0000000..4feaeb6 --- /dev/null +++ b/poulpy-core/src/encryption/glwe_automorphism_key.rs @@ -0,0 +1,163 @@ +use poulpy_hal::{ + api::{ScratchAvailable, SvpPPolBytesOf, VecZnxAutomorphism}, + layouts::{Backend, DataMut, GaloisElement, Module, Scratch}, + source::Source, +}; + +use crate::{ + GGLWEEncryptSk, ScratchTakeCore, + layouts::{ + GGLWEInfos, GGLWEToMut, GGLWEToRef, GLWEAutomorphismKey, GLWEInfos, GLWESecret, GLWESecretPrepared, + GLWESecretPreparedFactory, GLWESecretToRef, LWEInfos, SetGaloisElement, + }, +}; + +impl GLWEAutomorphismKey> { + pub fn encrypt_sk_tmp_bytes(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: GLWEAutomorphismKeyEncryptSk, + { + module.glwe_automorphism_key_encrypt_sk_tmp_bytes(infos) + } + + pub fn encrypt_pk_tmp_bytes(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: GLWEAutomorphismKeyEncryptPk, + { + module.glwe_automorphism_key_encrypt_pk_tmp_bytes(infos) + } +} + +impl GLWEAutomorphismKey +where + Self: GGLWEToRef, +{ + pub fn encrypt_sk( + &mut self, + module: &M, + p: i64, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + S: GLWESecretToRef, + M: GLWEAutomorphismKeyEncryptSk, + { + module.glwe_automorphism_key_encrypt_sk(self, p, sk, source_xa, source_xe, scratch); + } +} + +pub trait GLWEAutomorphismKeyEncryptSk { + fn glwe_automorphism_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos; + + fn glwe_automorphism_key_encrypt_sk( + &self, + res: &mut R, + p: i64, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGLWEToMut + SetGaloisElement + GGLWEInfos, + S: GLWESecretToRef; +} + +impl GLWEAutomorphismKeyEncryptSk for Module +where + Self: GGLWEEncryptSk + VecZnxAutomorphism + GaloisElement + SvpPPolBytesOf + GLWESecretPreparedFactory, + Scratch: ScratchTakeCore, +{ + fn glwe_automorphism_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + assert_eq!( + infos.rank_in(), + infos.rank_out(), + "rank_in != rank_out is not supported for GGLWEAutomorphismKey" + ); + GLWESecretPrepared::bytes_of_from_infos(self, infos) + + self + .gglwe_encrypt_sk_tmp_bytes(infos) + .max(GLWESecret::bytes_of_from_infos(infos)) + } + + fn glwe_automorphism_key_encrypt_sk( + &self, + res: &mut R, + p: i64, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGLWEToMut + SetGaloisElement + GGLWEInfos, + S: GLWESecretToRef, + { + let sk: &GLWESecret<&[u8]> = &sk.to_ref(); + + assert_eq!(res.n(), sk.n()); + assert_eq!(res.rank_out(), res.rank_in()); + assert_eq!(sk.rank(), res.rank_out()); + assert!( + scratch.available() >= self.glwe_automorphism_key_encrypt_sk_tmp_bytes(res), + "scratch.available(): {} < AutomorphismKey::encrypt_sk_tmp_bytes: {:?}", + scratch.available(), + self.glwe_automorphism_key_encrypt_sk_tmp_bytes(res) + ); + + let (mut sk_out_prepared, scratch_1) = scratch.take_glwe_secret_prepared(self, sk.rank()); + + { + let (mut sk_out, _) = scratch_1.take_glwe_secret(sk.n(), sk.rank()); + sk_out.dist = sk.dist; + + for i in 0..sk.rank().into() { + self.vec_znx_automorphism( + self.galois_element_inv(p), + &mut sk_out.data.as_vec_znx_mut(), + i, + &sk.data.as_vec_znx(), + i, + ); + } + sk_out_prepared.prepare(self, &sk_out); + } + + self.gglwe_encrypt_sk( + res, + &sk.data, + &sk_out_prepared, + source_xa, + source_xe, + scratch_1, + ); + + res.set_p(p); + } +} + +pub trait GLWEAutomorphismKeyEncryptPk { + fn glwe_automorphism_key_encrypt_pk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos; +} + +impl GLWEAutomorphismKeyEncryptPk for Module +where + Self:, + Scratch: ScratchTakeCore, +{ + fn glwe_automorphism_key_encrypt_pk_tmp_bytes(&self, _infos: &A) -> usize + where + A: GGLWEInfos, + { + unimplemented!() + } +} diff --git a/poulpy-core/src/encryption/glwe_ct.rs b/poulpy-core/src/encryption/glwe_ct.rs deleted file mode 100644 index 8ecacc6..0000000 --- a/poulpy-core/src/encryption/glwe_ct.rs +++ /dev/null @@ -1,407 +0,0 @@ -use poulpy_hal::{ - api::{ - ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeSvpPPol, - TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxBigAddNormal, VecZnxBigAddSmallInplace, - VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, - VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, - }, - layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, VecZnx, VecZnxBig, ZnxInfos, ZnxZero}, - source::Source, -}; - -use crate::{ - dist::Distribution, - encryption::{SIGMA, SIGMA_BOUND}, - layouts::{ - GLWECiphertext, GLWEInfos, GLWEPlaintext, LWEInfos, - prepared::{GLWEPublicKeyPrepared, GLWESecretPrepared}, - }, -}; - -impl GLWECiphertext> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize - where - A: GLWEInfos, - Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes, - { - let size: usize = infos.size(); - assert_eq!(module.n() as u32, infos.n()); - module.vec_znx_normalize_tmp_bytes() - + 2 * VecZnx::alloc_bytes(module.n(), 1, size) - + module.vec_znx_dft_alloc_bytes(1, size) - } - pub fn encrypt_pk_scratch_space(module: &Module, infos: &A) -> usize - where - A: GLWEInfos, - Module: VecZnxDftAllocBytes + SvpPPolAllocBytes + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes, - { - let size: usize = infos.size(); - assert_eq!(module.n() as u32, infos.n()); - ((module.vec_znx_dft_alloc_bytes(1, size) + module.vec_znx_big_alloc_bytes(1, size)) - | ScalarZnx::alloc_bytes(module.n(), 1)) - + module.svp_ppol_alloc_bytes(1) - + module.vec_znx_normalize_tmp_bytes() - } -} - -impl GLWECiphertext { - #[allow(clippy::too_many_arguments)] - pub fn encrypt_sk( - &mut self, - module: &Module, - pt: &GLWEPlaintext, - sk: &GLWESecretPrepared, - source_xa: &mut Source, - source_xe: &mut Source, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, - { - #[cfg(debug_assertions)] - { - assert_eq!(self.rank(), sk.rank()); - assert_eq!(sk.n(), self.n()); - assert_eq!(pt.n(), self.n()); - assert!( - scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self), - "scratch.available(): {} < GLWECiphertext::encrypt_sk_scratch_space: {}", - scratch.available(), - GLWECiphertext::encrypt_sk_scratch_space(module, self) - ) - } - - self.encrypt_sk_internal(module, Some((pt, 0)), sk, source_xa, source_xe, scratch); - } - - pub fn encrypt_zero_sk( - &mut self, - module: &Module, - sk: &GLWESecretPrepared, - source_xa: &mut Source, - source_xe: &mut Source, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, - { - #[cfg(debug_assertions)] - { - assert_eq!(self.rank(), sk.rank()); - assert_eq!(sk.n(), self.n()); - assert!( - scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self), - "scratch.available(): {} < GLWECiphertext::encrypt_sk_scratch_space: {}", - scratch.available(), - GLWECiphertext::encrypt_sk_scratch_space(module, self) - ) - } - self.encrypt_sk_internal( - module, - None::<(&GLWEPlaintext>, usize)>, - sk, - source_xa, - source_xe, - scratch, - ); - } - - #[allow(clippy::too_many_arguments)] - pub(crate) fn encrypt_sk_internal( - &mut self, - module: &Module, - pt: Option<(&GLWEPlaintext, usize)>, - sk: &GLWESecretPrepared, - source_xa: &mut Source, - source_xe: &mut Source, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, - { - let cols: usize = (self.rank() + 1).into(); - glwe_encrypt_sk_internal( - module, - self.base2k().into(), - self.k().into(), - &mut self.data, - cols, - false, - pt, - sk, - source_xa, - source_xe, - SIGMA, - scratch, - ); - } - - #[allow(clippy::too_many_arguments)] - pub fn encrypt_pk( - &mut self, - module: &Module, - pt: &GLWEPlaintext, - pk: &GLWEPublicKeyPrepared, - source_xu: &mut Source, - source_xe: &mut Source, - scratch: &mut Scratch, - ) where - Module: SvpPrepare - + SvpApplyDftToDft - + VecZnxIdftApplyConsume - + VecZnxBigAddNormal - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize, - Scratch: TakeSvpPPol + TakeScalarZnx + TakeVecZnxDft, - { - self.encrypt_pk_internal::(module, Some((pt, 0)), pk, source_xu, source_xe, scratch); - } - - pub fn encrypt_zero_pk( - &mut self, - module: &Module, - pk: &GLWEPublicKeyPrepared, - source_xu: &mut Source, - source_xe: &mut Source, - scratch: &mut Scratch, - ) where - Module: SvpPrepare - + SvpApplyDftToDft - + VecZnxIdftApplyConsume - + VecZnxBigAddNormal - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize, - Scratch: TakeSvpPPol + TakeScalarZnx + TakeVecZnxDft, - { - self.encrypt_pk_internal::, DataPk, B>( - module, - None::<(&GLWEPlaintext>, usize)>, - pk, - source_xu, - source_xe, - scratch, - ); - } - - #[allow(clippy::too_many_arguments)] - pub(crate) fn encrypt_pk_internal( - &mut self, - module: &Module, - pt: Option<(&GLWEPlaintext, usize)>, - pk: &GLWEPublicKeyPrepared, - source_xu: &mut Source, - source_xe: &mut Source, - scratch: &mut Scratch, - ) where - Module: SvpPrepare - + SvpApplyDftToDft - + VecZnxIdftApplyConsume - + VecZnxBigAddNormal - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize, - Scratch: TakeSvpPPol + TakeScalarZnx + TakeVecZnxDft, - { - #[cfg(debug_assertions)] - { - assert_eq!(self.base2k(), pk.base2k()); - assert_eq!(self.n(), pk.n()); - assert_eq!(self.rank(), pk.rank()); - if let Some((pt, _)) = pt { - assert_eq!(pt.base2k(), pk.base2k()); - assert_eq!(pt.n(), pk.n()); - } - } - - let base2k: usize = pk.base2k().into(); - let size_pk: usize = pk.size(); - let cols: usize = (self.rank() + 1).into(); - - // Generates u according to the underlying secret distribution. - let (mut u_dft, scratch_1) = scratch.take_svp_ppol(self.n().into(), 1); - - { - let (mut u, _) = scratch_1.take_scalar_znx(self.n().into(), 1); - match pk.dist { - Distribution::NONE => panic!( - "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through \ - Self::generate" - ), - Distribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu), - Distribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu), - Distribution::BinaryFixed(hw) => u.fill_binary_hw(0, hw, source_xu), - Distribution::BinaryProb(prob) => u.fill_binary_prob(0, prob, source_xu), - Distribution::BinaryBlock(block_size) => u.fill_binary_block(0, block_size, source_xu), - Distribution::ZERO => {} - } - - module.svp_prepare(&mut u_dft, 0, &u, 0); - } - - // ct[i] = pk[i] * u + ei (+ m if col = i) - (0..cols).for_each(|i| { - let (mut ci_dft, scratch_2) = scratch_1.take_vec_znx_dft(self.n().into(), 1, size_pk); - // ci_dft = DFT(u) * DFT(pk[i]) - module.svp_apply_dft_to_dft(&mut ci_dft, 0, &u_dft, 0, &pk.data, i); - - // ci_big = u * p[i] - let mut ci_big = module.vec_znx_idft_apply_consume(ci_dft); - - // ci_big = u * pk[i] + e - module.vec_znx_big_add_normal( - base2k, - &mut ci_big, - 0, - pk.k().into(), - source_xe, - SIGMA, - SIGMA_BOUND, - ); - - // ci_big = u * pk[i] + e + m (if col = i) - if let Some((pt, col)) = pt - && col == i - { - module.vec_znx_big_add_small_inplace(&mut ci_big, 0, &pt.data, 0); - } - - // ct[i] = norm(ci_big) - module.vec_znx_big_normalize(base2k, &mut self.data, i, base2k, &ci_big, 0, scratch_2); - }); - } -} - -#[allow(clippy::too_many_arguments)] -pub(crate) fn glwe_encrypt_sk_internal( - module: &Module, - base2k: usize, - k: usize, - ct: &mut VecZnx, - cols: usize, - compressed: bool, - pt: Option<(&GLWEPlaintext, usize)>, - sk: &GLWESecretPrepared, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - scratch: &mut Scratch, -) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, -{ - #[cfg(debug_assertions)] - { - if compressed { - assert_eq!( - ct.cols(), - 1, - "invalid ciphertext: compressed tag=true but #cols={} != 1", - ct.cols() - ) - } - } - - let size: usize = ct.size(); - - let (mut c0, scratch_1) = scratch.take_vec_znx(ct.n(), 1, size); - c0.zero(); - - { - let (mut ci, scratch_2) = scratch_1.take_vec_znx(ct.n(), 1, size); - - // ct[i] = uniform - // ct[0] -= c[i] * s[i], - (1..cols).for_each(|i| { - let col_ct: usize = if compressed { 0 } else { i }; - - // ct[i] = uniform (+ pt) - module.vec_znx_fill_uniform(base2k, ct, col_ct, source_xa); - - let (mut ci_dft, scratch_3) = scratch_2.take_vec_znx_dft(ct.n(), 1, size); - - // ci = ct[i] - pt - // i.e. we act as we sample ct[i] already as uniform + pt - // and if there is a pt, then we subtract it before applying DFT - if let Some((pt, col)) = pt { - if i == col { - module.vec_znx_sub(&mut ci, 0, ct, col_ct, &pt.data, 0); - module.vec_znx_normalize_inplace(base2k, &mut ci, 0, scratch_3); - module.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, &ci, 0); - } else { - module.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, ct, col_ct); - } - } else { - module.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, ct, col_ct); - } - - module.svp_apply_dft_to_dft_inplace(&mut ci_dft, 0, &sk.data, i - 1); - let ci_big: VecZnxBig<&mut [u8], B> = module.vec_znx_idft_apply_consume(ci_dft); - - // use c[0] as buffer, which is overwritten later by the normalization step - module.vec_znx_big_normalize(base2k, &mut ci, 0, base2k, &ci_big, 0, scratch_3); - - // c0_tmp = -c[i] * s[i] (use c[0] as buffer) - module.vec_znx_sub_inplace(&mut c0, 0, &ci, 0); - }); - } - - // c[0] += e - module.vec_znx_add_normal(base2k, &mut c0, 0, k, source_xe, sigma, SIGMA_BOUND); - - // c[0] += m if col = 0 - if let Some((pt, col)) = pt - && col == 0 - { - module.vec_znx_add_inplace(&mut c0, 0, &pt.data, 0); - } - - // c[0] = norm(c[0]) - module.vec_znx_normalize(base2k, ct, 0, base2k, &c0, 0, scratch_1); -} diff --git a/poulpy-core/src/encryption/glwe_pk.rs b/poulpy-core/src/encryption/glwe_pk.rs deleted file mode 100644 index c7cdaeb..0000000 --- a/poulpy-core/src/encryption/glwe_pk.rs +++ /dev/null @@ -1,61 +0,0 @@ -use poulpy_hal::{ - api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, VecZnxAddInplace, VecZnxAddNormal, VecZnxBigNormalize, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, - }, - layouts::{Backend, DataMut, DataRef, Module, ScratchOwned}, - oep::{ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxDftImpl, TakeVecZnxImpl}, - source::Source, -}; - -use crate::layouts::{GLWECiphertext, GLWEPublicKey, prepared::GLWESecretPrepared}; - -impl GLWEPublicKey { - pub fn generate_from_sk( - &mut self, - module: &Module, - sk: &GLWESecretPrepared, - source_xa: &mut Source, - source_xe: &mut Source, - ) where - Module:, - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - B: Backend - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + TakeVecZnxDftImpl - + ScratchAvailableImpl - + TakeVecZnxImpl, - { - #[cfg(debug_assertions)] - { - use crate::{Distribution, layouts::LWEInfos}; - - assert_eq!(self.n(), sk.n()); - - if sk.dist == Distribution::NONE { - panic!("invalid sk: SecretDistribution::NONE") - } - } - - // Its ok to allocate scratch space here since pk is usually generated only once. - let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWECiphertext::encrypt_sk_scratch_space(module, self)); - - let mut tmp: GLWECiphertext> = GLWECiphertext::alloc(self); - tmp.encrypt_zero_sk(module, sk, source_xa, source_xe, scratch.borrow()); - self.dist = sk.dist; - } -} diff --git a/poulpy-core/src/encryption/glwe_public_key.rs b/poulpy-core/src/encryption/glwe_public_key.rs new file mode 100644 index 0000000..460b037 --- /dev/null +++ b/poulpy-core/src/encryption/glwe_public_key.rs @@ -0,0 +1,62 @@ +use poulpy_hal::{ + api::{ScratchOwnedAlloc, ScratchOwnedBorrow}, + layouts::{Backend, DataMut, Module, Scratch, ScratchOwned}, + source::Source, +}; + +use crate::{ + Distribution, GLWEEncryptSk, GetDistribution, GetDistributionMut, ScratchTakeCore, + layouts::{ + GLWE, GLWEInfos, GLWEPublicKey, GLWEToMut, + prepared::{GLWESecretPrepared, GLWESecretPreparedToRef}, + }, +}; + +impl GLWEPublicKey { + pub fn generate(&mut self, module: &M, sk: &S, source_xa: &mut Source, source_xe: &mut Source) + where + S: GLWESecretPreparedToRef + GetDistribution, + M: GLWEPublicKeyGenerate, + { + module.glwe_public_key_generate(self, sk, source_xa, source_xe); + } +} + +pub trait GLWEPublicKeyGenerate { + fn glwe_public_key_generate(&self, res: &mut R, sk: &S, source_xa: &mut Source, source_xe: &mut Source) + where + R: GLWEToMut + GetDistributionMut + GLWEInfos, + S: GLWESecretPreparedToRef + GetDistribution; +} + +impl GLWEPublicKeyGenerate for Module +where + Self: GLWEEncryptSk, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchTakeCore, +{ + fn glwe_public_key_generate(&self, res: &mut R, sk: &S, source_xa: &mut Source, source_xe: &mut Source) + where + R: GLWEToMut + GetDistributionMut + GLWEInfos, + S: GLWESecretPreparedToRef + GetDistribution, + { + { + let sk: &GLWESecretPrepared<&[u8], BE> = &sk.to_ref(); + + assert_eq!(res.n(), self.n() as u32); + assert_eq!(sk.n(), self.n() as u32); + + if sk.dist == Distribution::NONE { + panic!("invalid sk: SecretDistribution::NONE") + } + + // Its ok to allocate scratch space here since pk is usually generated only once. + let mut scratch: ScratchOwned = ScratchOwned::alloc(self.glwe_encrypt_sk_tmp_bytes(res)); + + let mut tmp: GLWE> = GLWE::alloc_from_infos(res); + + tmp.encrypt_zero_sk(self, sk, source_xa, source_xe, scratch.borrow()); + } + *res.dist_mut() = *sk.dist(); + } +} diff --git a/poulpy-core/src/encryption/glwe_switching_key.rs b/poulpy-core/src/encryption/glwe_switching_key.rs new file mode 100644 index 0000000..e1467f3 --- /dev/null +++ b/poulpy-core/src/encryption/glwe_switching_key.rs @@ -0,0 +1,160 @@ +use poulpy_hal::{ + api::{ModuleN, ScratchAvailable, ScratchTakeBasic, SvpPrepare, VecZnxSwitchRing}, + layouts::{Backend, DataMut, Module, ScalarZnx, Scratch}, + source::Source, +}; + +use crate::{ + ScratchTakeCore, + encryption::gglwe::GGLWEEncryptSk, + layouts::{ + GGLWEInfos, GGLWEToMut, GLWEInfos, GLWESecret, GLWESecretToRef, GLWESwitchingKey, GLWESwitchingKeyDegreesMut, LWEInfos, + prepared::GLWESecretPreparedFactory, + }, +}; + +impl GLWESwitchingKey> { + pub fn encrypt_sk_tmp_bytes(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: GLWESwitchingKeyEncryptSk, + { + module.glwe_switching_key_encrypt_sk_tmp_bytes(infos) + } + + pub fn encrypt_pk_tmp_bytes(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: GLWESwitchingKeyEncryptPk, + { + module.glwe_switching_key_encrypt_pk_tmp_bytes(infos) + } +} + +impl GLWESwitchingKey { + pub fn encrypt_sk( + &mut self, + module: &M, + sk_in: &S1, + sk_out: &S2, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + S1: GLWESecretToRef, + S2: GLWESecretToRef, + M: GLWESwitchingKeyEncryptSk, + Scratch: ScratchTakeCore, + { + module.glwe_switching_key_encrypt_sk(self, sk_in, sk_out, source_xa, source_xe, scratch); + } +} + +pub trait GLWESwitchingKeyEncryptSk { + fn glwe_switching_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos; + + fn glwe_switching_key_encrypt_sk( + &self, + res: &mut R, + sk_in: &S1, + sk_out: &S2, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGLWEToMut + GLWESwitchingKeyDegreesMut + GGLWEInfos, + S1: GLWESecretToRef, + S2: GLWESecretToRef; +} + +impl GLWESwitchingKeyEncryptSk for Module +where + Self: ModuleN + GGLWEEncryptSk + GLWESecretPreparedFactory + VecZnxSwitchRing + SvpPrepare, + Scratch: ScratchTakeCore, +{ + fn glwe_switching_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + self.gglwe_encrypt_sk_tmp_bytes(infos) + .max(ScalarZnx::bytes_of(self.n(), 1)) + + ScalarZnx::bytes_of(self.n(), infos.rank_in().into()) + + self.bytes_of_glwe_secret_prepared_from_infos(infos) + } + + fn glwe_switching_key_encrypt_sk( + &self, + res: &mut R, + sk_in: &S1, + sk_out: &S2, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGLWEToMut + GLWESwitchingKeyDegreesMut + GGLWEInfos, + S1: GLWESecretToRef, + S2: GLWESecretToRef, + { + let sk_in: &GLWESecret<&[u8]> = &sk_in.to_ref(); + let sk_out: &GLWESecret<&[u8]> = &sk_out.to_ref(); + + assert!(sk_in.n().0 <= self.n() as u32); + assert!(sk_out.n().0 <= self.n() as u32); + assert!( + scratch.available() >= self.glwe_switching_key_encrypt_sk_tmp_bytes(res), + "scratch.available()={} < GLWESwitchingKey::encrypt_sk_tmp_bytes={}", + scratch.available(), + self.glwe_switching_key_encrypt_sk_tmp_bytes(res) + ); + + let (mut sk_in_tmp, scratch_1) = scratch.take_scalar_znx(self.n(), sk_in.rank().into()); + for i in 0..sk_in.rank().into() { + self.vec_znx_switch_ring( + &mut sk_in_tmp.as_vec_znx_mut(), + i, + &sk_in.data.as_vec_znx(), + i, + ); + } + + let (mut sk_out_tmp, scratch_2) = scratch_1.take_glwe_secret_prepared(self, sk_out.rank()); + { + let (mut tmp, _) = scratch_2.take_scalar_znx(self.n(), 1); + for i in 0..sk_out.rank().into() { + self.vec_znx_switch_ring(&mut tmp.as_vec_znx_mut(), 0, &sk_out.data.as_vec_znx(), i); + self.svp_prepare(&mut sk_out_tmp.data, i, &tmp, 0); + } + } + + sk_out_tmp.dist = sk_out.dist; + + self.gglwe_encrypt_sk( + res, + &sk_in_tmp, + &sk_out_tmp, + source_xa, + source_xe, + scratch_2, + ); + + *res.input_degree() = sk_in.n(); + *res.output_degree() = sk_out.n(); + } +} + +pub trait GLWESwitchingKeyEncryptPk { + fn glwe_switching_key_encrypt_pk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos; +} + +impl GLWESwitchingKeyEncryptPk for Module { + fn glwe_switching_key_encrypt_pk_tmp_bytes(&self, _infos: &A) -> usize + where + A: GGLWEInfos, + { + unimplemented!() + } +} diff --git a/poulpy-core/src/encryption/glwe_tensor_key.rs b/poulpy-core/src/encryption/glwe_tensor_key.rs new file mode 100644 index 0000000..b7afae5 --- /dev/null +++ b/poulpy-core/src/encryption/glwe_tensor_key.rs @@ -0,0 +1,147 @@ +use poulpy_hal::{ + api::{ + ModuleN, ScratchTakeBasic, SvpApplyDftToDft, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, + VecZnxIdftApplyTmpA, + }, + layouts::{Backend, DataMut, Module, Scratch}, + source::Source, +}; + +use crate::{ + GGLWEEncryptSk, GetDistribution, ScratchTakeCore, + layouts::{ + GGLWE, GGLWEInfos, GLWEInfos, GLWESecret, GLWESecretToRef, GLWETensorKey, GLWETensorKeyToMut, LWEInfos, Rank, + prepared::{GLWESecretPrepared, GLWESecretPreparedFactory}, + }, +}; + +impl GLWETensorKey> { + pub fn encrypt_sk_tmp_bytes(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: GLWETensorKeyEncryptSk, + { + module.glwe_tensor_key_encrypt_sk_tmp_bytes(infos) + } +} + +impl GLWETensorKey { + pub fn encrypt_sk( + &mut self, + module: &M, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + M: GLWETensorKeyEncryptSk, + S: GLWESecretToRef + GetDistribution + GLWEInfos, + Scratch: ScratchTakeCore, + { + module.glwe_tensor_key_encrypt_sk(self, sk, source_xa, source_xe, scratch); + } +} + +pub trait GLWETensorKeyEncryptSk { + fn glwe_tensor_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos; + + fn glwe_tensor_key_encrypt_sk( + &self, + res: &mut R, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GLWETensorKeyToMut, + S: GLWESecretToRef + GetDistribution + GLWEInfos; +} + +impl GLWETensorKeyEncryptSk for Module +where + Self: ModuleN + + GGLWEEncryptSk + + VecZnxDftBytesOf + + VecZnxBigBytesOf + + GLWESecretPreparedFactory + + VecZnxDftApply + + SvpApplyDftToDft + + VecZnxIdftApplyTmpA + + VecZnxBigNormalize, + Scratch: ScratchTakeCore, +{ + fn glwe_tensor_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + GLWESecretPrepared::bytes_of(self, infos.rank_out()) + + self.bytes_of_vec_znx_dft(infos.rank_out().into(), 1) + + self.bytes_of_vec_znx_big(1, 1) + + self.bytes_of_vec_znx_dft(1, 1) + + GLWESecret::bytes_of(self.n().into(), Rank(1)) + + GGLWE::encrypt_sk_tmp_bytes(self, infos) + } + + fn glwe_tensor_key_encrypt_sk( + &self, + res: &mut R, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GLWETensorKeyToMut, + S: GLWESecretToRef + GetDistribution + GLWEInfos, + { + let res: &mut GLWETensorKey<&mut [u8]> = &mut res.to_mut(); + + // let n: RingDegree = sk.n(); + let rank: Rank = res.rank_out(); + + let (mut sk_prepared, scratch_1) = scratch.take_glwe_secret_prepared(self, sk.rank()); + sk_prepared.prepare(self, sk); + + let sk: &GLWESecret<&[u8]> = &sk.to_ref(); + + assert_eq!(res.rank_out(), sk.rank()); + assert_eq!(res.n(), sk.n()); + + let (mut sk_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, rank.into(), 1); + + (0..rank.into()).for_each(|i| { + self.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); + }); + + let (mut sk_ij_big, scratch_3) = scratch_2.take_vec_znx_big(self, 1, 1); + let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(self.n().into(), Rank(1)); + let (mut sk_ij_dft, scratch_5) = scratch_4.take_vec_znx_dft(self, 1, 1); + + (0..rank.into()).for_each(|i| { + (i..rank.into()).for_each(|j| { + self.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_prepared.data, j, &sk_dft, i); + + self.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0); + self.vec_znx_big_normalize( + res.base2k().into(), + &mut sk_ij.data.as_vec_znx_mut(), + 0, + res.base2k().into(), + &sk_ij_big, + 0, + scratch_5, + ); + + res.at_mut(i, j).encrypt_sk( + self, + &sk_ij.data, + &sk_prepared, + source_xa, + source_xe, + scratch_5, + ); + }); + }) + } +} diff --git a/poulpy-core/src/encryption/glwe_to_lwe_ksk.rs b/poulpy-core/src/encryption/glwe_to_lwe_ksk.rs deleted file mode 100644 index b65ce4e..0000000 --- a/poulpy-core/src/encryption/glwe_to_lwe_ksk.rs +++ /dev/null @@ -1,81 +0,0 @@ -use poulpy_hal::{ - api::{ - ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigNormalize, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, - }, - layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut, ZnxZero}, - source::Source, -}; - -use crate::{ - TakeGLWESecret, TakeGLWESecretPrepared, - layouts::{GGLWEInfos, GGLWESwitchingKey, GLWESecret, GLWEToLWEKey, LWEInfos, LWESecret, Rank, prepared::GLWESecretPrepared}, -}; - -impl GLWEToLWEKey> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize - where - A: GGLWEInfos, - Module: SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes, - { - GLWESecretPrepared::alloc_bytes_with(module, infos.rank_in()) - + (GGLWESwitchingKey::encrypt_sk_scratch_space(module, infos) - | GLWESecret::alloc_bytes_with(infos.n(), infos.rank_in())) - } -} - -impl GLWEToLWEKey { - #[allow(clippy::too_many_arguments)] - pub fn encrypt_sk( - &mut self, - module: &Module, - sk_lwe: &LWESecret, - sk_glwe: &GLWESecret, - source_xa: &mut Source, - source_xe: &mut Source, - scratch: &mut Scratch, - ) where - DLwe: DataRef, - DGlwe: DataRef, - Module: VecZnxAutomorphismInplace - + VecZnxAddScalarInplace - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + VecZnxSwitchRing - + SvpPPolAllocBytes, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared, - { - #[cfg(debug_assertions)] - { - assert!(sk_lwe.n().0 <= module.n() as u32); - } - - let (mut sk_lwe_as_glwe, scratch_1) = scratch.take_glwe_secret(sk_glwe.n(), Rank(1)); - sk_lwe_as_glwe.data.zero(); - sk_lwe_as_glwe.data.at_mut(0, 0)[..sk_lwe.n().into()].copy_from_slice(sk_lwe.data.at(0, 0)); - module.vec_znx_automorphism_inplace(-1, &mut sk_lwe_as_glwe.data.as_vec_znx_mut(), 0, scratch_1); - - self.0.encrypt_sk( - module, - sk_glwe, - &sk_lwe_as_glwe, - source_xa, - source_xe, - scratch_1, - ); - } -} diff --git a/poulpy-core/src/encryption/glwe_to_lwe_switching_key.rs b/poulpy-core/src/encryption/glwe_to_lwe_switching_key.rs new file mode 100644 index 0000000..71877a4 --- /dev/null +++ b/poulpy-core/src/encryption/glwe_to_lwe_switching_key.rs @@ -0,0 +1,120 @@ +use poulpy_hal::{ + api::{ModuleN, VecZnxAutomorphismInplace, VecZnxAutomorphismInplaceTmpBytes}, + layouts::{Backend, DataMut, Module, Scratch, ZnxView, ZnxViewMut, ZnxZero}, + source::Source, +}; + +use crate::{ + GGLWEEncryptSk, ScratchTakeCore, + layouts::{ + GGLWE, GGLWEInfos, GGLWEToMut, GLWESecret, GLWESecretToRef, GLWEToLWESwitchingKey, LWEInfos, LWESecret, LWESecretToRef, + Rank, + prepared::{GLWESecretPrepared, GLWESecretPreparedFactory}, + }, +}; + +impl GLWEToLWESwitchingKey> { + pub fn encrypt_sk_tmp_bytes(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: GLWEToLWESwitchingKeyEncryptSk, + { + module.glwe_to_lwe_switching_key_encrypt_sk_tmp_bytes(infos) + } +} + +impl GLWEToLWESwitchingKey { + pub fn encrypt_sk( + &mut self, + module: &M, + sk_lwe: &S1, + sk_glwe: &S2, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + M: GLWEToLWESwitchingKeyEncryptSk, + S1: LWESecretToRef, + S2: GLWESecretToRef, + Scratch: ScratchTakeCore, + { + module.glwe_to_lwe_switching_key_encrypt_sk(self, sk_lwe, sk_glwe, source_xa, source_xe, scratch); + } +} + +pub trait GLWEToLWESwitchingKeyEncryptSk { + fn glwe_to_lwe_switching_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos; + + fn glwe_to_lwe_switching_key_encrypt_sk( + &self, + res: &mut R, + sk_lwe: &S1, + sk_glwe: &S2, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + S1: LWESecretToRef, + S2: GLWESecretToRef, + R: GGLWEToMut; +} + +impl GLWEToLWESwitchingKeyEncryptSk for Module +where + Self: ModuleN + + GGLWEEncryptSk + + GLWESecretPreparedFactory + + VecZnxAutomorphismInplace + + VecZnxAutomorphismInplaceTmpBytes, + Scratch: ScratchTakeCore, +{ + fn glwe_to_lwe_switching_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + GLWESecretPrepared::bytes_of(self, infos.rank_in()) + + GGLWE::encrypt_sk_tmp_bytes(self, infos) + .max(GLWESecret::bytes_of(self.n().into(), infos.rank_in()) + self.vec_znx_automorphism_inplace_tmp_bytes()) + } + + fn glwe_to_lwe_switching_key_encrypt_sk( + &self, + res: &mut R, + sk_lwe: &S1, + sk_glwe: &S2, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + S1: LWESecretToRef, + S2: GLWESecretToRef, + R: GGLWEToMut, + { + let sk_lwe: &LWESecret<&[u8]> = &sk_lwe.to_ref(); + let sk_glwe: &GLWESecret<&[u8]> = &sk_glwe.to_ref(); + + assert!(sk_lwe.n().0 <= self.n() as u32); + + let (mut sk_lwe_as_glwe_prep, scratch_1) = scratch.take_glwe_secret_prepared(self, Rank(1)); + + { + let (mut sk_lwe_as_glwe, scratch_2) = scratch_1.take_glwe_secret(self.n().into(), sk_lwe_as_glwe_prep.rank()); + sk_lwe_as_glwe.dist = sk_lwe.dist; + sk_lwe_as_glwe.data.zero(); + sk_lwe_as_glwe.data.at_mut(0, 0)[..sk_lwe.n().into()].copy_from_slice(sk_lwe.data.at(0, 0)); + self.vec_znx_automorphism_inplace(-1, &mut sk_lwe_as_glwe.data.as_vec_znx_mut(), 0, scratch_2); + sk_lwe_as_glwe_prep.prepare(self, &sk_lwe_as_glwe); + } + + self.gglwe_encrypt_sk( + res, + &sk_glwe.data, + &sk_lwe_as_glwe_prep, + source_xa, + source_xe, + scratch_1, + ); + } +} diff --git a/poulpy-core/src/encryption/lwe.rs b/poulpy-core/src/encryption/lwe.rs new file mode 100644 index 0000000..7651e8d --- /dev/null +++ b/poulpy-core/src/encryption/lwe.rs @@ -0,0 +1,100 @@ +use poulpy_hal::{ + api::{ScratchOwnedAlloc, ScratchOwnedBorrow, ZnAddNormal, ZnFillUniform, ZnNormalizeInplace}, + layouts::{Backend, DataMut, Module, ScratchOwned, Zn, ZnxView, ZnxViewMut}, + source::Source, +}; + +use crate::{ + encryption::{SIGMA, SIGMA_BOUND}, + layouts::{LWE, LWEInfos, LWEPlaintext, LWEPlaintextToRef, LWESecret, LWESecretToRef, LWEToMut}, +}; + +impl LWE { + pub fn encrypt_sk(&mut self, module: &M, pt: &P, sk: &S, source_xa: &mut Source, source_xe: &mut Source) + where + P: LWEPlaintextToRef, + S: LWESecretToRef, + M: LWEEncryptSk, + { + module.lwe_encrypt_sk(self, pt, sk, source_xa, source_xe); + } +} + +pub trait LWEEncryptSk { + fn lwe_encrypt_sk(&self, res: &mut R, pt: &P, sk: &S, source_xa: &mut Source, source_xe: &mut Source) + where + R: LWEToMut, + P: LWEPlaintextToRef, + S: LWESecretToRef; +} + +impl LWEEncryptSk for Module +where + Self: Sized + ZnFillUniform + ZnAddNormal + ZnNormalizeInplace, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + fn lwe_encrypt_sk(&self, res: &mut R, pt: &P, sk: &S, source_xa: &mut Source, source_xe: &mut Source) + where + R: LWEToMut, + P: LWEPlaintextToRef, + S: LWESecretToRef, + { + let res: &mut LWE<&mut [u8]> = &mut res.to_mut(); + let pt: &LWEPlaintext<&[u8]> = &pt.to_ref(); + let sk: &LWESecret<&[u8]> = &sk.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.n(), sk.n()) + } + + let base2k: usize = res.base2k().into(); + let k: usize = res.k().into(); + + self.zn_fill_uniform((res.n() + 1).into(), base2k, &mut res.data, 0, source_xa); + + let mut tmp_znx: Zn> = Zn::alloc(1, 1, res.size()); + + let min_size = res.size().min(pt.size()); + + (0..min_size).for_each(|i| { + tmp_znx.at_mut(0, i)[0] = pt.data.at(0, i)[0] + - res.data.at(0, i)[1..] + .iter() + .zip(sk.data.at(0, 0)) + .map(|(x, y)| x * y) + .sum::(); + }); + + (min_size..res.size()).for_each(|i| { + tmp_znx.at_mut(0, i)[0] -= res.data.at(0, i)[1..] + .iter() + .zip(sk.data.at(0, 0)) + .map(|(x, y)| x * y) + .sum::(); + }); + + self.zn_add_normal( + 1, + base2k, + &mut res.data, + 0, + k, + source_xe, + SIGMA, + SIGMA_BOUND, + ); + + self.zn_normalize_inplace( + 1, + base2k, + &mut tmp_znx, + 0, + ScratchOwned::alloc(size_of::()).borrow(), + ); + + (0..res.size()).for_each(|i| { + res.data.at_mut(0, i)[0] = tmp_znx.at(0, i)[0]; + }); + } +} diff --git a/poulpy-core/src/encryption/lwe_ct.rs b/poulpy-core/src/encryption/lwe_ct.rs deleted file mode 100644 index 4dd09ac..0000000 --- a/poulpy-core/src/encryption/lwe_ct.rs +++ /dev/null @@ -1,81 +0,0 @@ -use poulpy_hal::{ - api::{ScratchOwnedAlloc, ScratchOwnedBorrow, ZnAddNormal, ZnFillUniform, ZnNormalizeInplace}, - layouts::{Backend, DataMut, DataRef, Module, ScratchOwned, Zn, ZnxView, ZnxViewMut}, - oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl}, - source::Source, -}; - -use crate::{ - encryption::{SIGMA, SIGMA_BOUND}, - layouts::{LWECiphertext, LWEInfos, LWEPlaintext, LWESecret}, -}; - -impl LWECiphertext { - pub fn encrypt_sk( - &mut self, - module: &Module, - pt: &LWEPlaintext, - sk: &LWESecret, - source_xa: &mut Source, - source_xe: &mut Source, - ) where - DataPt: DataRef, - DataSk: DataRef, - Module: ZnFillUniform + ZnAddNormal + ZnNormalizeInplace, - B: Backend + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, - { - #[cfg(debug_assertions)] - { - assert_eq!(self.n(), sk.n()) - } - - let base2k: usize = self.base2k().into(); - let k: usize = self.k().into(); - - module.zn_fill_uniform((self.n() + 1).into(), base2k, &mut self.data, 0, source_xa); - - let mut tmp_znx: Zn> = Zn::alloc(1, 1, self.size()); - - let min_size = self.size().min(pt.size()); - - (0..min_size).for_each(|i| { - tmp_znx.at_mut(0, i)[0] = pt.data.at(0, i)[0] - - self.data.at(0, i)[1..] - .iter() - .zip(sk.data.at(0, 0)) - .map(|(x, y)| x * y) - .sum::(); - }); - - (min_size..self.size()).for_each(|i| { - tmp_znx.at_mut(0, i)[0] -= self.data.at(0, i)[1..] - .iter() - .zip(sk.data.at(0, 0)) - .map(|(x, y)| x * y) - .sum::(); - }); - - module.zn_add_normal( - 1, - base2k, - &mut self.data, - 0, - k, - source_xe, - SIGMA, - SIGMA_BOUND, - ); - - module.zn_normalize_inplace( - 1, - base2k, - &mut tmp_znx, - 0, - ScratchOwned::alloc(size_of::()).borrow(), - ); - - (0..self.size()).for_each(|i| { - self.data.at_mut(0, i)[0] = tmp_znx.at(0, i)[0]; - }); - } -} diff --git a/poulpy-core/src/encryption/lwe_ksk.rs b/poulpy-core/src/encryption/lwe_ksk.rs deleted file mode 100644 index 66ae685..0000000 --- a/poulpy-core/src/encryption/lwe_ksk.rs +++ /dev/null @@ -1,107 +0,0 @@ -use poulpy_hal::{ - api::{ - ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigNormalize, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, - }, - layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut}, - source::Source, -}; - -use crate::{ - TakeGLWESecret, TakeGLWESecretPrepared, - layouts::{ - Degree, GGLWEInfos, GGLWESwitchingKey, GLWESecret, LWEInfos, LWESecret, LWESwitchingKey, Rank, - prepared::GLWESecretPrepared, - }, -}; - -impl LWESwitchingKey> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize - where - A: GGLWEInfos, - Module: SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes, - { - debug_assert_eq!( - infos.dsize().0, - 1, - "dsize > 1 is not supported for LWESwitchingKey" - ); - debug_assert_eq!( - infos.rank_in().0, - 1, - "rank_in > 1 is not supported for LWESwitchingKey" - ); - debug_assert_eq!( - infos.rank_out().0, - 1, - "rank_out > 1 is not supported for LWESwitchingKey" - ); - GLWESecret::alloc_bytes_with(Degree(module.n() as u32), Rank(1)) - + GLWESecretPrepared::alloc_bytes_with(module, Rank(1)) - + GGLWESwitchingKey::encrypt_sk_scratch_space(module, infos) - } -} - -impl LWESwitchingKey { - #[allow(clippy::too_many_arguments)] - pub fn encrypt_sk( - &mut self, - module: &Module, - sk_lwe_in: &LWESecret, - sk_lwe_out: &LWESecret, - source_xa: &mut Source, - source_xe: &mut Source, - scratch: &mut Scratch, - ) where - DIn: DataRef, - DOut: DataRef, - Module: VecZnxAutomorphismInplace - + VecZnxAddScalarInplace - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + VecZnxSwitchRing - + SvpPPolAllocBytes, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared, - { - #[cfg(debug_assertions)] - { - assert!(sk_lwe_in.n().0 <= self.n().0); - assert!(sk_lwe_out.n().0 <= self.n().0); - assert!(self.n().0 <= module.n() as u32); - } - - let (mut sk_in_glwe, scratch_1) = scratch.take_glwe_secret(self.n(), Rank(1)); - let (mut sk_out_glwe, scratch_2) = scratch_1.take_glwe_secret(self.n(), Rank(1)); - - sk_out_glwe.data.at_mut(0, 0)[..sk_lwe_out.n().into()].copy_from_slice(sk_lwe_out.data.at(0, 0)); - sk_out_glwe.data.at_mut(0, 0)[sk_lwe_out.n().into()..].fill(0); - module.vec_znx_automorphism_inplace(-1, &mut sk_out_glwe.data.as_vec_znx_mut(), 0, scratch_2); - - sk_in_glwe.data.at_mut(0, 0)[..sk_lwe_in.n().into()].copy_from_slice(sk_lwe_in.data.at(0, 0)); - sk_in_glwe.data.at_mut(0, 0)[sk_lwe_in.n().into()..].fill(0); - module.vec_znx_automorphism_inplace(-1, &mut sk_in_glwe.data.as_vec_znx_mut(), 0, scratch_2); - - self.0.encrypt_sk( - module, - &sk_in_glwe, - &sk_out_glwe, - source_xa, - source_xe, - scratch_2, - ); - } -} diff --git a/poulpy-core/src/encryption/lwe_switching_key.rs b/poulpy-core/src/encryption/lwe_switching_key.rs new file mode 100644 index 0000000..431545b --- /dev/null +++ b/poulpy-core/src/encryption/lwe_switching_key.rs @@ -0,0 +1,137 @@ +use poulpy_hal::{ + api::{ModuleN, VecZnxAutomorphismInplace}, + layouts::{Backend, DataMut, Module, Scratch, ZnxView, ZnxViewMut}, + source::Source, +}; + +use crate::{ + ScratchTakeCore, + encryption::glwe_switching_key::GLWESwitchingKeyEncryptSk, + layouts::{ + GGLWEInfos, GGLWEToMut, GLWESecret, GLWESwitchingKey, GLWESwitchingKeyDegreesMut, LWEInfos, LWESecret, LWESecretToRef, + LWESwitchingKey, Rank, + prepared::{GLWESecretPrepared, GLWESecretPreparedFactory}, + }, +}; + +impl LWESwitchingKey> { + pub fn encrypt_sk_tmp_bytes(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: LWESwitchingKeyEncrypt, + { + module.lwe_switching_key_encrypt_sk_tmp_bytes(infos) + } +} + +impl LWESwitchingKey { + pub fn encrypt_sk( + &mut self, + module: &M, + sk_lwe_in: &S1, + sk_lwe_out: &S2, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + S1: LWESecretToRef, + S2: LWESecretToRef, + M: LWESwitchingKeyEncrypt, + { + module.lwe_switching_key_encrypt_sk(self, sk_lwe_in, sk_lwe_out, source_xa, source_xe, scratch); + } +} + +pub trait LWESwitchingKeyEncrypt { + fn lwe_switching_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos; + + fn lwe_switching_key_encrypt_sk( + &self, + res: &mut R, + sk_lwe_in: &S1, + sk_lwe_out: &S2, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGLWEToMut + GLWESwitchingKeyDegreesMut + GGLWEInfos, + S1: LWESecretToRef, + S2: LWESecretToRef; +} + +impl LWESwitchingKeyEncrypt for Module +where + Self: ModuleN + GLWESwitchingKeyEncryptSk + GLWESecretPreparedFactory + VecZnxAutomorphismInplace, + Scratch: ScratchTakeCore, +{ + fn lwe_switching_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + assert_eq!( + infos.dsize().0, + 1, + "dsize > 1 is not supported for LWESwitchingKey" + ); + assert_eq!( + infos.rank_in().0, + 1, + "rank_in > 1 is not supported for LWESwitchingKey" + ); + assert_eq!( + infos.rank_out().0, + 1, + "rank_out > 1 is not supported for LWESwitchingKey" + ); + GLWESecret::bytes_of(self.n().into(), Rank(1)) + + GLWESecretPrepared::bytes_of(self, Rank(1)) + + GLWESwitchingKey::encrypt_sk_tmp_bytes(self, infos) + } + + #[allow(clippy::too_many_arguments)] + fn lwe_switching_key_encrypt_sk( + &self, + res: &mut R, + sk_lwe_in: &S1, + sk_lwe_out: &S2, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGLWEToMut + GLWESwitchingKeyDegreesMut + GGLWEInfos, + S1: LWESecretToRef, + S2: LWESecretToRef, + { + let sk_lwe_in: &LWESecret<&[u8]> = &sk_lwe_in.to_ref(); + let sk_lwe_out: &LWESecret<&[u8]> = &sk_lwe_out.to_ref(); + + assert!(sk_lwe_in.n().0 <= res.n().0); + assert!(sk_lwe_out.n().0 <= res.n().0); + assert!(res.n() <= self.n() as u32); + + let (mut sk_glwe_in, scratch_1) = scratch.take_glwe_secret(self.n().into(), Rank(1)); + let (mut sk_glwe_out, scratch_2) = scratch_1.take_glwe_secret(self.n().into(), Rank(1)); + + sk_glwe_in.dist = sk_lwe_in.dist; + sk_glwe_out.dist = sk_lwe_out.dist; + + sk_glwe_out.data.at_mut(0, 0)[..sk_lwe_out.n().into()].copy_from_slice(sk_lwe_out.data.at(0, 0)); + sk_glwe_out.data.at_mut(0, 0)[sk_lwe_out.n().into()..].fill(0); + self.vec_znx_automorphism_inplace(-1, &mut sk_glwe_out.data.as_vec_znx_mut(), 0, scratch_2); + + sk_glwe_in.data.at_mut(0, 0)[..sk_lwe_in.n().into()].copy_from_slice(sk_lwe_in.data.at(0, 0)); + sk_glwe_in.data.at_mut(0, 0)[sk_lwe_in.n().into()..].fill(0); + self.vec_znx_automorphism_inplace(-1, &mut sk_glwe_in.data.as_vec_znx_mut(), 0, scratch_2); + + self.glwe_switching_key_encrypt_sk( + res, + &sk_glwe_in, + &sk_glwe_out, + source_xa, + source_xe, + scratch_2, + ); + } +} diff --git a/poulpy-core/src/encryption/lwe_to_glwe_ksk.rs b/poulpy-core/src/encryption/lwe_to_glwe_ksk.rs deleted file mode 100644 index 204e84b..0000000 --- a/poulpy-core/src/encryption/lwe_to_glwe_ksk.rs +++ /dev/null @@ -1,87 +0,0 @@ -use poulpy_hal::{ - api::{ - ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigNormalize, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, - }, - layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut}, - source::Source, -}; - -use crate::{ - TakeGLWESecret, TakeGLWESecretPrepared, - layouts::{Degree, GGLWEInfos, GGLWESwitchingKey, GLWESecret, LWEInfos, LWESecret, LWEToGLWESwitchingKey, Rank}, -}; - -impl LWEToGLWESwitchingKey> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize - where - A: GGLWEInfos, - Module: SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes, - { - debug_assert_eq!( - infos.rank_in(), - Rank(1), - "rank_in != 1 is not supported for LWEToGLWESwitchingKey" - ); - GGLWESwitchingKey::encrypt_sk_scratch_space(module, infos) - + GLWESecret::alloc_bytes_with(Degree(module.n() as u32), infos.rank_in()) - } -} - -impl LWEToGLWESwitchingKey { - #[allow(clippy::too_many_arguments)] - pub fn encrypt_sk( - &mut self, - module: &Module, - sk_lwe: &LWESecret, - sk_glwe: &GLWESecret, - source_xa: &mut Source, - source_xe: &mut Source, - scratch: &mut Scratch, - ) where - DLwe: DataRef, - DGlwe: DataRef, - Module: VecZnxAutomorphismInplace - + VecZnxAddScalarInplace - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + VecZnxSwitchRing - + SvpPPolAllocBytes, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared, - { - #[cfg(debug_assertions)] - { - use crate::layouts::LWEInfos; - - assert!(sk_lwe.n().0 <= module.n() as u32); - } - - let (mut sk_lwe_as_glwe, scratch_1) = scratch.take_glwe_secret(sk_glwe.n(), Rank(1)); - sk_lwe_as_glwe.data.at_mut(0, 0)[..sk_lwe.n().into()].copy_from_slice(sk_lwe.data.at(0, 0)); - sk_lwe_as_glwe.data.at_mut(0, 0)[sk_lwe.n().into()..].fill(0); - module.vec_znx_automorphism_inplace(-1, &mut sk_lwe_as_glwe.data.as_vec_znx_mut(), 0, scratch_1); - - self.0.encrypt_sk( - module, - &sk_lwe_as_glwe, - sk_glwe, - source_xa, - source_xe, - scratch_1, - ); - } -} diff --git a/poulpy-core/src/encryption/lwe_to_glwe_switching_key.rs b/poulpy-core/src/encryption/lwe_to_glwe_switching_key.rs new file mode 100644 index 0000000..af31420 --- /dev/null +++ b/poulpy-core/src/encryption/lwe_to_glwe_switching_key.rs @@ -0,0 +1,118 @@ +use poulpy_hal::{ + api::{ModuleN, VecZnxAutomorphismInplace, VecZnxAutomorphismInplaceTmpBytes}, + layouts::{Backend, DataMut, Module, Scratch, ZnxView, ZnxViewMut}, + source::Source, +}; + +use crate::{ + GGLWEEncryptSk, ScratchTakeCore, + layouts::{ + GGLWE, GGLWEInfos, GGLWEToMut, GLWESecret, GLWESecretPreparedFactory, GLWESecretPreparedToRef, LWEInfos, LWESecret, + LWESecretToRef, LWEToGLWESwitchingKey, Rank, + }, +}; + +impl LWEToGLWESwitchingKey> { + pub fn encrypt_sk_tmp_bytes(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: LWEToGLWESwitchingKeyEncryptSk, + { + module.lwe_to_glwe_switching_key_encrypt_sk_tmp_bytes(infos) + } +} + +impl LWEToGLWESwitchingKey { + pub fn encrypt_sk( + &mut self, + module: &M, + sk_lwe: &S1, + sk_glwe: &S2, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + S1: LWESecretToRef, + S2: GLWESecretPreparedToRef, + M: LWEToGLWESwitchingKeyEncryptSk, + Scratch: ScratchTakeCore, + { + module.lwe_to_glwe_switching_key_encrypt_sk(self, sk_lwe, sk_glwe, source_xa, source_xe, scratch); + } +} + +pub trait LWEToGLWESwitchingKeyEncryptSk { + fn lwe_to_glwe_switching_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos; + + fn lwe_to_glwe_switching_key_encrypt_sk( + &self, + res: &mut R, + sk_lwe: &S1, + sk_glwe: &S2, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + S1: LWESecretToRef, + S2: GLWESecretPreparedToRef, + R: GGLWEToMut; +} + +impl LWEToGLWESwitchingKeyEncryptSk for Module +where + Self: ModuleN + + GGLWEEncryptSk + + VecZnxAutomorphismInplace + + GLWESecretPreparedFactory + + VecZnxAutomorphismInplaceTmpBytes, + Scratch: ScratchTakeCore, +{ + fn lwe_to_glwe_switching_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + debug_assert_eq!( + infos.rank_in(), + Rank(1), + "rank_in != 1 is not supported for LWEToGLWESwitchingKey" + ); + GLWESecret::bytes_of(self.n().into(), infos.rank_in()) + + GGLWE::encrypt_sk_tmp_bytes(self, infos).max(self.vec_znx_automorphism_inplace_tmp_bytes()) + } + + fn lwe_to_glwe_switching_key_encrypt_sk( + &self, + res: &mut R, + sk_lwe: &S1, + sk_glwe: &S2, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + S1: LWESecretToRef, + S2: GLWESecretPreparedToRef, + R: GGLWEToMut, + { + let sk_lwe: &LWESecret<&[u8]> = &sk_lwe.to_ref(); + + assert!(sk_lwe.n().0 <= self.n() as u32); + + let (mut sk_lwe_as_glwe, scratch_1) = scratch.take_glwe_secret(self.n().into(), Rank(1)); + sk_lwe_as_glwe.dist = sk_lwe.dist; + + sk_lwe_as_glwe.data.at_mut(0, 0)[..sk_lwe.n().into()].copy_from_slice(sk_lwe.data.at(0, 0)); + sk_lwe_as_glwe.data.at_mut(0, 0)[sk_lwe.n().into()..].fill(0); + self.vec_znx_automorphism_inplace(-1, &mut sk_lwe_as_glwe.data.as_vec_znx_mut(), 0, scratch_1); + + self.gglwe_encrypt_sk( + res, + &sk_lwe_as_glwe.data, + sk_glwe, + source_xa, + source_xe, + scratch_1, + ); + } +} diff --git a/poulpy-core/src/encryption/mod.rs b/poulpy-core/src/encryption/mod.rs index 9380933..7a391a6 100644 --- a/poulpy-core/src/encryption/mod.rs +++ b/poulpy-core/src/encryption/mod.rs @@ -1,17 +1,28 @@ mod compressed; -mod gglwe_atk; -mod gglwe_ct; -mod gglwe_ksk; -mod gglwe_tsk; -mod ggsw_ct; -mod glwe_ct; -mod glwe_pk; -mod glwe_to_lwe_ksk; -mod lwe_ct; -mod lwe_ksk; -mod lwe_to_glwe_ksk; +mod gglwe; +mod ggsw; +mod glwe; +mod glwe_automorphism_key; +mod glwe_public_key; +mod glwe_switching_key; +mod glwe_tensor_key; +mod glwe_to_lwe_switching_key; +mod lwe; +mod lwe_switching_key; +mod lwe_to_glwe_switching_key; -pub(crate) use glwe_ct::glwe_encrypt_sk_internal; +pub use compressed::*; +pub use gglwe::*; +pub use ggsw::*; +pub use glwe::*; +pub use glwe_automorphism_key::*; +pub use glwe_public_key::*; +pub use glwe_switching_key::*; +pub use glwe_tensor_key::*; +pub use glwe_to_lwe_switching_key::*; +pub use lwe::*; +pub use lwe_switching_key::*; +pub use lwe_to_glwe_switching_key::*; pub const SIGMA: f64 = 3.2; pub(crate) const SIGMA_BOUND: f64 = 6.0 * SIGMA; diff --git a/poulpy-core/src/external_product/gglwe.rs b/poulpy-core/src/external_product/gglwe.rs new file mode 100644 index 0000000..437cf39 --- /dev/null +++ b/poulpy-core/src/external_product/gglwe.rs @@ -0,0 +1,167 @@ +use poulpy_hal::layouts::{Backend, DataMut, Module, Scratch, ZnxZero}; + +use crate::{ + GLWEExternalProduct, ScratchTakeCore, + layouts::{ + GGLWE, GGLWEInfos, GGLWEToMut, GGLWEToRef, GGSWInfos, GGSWPrepared, GLWEAutomorphismKey, GLWEInfos, GLWESwitchingKey, + prepared::GGSWPreparedToRef, + }, +}; + +impl GLWEAutomorphismKey> { + pub fn external_product_tmp_bytes( + &self, + module: &M, + res_infos: &R, + a_infos: &A, + b_infos: &B, + ) -> usize + where + R: GGLWEInfos, + A: GGLWEInfos, + B: GGSWInfos, + M: GGLWEExternalProduct, + { + module.gglwe_external_product_tmp_bytes(res_infos, a_infos, b_infos) + } +} + +impl GLWEAutomorphismKey { + pub fn external_product(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch) + where + M: GGLWEExternalProduct, + A: GGLWEToRef, + B: GGSWPreparedToRef, + Scratch: ScratchTakeCore, + { + module.gglwe_external_product(self, a, b, scratch); + } + + pub fn external_product_inplace(&mut self, module: &M, a: &A, scratch: &mut Scratch) + where + M: GGLWEExternalProduct, + A: GGSWPreparedToRef, + Scratch: ScratchTakeCore, + { + module.gglwe_external_product_inplace(self, a, scratch); + } +} + +pub trait GGLWEExternalProduct +where + Self: GLWEExternalProduct, +{ + fn gglwe_external_product_tmp_bytes(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize + where + R: GGLWEInfos, + A: GGLWEInfos, + B: GGSWInfos, + { + self.glwe_external_product_tmp_bytes(res_infos, a_infos, b_infos) + } + + fn gglwe_external_product(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) + where + R: GGLWEToMut, + A: GGLWEToRef, + B: GGSWPreparedToRef, + Scratch: ScratchTakeCore, + { + let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GGLWE<&[u8]> = &a.to_ref(); + let b: &GGSWPrepared<&[u8], BE> = &b.to_ref(); + + assert_eq!( + res.rank_in(), + a.rank_in(), + "res input rank_in: {} != a input rank_in: {}", + res.rank_in(), + a.rank_in() + ); + assert_eq!( + a.rank_out(), + b.rank(), + "a output rank_out: {} != b rank: {}", + a.rank_out(), + b.rank() + ); + assert_eq!( + res.rank_out(), + b.rank(), + "res output rank_out: {} != b rank: {}", + res.rank_out(), + b.rank() + ); + + for row in 0..res.dnum().into() { + for col in 0..res.rank_in().into() { + self.glwe_external_product(&mut res.at_mut(row, col), &a.at(row, col), b, scratch); + } + } + + for row in res.dnum().min(a.dnum()).into()..res.dnum().into() { + for col in 0..res.rank_in().into() { + res.at_mut(row, col).data_mut().zero(); + } + } + } + + fn gglwe_external_product_inplace(&self, res: &mut R, a: &A, scratch: &mut Scratch) + where + R: GGLWEToMut, + A: GGSWPreparedToRef, + Scratch: ScratchTakeCore, + { + let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GGSWPrepared<&[u8], BE> = &a.to_ref(); + + assert_eq!( + res.rank_out(), + a.rank(), + "res output rank: {} != a rank: {}", + res.rank_out(), + a.rank() + ); + + for row in 0..res.dnum().into() { + for col in 0..res.rank_in().into() { + self.glwe_external_product_inplace(&mut res.at_mut(row, col), a, scratch); + } + } + } +} + +impl GGLWEExternalProduct for Module where Self: GLWEExternalProduct {} + +impl GLWESwitchingKey> { + pub fn external_product_tmp_bytes(module: &M, res_infos: &R, a_infos: &A, b_infos: &B) -> usize + where + R: GGLWEInfos, + A: GGLWEInfos, + B: GGSWInfos, + M: GGLWEExternalProduct, + { + module.gglwe_external_product_tmp_bytes(res_infos, a_infos, b_infos) + } +} + +impl GLWESwitchingKey { + pub fn external_product(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch) + where + M: GGLWEExternalProduct, + A: GGLWEToRef, + B: GGSWPreparedToRef, + Scratch: ScratchTakeCore, + { + module.gglwe_external_product(self, a, b, scratch); + } + + pub fn external_product_inplace(&mut self, module: &M, a: &A, scratch: &mut Scratch) + where + M: GGLWEExternalProduct, + A: GGSWPreparedToRef, + Scratch: ScratchTakeCore, + { + module.gglwe_external_product_inplace(self, a, scratch); + } +} diff --git a/poulpy-core/src/external_product/gglwe_atk.rs b/poulpy-core/src/external_product/gglwe_atk.rs deleted file mode 100644 index cb35a4c..0000000 --- a/poulpy-core/src/external_product/gglwe_atk.rs +++ /dev/null @@ -1,83 +0,0 @@ -use poulpy_hal::{ - api::{ - ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, - VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, - VmpApplyDftToDftTmpBytes, - }, - layouts::{Backend, DataMut, DataRef, Module, Scratch}, -}; - -use crate::layouts::{GGLWEAutomorphismKey, GGLWEInfos, GGLWESwitchingKey, GGSWInfos, prepared::GGSWCiphertextPrepared}; - -impl GGLWEAutomorphismKey> { - pub fn external_product_scratch_space( - module: &Module, - out_infos: &OUT, - in_infos: &IN, - ggsw_infos: &GGSW, - ) -> usize - where - OUT: GGLWEInfos, - IN: GGLWEInfos, - GGSW: GGSWInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes, - { - GGLWESwitchingKey::external_product_scratch_space(module, out_infos, in_infos, ggsw_infos) - } - - pub fn external_product_inplace_scratch_space( - module: &Module, - out_infos: &OUT, - ggsw_infos: &GGSW, - ) -> usize - where - OUT: GGLWEInfos, - GGSW: GGSWInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes, - { - GGLWESwitchingKey::external_product_inplace_scratch_space(module, out_infos, ggsw_infos) - } -} - -impl GGLWEAutomorphismKey { - pub fn external_product( - &mut self, - module: &Module, - lhs: &GGLWEAutomorphismKey, - rhs: &GGSWCiphertextPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxNormalizeTmpBytes - + VecZnxDftApply - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxIdftApplyConsume - + VecZnxBigNormalize - + VecZnxNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, - { - self.key.external_product(module, &lhs.key, rhs, scratch); - } - - pub fn external_product_inplace( - &mut self, - module: &Module, - rhs: &GGSWCiphertextPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxNormalizeTmpBytes - + VecZnxDftApply - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxIdftApplyConsume - + VecZnxBigNormalize - + VecZnxNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, - { - self.key.external_product_inplace(module, rhs, scratch); - } -} diff --git a/poulpy-core/src/external_product/gglwe_ksk.rs b/poulpy-core/src/external_product/gglwe_ksk.rs deleted file mode 100644 index 2eff45c..0000000 --- a/poulpy-core/src/external_product/gglwe_ksk.rs +++ /dev/null @@ -1,144 +0,0 @@ -use poulpy_hal::{ - api::{ - ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, - VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, - VmpApplyDftToDftTmpBytes, - }, - layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero}, -}; - -use crate::layouts::{GGLWEInfos, GGLWESwitchingKey, GGSWInfos, GLWECiphertext, prepared::GGSWCiphertextPrepared}; - -impl GGLWESwitchingKey> { - pub fn external_product_scratch_space( - module: &Module, - out_infos: &OUT, - in_infos: &IN, - ggsw_infos: &GGSW, - ) -> usize - where - OUT: GGLWEInfos, - IN: GGLWEInfos, - GGSW: GGSWInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes, - { - GLWECiphertext::external_product_scratch_space( - module, - &out_infos.glwe_layout(), - &in_infos.glwe_layout(), - ggsw_infos, - ) - } - - pub fn external_product_inplace_scratch_space( - module: &Module, - out_infos: &OUT, - ggsw_infos: &GGSW, - ) -> usize - where - OUT: GGLWEInfos, - GGSW: GGSWInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes, - { - GLWECiphertext::external_product_inplace_scratch_space(module, &out_infos.glwe_layout(), ggsw_infos) - } -} - -impl GGLWESwitchingKey { - pub fn external_product( - &mut self, - module: &Module, - lhs: &GGLWESwitchingKey, - rhs: &GGSWCiphertextPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxNormalizeTmpBytes - + VecZnxDftApply - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxIdftApplyConsume - + VecZnxBigNormalize - + VecZnxNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, - { - #[cfg(debug_assertions)] - { - use crate::layouts::GLWEInfos; - - assert_eq!( - self.rank_in(), - lhs.rank_in(), - "ksk_out input rank: {} != ksk_in input rank: {}", - self.rank_in(), - lhs.rank_in() - ); - assert_eq!( - lhs.rank_out(), - rhs.rank(), - "ksk_in output rank: {} != ggsw rank: {}", - self.rank_out(), - rhs.rank() - ); - assert_eq!( - self.rank_out(), - rhs.rank(), - "ksk_out output rank: {} != ggsw rank: {}", - self.rank_out(), - rhs.rank() - ); - } - - (0..self.rank_in().into()).for_each(|col_i| { - (0..self.dnum().into()).for_each(|row_j| { - self.at_mut(row_j, col_i) - .external_product(module, &lhs.at(row_j, col_i), rhs, scratch); - }); - }); - - (self.dnum().min(lhs.dnum()).into()..self.dnum().into()).for_each(|row_i| { - (0..self.rank_in().into()).for_each(|col_j| { - self.at_mut(row_i, col_j).data.zero(); - }); - }); - } - - pub fn external_product_inplace( - &mut self, - module: &Module, - rhs: &GGSWCiphertextPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxNormalizeTmpBytes - + VecZnxDftApply - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxIdftApplyConsume - + VecZnxBigNormalize - + VecZnxNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, - { - #[cfg(debug_assertions)] - { - use crate::layouts::GLWEInfos; - - assert_eq!( - self.rank_out(), - rhs.rank(), - "ksk_out output rank: {} != ggsw rank: {}", - self.rank_out(), - rhs.rank() - ); - } - - (0..self.rank_in().into()).for_each(|col_i| { - (0..self.dnum().into()).for_each(|row_j| { - self.at_mut(row_j, col_i) - .external_product_inplace(module, rhs, scratch); - }); - }); - } -} diff --git a/poulpy-core/src/external_product/ggsw.rs b/poulpy-core/src/external_product/ggsw.rs new file mode 100644 index 0000000..f9659fe --- /dev/null +++ b/poulpy-core/src/external_product/ggsw.rs @@ -0,0 +1,130 @@ +use poulpy_hal::{ + api::ScratchAvailable, + layouts::{Backend, DataMut, Module, Scratch, ZnxZero}, +}; + +use crate::{ + GLWEExternalProduct, ScratchTakeCore, + layouts::{ + GGSW, GGSWInfos, GGSWToMut, GGSWToRef, GLWEInfos, LWEInfos, + prepared::{GGSWPrepared, GGSWPreparedToRef}, + }, +}; + +pub trait GGSWExternalProduct +where + Self: GLWEExternalProduct, +{ + fn ggsw_external_product_tmp_bytes(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize + where + R: GGSWInfos, + A: GGSWInfos, + B: GGSWInfos, + { + self.glwe_external_product_tmp_bytes(res_infos, a_infos, b_infos) + } + + fn ggsw_external_product(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) + where + R: GGSWToMut, + A: GGSWToRef, + B: GGSWPreparedToRef, + Scratch: ScratchTakeCore, + { + let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); + let a: &GGSW<&[u8]> = &a.to_ref(); + let b: &GGSWPrepared<&[u8], BE> = &b.to_ref(); + + assert_eq!( + res.rank(), + a.rank(), + "res rank: {} != a rank: {}", + res.rank(), + a.rank() + ); + assert_eq!( + res.rank(), + b.rank(), + "res rank: {} != b rank: {}", + res.rank(), + b.rank() + ); + + assert!(scratch.available() >= self.ggsw_external_product_tmp_bytes(res, a, b)); + + let min_dnum: usize = res.dnum().min(a.dnum()).into(); + + for row in 0..min_dnum { + for col in 0..(res.rank() + 1).into() { + self.glwe_external_product(&mut res.at_mut(row, col), &a.at(row, col), b, scratch); + } + } + + for row in min_dnum..res.dnum().into() { + for col in 0..(res.rank() + 1).into() { + res.at_mut(row, col).data.zero(); + } + } + } + + fn ggsw_external_product_inplace(&self, res: &mut R, a: &A, scratch: &mut Scratch) + where + R: GGSWToMut, + A: GGSWPreparedToRef, + Scratch: ScratchTakeCore, + { + let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); + let a: &GGSWPrepared<&[u8], BE> = &a.to_ref(); + + assert_eq!(res.n(), self.n() as u32); + assert_eq!(a.n(), self.n() as u32); + assert_eq!( + res.rank(), + a.rank(), + "res rank: {} != a rank: {}", + res.rank(), + a.rank() + ); + + for row in 0..res.dnum().into() { + for col in 0..(res.rank() + 1).into() { + self.glwe_external_product_inplace(&mut res.at_mut(row, col), a, scratch); + } + } + } +} + +impl GGSWExternalProduct for Module where Self: GLWEExternalProduct {} + +impl GGSW> { + pub fn external_product_tmp_bytes(module: &M, res_infos: &R, a_infos: &A, b_infos: &B) -> usize + where + R: GGSWInfos, + A: GGSWInfos, + B: GGSWInfos, + M: GGSWExternalProduct, + { + module.ggsw_external_product_tmp_bytes(res_infos, a_infos, b_infos) + } +} + +impl GGSW { + pub fn external_product(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch) + where + M: GGSWExternalProduct, + A: GGSWToRef, + B: GGSWPreparedToRef, + Scratch: ScratchTakeCore, + { + module.ggsw_external_product(self, a, b, scratch); + } + + pub fn external_product_inplace(&mut self, module: &M, a: &A, scratch: &mut Scratch) + where + M: GGSWExternalProduct, + A: GGSWPreparedToRef, + Scratch: ScratchTakeCore, + { + module.ggsw_external_product_inplace(self, a, scratch); + } +} diff --git a/poulpy-core/src/external_product/ggsw_ct.rs b/poulpy-core/src/external_product/ggsw_ct.rs deleted file mode 100644 index a458de1..0000000 --- a/poulpy-core/src/external_product/ggsw_ct.rs +++ /dev/null @@ -1,143 +0,0 @@ -use poulpy_hal::{ - api::{ - ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, - VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, - VmpApplyDftToDftTmpBytes, - }, - layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero}, -}; - -use crate::layouts::{GGSWCiphertext, GGSWInfos, GLWECiphertext, GLWEInfos, prepared::GGSWCiphertextPrepared}; - -impl GGSWCiphertext> { - #[allow(clippy::too_many_arguments)] - pub fn external_product_scratch_space( - module: &Module, - out_infos: &OUT, - in_infos: &IN, - apply_infos: &GGSW, - ) -> usize - where - OUT: GGSWInfos, - IN: GGSWInfos, - GGSW: GGSWInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes, - { - GLWECiphertext::external_product_scratch_space( - module, - &out_infos.glwe_layout(), - &in_infos.glwe_layout(), - apply_infos, - ) - } - - pub fn external_product_inplace_scratch_space( - module: &Module, - out_infos: &OUT, - apply_infos: &GGSW, - ) -> usize - where - OUT: GGSWInfos, - GGSW: GGSWInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes, - { - GLWECiphertext::external_product_inplace_scratch_space(module, &out_infos.glwe_layout(), apply_infos) - } -} - -impl GGSWCiphertext { - pub fn external_product( - &mut self, - module: &Module, - lhs: &GGSWCiphertext, - rhs: &GGSWCiphertextPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxNormalizeTmpBytes - + VecZnxDftApply - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxIdftApplyConsume - + VecZnxBigNormalize - + VecZnxNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, - { - #[cfg(debug_assertions)] - { - use crate::layouts::LWEInfos; - - assert_eq!(lhs.n(), self.n()); - assert_eq!(rhs.n(), self.n()); - - assert_eq!( - self.rank(), - lhs.rank(), - "ggsw_out rank: {} != ggsw_in rank: {}", - self.rank(), - lhs.rank() - ); - assert_eq!( - self.rank(), - rhs.rank(), - "ggsw_in rank: {} != ggsw_apply rank: {}", - self.rank(), - rhs.rank() - ); - - assert!(scratch.available() >= GGSWCiphertext::external_product_scratch_space(module, self, lhs, rhs)) - } - - let min_dnum: usize = self.dnum().min(lhs.dnum()).into(); - - (0..(self.rank() + 1).into()).for_each(|col_i| { - (0..min_dnum).for_each(|row_j| { - self.at_mut(row_j, col_i) - .external_product(module, &lhs.at(row_j, col_i), rhs, scratch); - }); - (min_dnum..self.dnum().into()).for_each(|row_i| { - self.at_mut(row_i, col_i).data.zero(); - }); - }); - } - - pub fn external_product_inplace( - &mut self, - module: &Module, - rhs: &GGSWCiphertextPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxNormalizeTmpBytes - + VecZnxDftApply - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxIdftApplyConsume - + VecZnxBigNormalize - + VecZnxNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, - { - #[cfg(debug_assertions)] - { - use crate::layouts::LWEInfos; - - assert_eq!(rhs.n(), self.n()); - assert_eq!( - self.rank(), - rhs.rank(), - "ggsw_out rank: {} != ggsw_apply: {}", - self.rank(), - rhs.rank() - ); - } - - (0..(self.rank() + 1).into()).for_each(|col_i| { - (0..self.dnum().into()).for_each(|row_j| { - self.at_mut(row_j, col_i) - .external_product_inplace(module, rhs, scratch); - }); - }); - } -} diff --git a/poulpy-core/src/external_product/glwe_ct.rs b/poulpy-core/src/external_product/glwe.rs similarity index 66% rename from poulpy-core/src/external_product/glwe_ct.rs rename to poulpy-core/src/external_product/glwe.rs index d764507..a1cd5ee 100644 --- a/poulpy-core/src/external_product/glwe_ct.rs +++ b/poulpy-core/src/external_product/glwe.rs @@ -1,102 +1,57 @@ use poulpy_hal::{ api::{ - ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, - VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, - VmpApplyDftToDftTmpBytes, + ModuleN, ScratchTakeBasic, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, + VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, - layouts::{Backend, DataMut, DataRef, DataViewMut, Module, Scratch, VecZnx, VecZnxBig}, + layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnx, VecZnxBig}, }; use crate::{ - GLWEExternalProduct, GLWEExternalProductInplace, + ScratchTakeCore, layouts::{ - GGSWInfos, GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEInfos, LWEInfos, - prepared::{GGSWCiphertextPrepared, GGSWCiphertextPreparedToRef}, + GGSWInfos, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, + prepared::{GGSWPrepared, GGSWPreparedToRef}, }, }; -impl GLWECiphertext> { - #[allow(clippy::too_many_arguments)] - pub fn external_product_scratch_space( - module: &Module, - out_infos: &OUT, - in_infos: &IN, - apply_infos: &GGSW, - ) -> usize +impl GLWE> { + pub fn external_product_tmp_bytes(module: &M, res_infos: &R, a_infos: &A, b_infos: &B) -> usize where - OUT: GLWEInfos, - IN: GLWEInfos, - GGSW: GGSWInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes, + R: GLWEInfos, + A: GLWEInfos, + B: GGSWInfos, + M: GLWEExternalProduct, { - let in_size: usize = in_infos - .k() - .div_ceil(apply_infos.base2k()) - .div_ceil(apply_infos.dsize().into()) as usize; - let out_size: usize = out_infos.size(); - let ggsw_size: usize = apply_infos.size(); - let res_dft: usize = module.vec_znx_dft_alloc_bytes((apply_infos.rank() + 1).into(), ggsw_size); - let a_dft: usize = module.vec_znx_dft_alloc_bytes((apply_infos.rank() + 1).into(), in_size); - let vmp: usize = module.vmp_apply_dft_to_dft_tmp_bytes( - out_size, - in_size, - in_size, // rows - (apply_infos.rank() + 1).into(), // cols in - (apply_infos.rank() + 1).into(), // cols out - ggsw_size, - ); - let normalize_big: usize = module.vec_znx_normalize_tmp_bytes(); - - if in_infos.base2k() == apply_infos.base2k() { - res_dft + a_dft + (vmp | normalize_big) - } else { - let normalize_conv: usize = VecZnx::alloc_bytes(module.n(), (apply_infos.rank() + 1).into(), in_size); - res_dft + ((a_dft + normalize_conv + (module.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big) - } - } - - pub fn external_product_inplace_scratch_space( - module: &Module, - out_infos: &OUT, - apply_infos: &GGSW, - ) -> usize - where - OUT: GLWEInfos, - GGSW: GGSWInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes, - { - Self::external_product_scratch_space(module, out_infos, out_infos, apply_infos) + module.glwe_external_product_tmp_bytes(res_infos, a_infos, b_infos) } } -impl GLWECiphertext { - pub fn external_product( - &mut self, - module: &Module, - lhs: &GLWECiphertext, - rhs: &GGSWCiphertextPrepared, - scratch: &mut Scratch, - ) where - Module: GLWEExternalProduct, +impl GLWE { + pub fn external_product(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch) + where + A: GLWEToRef, + B: GGSWPreparedToRef, + M: GLWEExternalProduct, + Scratch: ScratchTakeCore, { - module.external_product(self, lhs, rhs, scratch); + module.glwe_external_product(self, a, b, scratch); } - pub fn external_product_inplace( - &mut self, - module: &Module, - rhs: &GGSWCiphertextPrepared, - scratch: &mut Scratch, - ) where - Module: GLWEExternalProductInplace, + pub fn external_product_inplace(&mut self, module: &M, a: &A, scratch: &mut Scratch) + where + A: GGSWPreparedToRef, + M: GLWEExternalProduct, + Scratch: ScratchTakeCore, { - module.external_product_inplace(self, rhs, scratch); + module.glwe_external_product_inplace(self, a, scratch); } } -impl GLWEExternalProductInplace for Module +pub trait GLWEExternalProduct where - Module: VecZnxDftAllocBytes + Self: Sized + + ModuleN + + VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes + VecZnxDftApply @@ -105,15 +60,47 @@ where + VecZnxIdftApplyConsume + VecZnxBigNormalize + VecZnxNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { - fn external_product_inplace(&self, res: &mut R, ggsw: &D, scratch: &mut Scratch) + fn glwe_external_product_tmp_bytes(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize where - R: GLWECiphertextToMut, - D: GGSWCiphertextPreparedToRef, + R: GLWEInfos, + A: GLWEInfos, + B: GGSWInfos, { - let res: &mut GLWECiphertext<&mut [u8]> = &mut res.to_mut(); - let rhs: &GGSWCiphertextPrepared<&[u8], BE> = &ggsw.to_ref(); + let in_size: usize = a_infos + .k() + .div_ceil(b_infos.base2k()) + .div_ceil(b_infos.dsize().into()) as usize; + let out_size: usize = res_infos.size(); + let ggsw_size: usize = b_infos.size(); + let res_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), ggsw_size); + let a_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), in_size); + let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes( + out_size, + in_size, + in_size, // rows + (b_infos.rank() + 1).into(), // cols in + (b_infos.rank() + 1).into(), // cols out + ggsw_size, + ); + let normalize_big: usize = self.vec_znx_normalize_tmp_bytes(); + + if a_infos.base2k() == b_infos.base2k() { + res_dft + a_dft + (vmp | normalize_big) + } else { + let normalize_conv: usize = VecZnx::bytes_of(self.n(), (b_infos.rank() + 1).into(), in_size); + res_dft + ((a_dft + normalize_conv + (self.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big) + } + } + + fn glwe_external_product_inplace(&self, res: &mut R, a: &D, scratch: &mut Scratch) + where + R: GLWEToMut, + D: GGSWPreparedToRef, + Scratch: ScratchTakeCore, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let rhs: &GGSWPrepared<&[u8], BE> = &a.to_ref(); let basek_in: usize = res.base2k().into(); let basek_ggsw: usize = rhs.base2k().into(); @@ -124,15 +111,15 @@ where assert_eq!(rhs.rank(), res.rank()); assert_eq!(rhs.n(), res.n()); - assert!(scratch.available() >= GLWECiphertext::external_product_inplace_scratch_space(self, res, rhs)); + assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, res, rhs)); } let cols: usize = (rhs.rank() + 1).into(); let dsize: usize = rhs.dsize().into(); let a_size: usize = (res.size() * basek_in).div_ceil(basek_ggsw); - let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(res.n().into(), cols, rhs.size()); // Todo optimise - let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(res.n().into(), cols, a_size.div_ceil(dsize)); + let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, rhs.size()); // Todo optimise + let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, cols, a_size.div_ceil(dsize)); a_dft.data_mut().fill(0); if basek_in == basek_ggsw { @@ -213,31 +200,18 @@ where ); } } -} -impl GLWEExternalProduct for Module -where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxNormalizeTmpBytes - + VecZnxDftApply - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxIdftApplyConsume - + VecZnxBigNormalize - + VecZnxNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, -{ - fn external_product(&self, res: &mut R, lhs: &A, rhs: &D, scratch: &mut Scratch) + fn glwe_external_product(&self, res: &mut R, lhs: &A, rhs: &D, scratch: &mut Scratch) where - R: GLWECiphertextToMut, - A: GLWECiphertextToRef, - D: GGSWCiphertextPreparedToRef, + R: GLWEToMut, + A: GLWEToRef, + D: GGSWPreparedToRef, + Scratch: ScratchTakeCore, { - let res: &mut GLWECiphertext<&mut [u8]> = &mut res.to_mut(); - let lhs: &GLWECiphertext<&[u8]> = &lhs.to_ref(); + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let lhs: &GLWE<&[u8]> = &lhs.to_ref(); - let rhs: &GGSWCiphertextPrepared<&[u8], BE> = &rhs.to_ref(); + let rhs: &GGSWPrepared<&[u8], BE> = &rhs.to_ref(); let basek_in: usize = lhs.base2k().into(); let basek_ggsw: usize = rhs.base2k().into(); @@ -251,7 +225,7 @@ where assert_eq!(rhs.rank(), res.rank()); assert_eq!(rhs.n(), res.n()); assert_eq!(lhs.n(), res.n()); - assert!(scratch.available() >= GLWECiphertext::external_product_scratch_space(self, res, lhs, rhs)); + assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, lhs, rhs)); } let cols: usize = (rhs.rank() + 1).into(); @@ -259,8 +233,8 @@ where let a_size: usize = (lhs.size() * basek_in).div_ceil(basek_ggsw); - let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), cols, rhs.size()); // Todo optimise - let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self.n(), cols, a_size.div_ceil(dsize)); + let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, rhs.size()); // Todo optimise + let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, cols, a_size.div_ceil(dsize)); a_dft.data_mut().fill(0); if basek_in == basek_ggsw { @@ -342,3 +316,20 @@ where }); } } + +impl GLWEExternalProduct for Module where + Self: ModuleN + + VecZnxDftBytesOf + + VmpApplyDftToDftTmpBytes + + VecZnxNormalizeTmpBytes + + VecZnxDftApply + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VecZnxIdftApplyConsume + + VecZnxBigNormalize + + VecZnxNormalize + + VecZnxDftBytesOf + + VmpApplyDftToDftTmpBytes + + VecZnxNormalizeTmpBytes +{ +} diff --git a/poulpy-core/src/external_product/mod.rs b/poulpy-core/src/external_product/mod.rs index f46b82c..cf5e6c0 100644 --- a/poulpy-core/src/external_product/mod.rs +++ b/poulpy-core/src/external_product/mod.rs @@ -1,23 +1,7 @@ -use poulpy_hal::layouts::{Backend, Scratch}; +mod gglwe; +mod ggsw; +mod glwe; -use crate::layouts::{GLWECiphertextToMut, GLWECiphertextToRef, prepared::GGSWCiphertextPreparedToRef}; - -mod gglwe_atk; -mod gglwe_ksk; -mod ggsw_ct; -mod glwe_ct; - -pub trait GLWEExternalProduct { - fn external_product(&self, res: &mut R, a: &A, ggsw: &D, scratch: &mut Scratch) - where - R: GLWECiphertextToMut, - A: GLWECiphertextToRef, - D: GGSWCiphertextPreparedToRef; -} - -pub trait GLWEExternalProductInplace { - fn external_product_inplace(&self, res: &mut R, ggsw: &D, scratch: &mut Scratch) - where - R: GLWECiphertextToMut, - D: GGSWCiphertextPreparedToRef; -} +pub use gglwe::*; +pub use ggsw::*; +pub use glwe::*; diff --git a/poulpy-core/src/glwe_packing.rs b/poulpy-core/src/glwe_packing.rs index 3da962a..09540b2 100644 --- a/poulpy-core/src/glwe_packing.rs +++ b/poulpy-core/src/glwe_packing.rs @@ -1,19 +1,14 @@ use std::collections::HashMap; use poulpy_hal::{ - api::{ - ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, - VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxCopy, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNegateInplace, - VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, - VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, - }, - layouts::{Backend, DataMut, DataRef, Module, Scratch}, + api::ModuleLogN, + layouts::{Backend, GaloisElement, Module, Scratch}, }; use crate::{ - GLWEOperations, TakeGLWECt, - layouts::{GGLWEInfos, GLWECiphertext, GLWEInfos, LWEInfos, prepared::GGLWEAutomorphismKeyPrepared}, + GLWEAdd, GLWEAutomorphism, GLWECopy, GLWENormalize, GLWERotate, GLWEShift, GLWESub, ScratchTakeCore, + glwe_trace::GLWETrace, + layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, GetGaloisElement, LWEInfos}, }; /// [GLWEPacker] enables only the fly GLWE packing @@ -29,7 +24,7 @@ pub struct GLWEPacker { /// [Accumulator] stores intermediate packing result. /// There are Log(N) such accumulators in a [GLWEPacker]. struct Accumulator { - data: GLWECiphertext>, + data: GLWE>, value: bool, // Implicit flag for zero ciphertext control: bool, // Can be combined with incoming value } @@ -48,7 +43,7 @@ impl Accumulator { A: GLWEInfos, { Self { - data: GLWECiphertext::alloc(infos), + data: GLWE::alloc_from_infos(infos), value: false, control: false, } @@ -60,20 +55,19 @@ impl GLWEPacker { /// /// # Arguments /// - /// * `module`: static backend FFT tables. /// * `log_batch`: packs coefficients which are multiples of X^{N/2^log_batch}. /// i.e. with `log_batch=0` only the constant coefficient is packed /// and N GLWE ciphertext can be packed. With `log_batch=2` all coefficients /// which are multiples of X^{N/4} are packed. Meaning that N/4 ciphertexts /// can be packed. - pub fn new(infos: &A, log_batch: usize) -> Self + pub fn alloc(infos: &A, log_batch: usize) -> Self where A: GLWEInfos, { let mut accumulators: Vec = Vec::::new(); let log_n: usize = infos.n().log2(); (0..log_n - log_batch).for_each(|_| accumulators.push(Accumulator::alloc(infos))); - Self { + GLWEPacker { accumulators, log_batch, counter: 0, @@ -90,17 +84,23 @@ impl GLWEPacker { } /// Number of scratch space bytes required to call [Self::add]. - pub fn scratch_space(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize + pub fn tmp_bytes(module: &M, res_infos: &R, key_infos: &K) -> usize where - OUT: GLWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, + R: GLWEInfos, + K: GGLWEInfos, + M: GLWEPacking, { - pack_core_scratch_space(module, out_infos, key_infos) + GLWE::bytes_of_from_infos(res_infos) + + module + .glwe_rsh_tmp_byte() + .max(module.glwe_automorphism_tmp_bytes(res_infos, res_infos, key_infos)) } - pub fn galois_elements(module: &Module) -> Vec { - GLWECiphertext::trace_galois_elements(module) + pub fn galois_elements(module: &M) -> Vec + where + M: GLWETrace, + { + module.glwe_trace_galois_elements() } /// Adds a GLWE ciphertext to the [GLWEPacker]. @@ -111,38 +111,13 @@ impl GLWEPacker { /// of packed ciphertexts reaches N/2^log_batch is a result written. /// * `a`: ciphertext to pack. Can optionally give None to pack a 0 ciphertext. /// * `auto_keys`: a [HashMap] containing the [AutomorphismKeyExec]s. - /// * `scratch`: scratch space of size at least [Self::scratch_space]. - pub fn add( - &mut self, - module: &Module, - a: Option<&GLWECiphertext>, - auto_keys: &HashMap>, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxCopy - + VecZnxRotateInplace - + VecZnxSub - + VecZnxNegateInplace - + VecZnxRshInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxSubInplace - + VecZnxRotate - + VecZnxAutomorphismInplace - + VecZnxBigSubSmallNegateInplace - + VecZnxBigAutomorphismInplace - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + /// * `scratch`: scratch space of size at least [Self::tmp_bytes]. + pub fn add(&mut self, module: &M, a: Option<&A>, auto_keys: &HashMap, scratch: &mut Scratch) + where + A: GLWEToRef + GLWEInfos, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, + M: GLWEPacking, + Scratch: ScratchTakeCore, { assert!( (self.counter as u32) < self.accumulators[0].data.n(), @@ -162,14 +137,15 @@ impl GLWEPacker { } /// Flush result to`res`. - pub fn flush(&mut self, module: &Module, res: &mut GLWECiphertext) + pub fn flush(&mut self, module: &M, res: &mut R) where - Module: VecZnxCopy, + R: GLWEToMut, + M: GLWEPacking, { assert!(self.counter as u32 == self.accumulators[0].data.n()); // Copy result GLWE into res GLWE - res.copy( - module, + module.glwe_copy( + res, &self.accumulators[module.log_n() - self.log_batch - 1].data, ); @@ -177,47 +153,96 @@ impl GLWEPacker { } } -fn pack_core_scratch_space(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize -where - OUT: GLWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, +impl GLWEPacking for Module where + Self: GLWEAutomorphism + + GaloisElement + + ModuleLogN + + GLWERotate + + GLWESub + + GLWEShift + + GLWEAdd + + GLWENormalize + + GLWECopy { - combine_scratch_space(module, out_infos, key_infos) } -fn pack_core( - module: &Module, - a: Option<&GLWECiphertext>, +pub trait GLWEPacking +where + Self: GLWEAutomorphism + + GaloisElement + + ModuleLogN + + GLWERotate + + GLWESub + + GLWEShift + + GLWEAdd + + GLWENormalize + + GLWECopy, +{ + /// Packs [x_0: GLWE(m_0), x_1: GLWE(m_1), ..., x_i: GLWE(m_i)] + /// to [0: GLWE(m_0 * X^x_0 + m_1 * X^x_1 + ... + m_i * X^x_i)] + fn glwe_pack( + &self, + cts: &mut HashMap, + log_gap_out: usize, + keys: &HashMap, + scratch: &mut Scratch, + ) where + R: GLWEToMut + GLWEToRef + GLWEInfos, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, + Scratch: ScratchTakeCore, + { + #[cfg(debug_assertions)] + { + assert!(*cts.keys().max().unwrap() < self.n()) + } + + let log_n: usize = self.log_n(); + + for i in 0..(log_n - log_gap_out) { + let t: usize = (1 << log_n).min(1 << (log_n - 1 - i)); + + let key: &K = if i == 0 { + keys.get(&-1).unwrap() + } else { + keys.get(&self.galois_element(1 << (i - 1))).unwrap() + }; + + for j in 0..t { + let mut a: Option<&mut R> = cts.remove(&j); + let mut b: Option<&mut R> = cts.remove(&(j + t)); + + pack_internal(self, &mut a, &mut b, i, key, scratch); + + if let Some(a) = a { + cts.insert(j, a); + } else if let Some(b) = b { + cts.insert(j, b); + } + } + } + } +} + +fn pack_core( + module: &M, + a: Option<&A>, accumulators: &mut [Accumulator], i: usize, - auto_keys: &HashMap>, - scratch: &mut Scratch, + auto_keys: &HashMap, + scratch: &mut Scratch, ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxCopy - + VecZnxRotateInplace - + VecZnxSub - + VecZnxNegateInplace - + VecZnxRshInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxSubInplace - + VecZnxRotate - + VecZnxAutomorphismInplace - + VecZnxBigSubSmallNegateInplace - + VecZnxBigAutomorphismInplace - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + A: GLWEToRef + GLWEInfos, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, + M: ModuleLogN + + GLWEAutomorphism + + GaloisElement + + GLWERotate + + GLWESub + + GLWEShift + + GLWEAdd + + GLWENormalize + + GLWECopy, + Scratch: ScratchTakeCore, { let log_n: usize = module.log_n(); @@ -234,7 +259,7 @@ fn pack_core( // No previous value -> copies and sets flags accordingly if let Some(a_ref) = a { - acc_mut_ref.data.copy(module, a_ref); + module.glwe_copy(&mut acc_mut_ref.data, a_ref); acc_mut_ref.value = true } else { acc_mut_ref.value = false @@ -258,7 +283,7 @@ fn pack_core( } else { pack_core( module, - None::<&GLWECiphertext>>, + None::<&GLWE>>, acc_next, i + 1, auto_keys, @@ -268,53 +293,23 @@ fn pack_core( } } -fn combine_scratch_space(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize -where - OUT: GLWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, -{ - GLWECiphertext::alloc_bytes(out_infos) - + (GLWECiphertext::rsh_scratch_space(module.n()) - | GLWECiphertext::automorphism_inplace_scratch_space(module, out_infos, key_infos)) -} - /// [combine] merges two ciphertexts together. -fn combine( - module: &Module, +fn combine( + module: &M, acc: &mut Accumulator, - b: Option<&GLWECiphertext>, + b: Option<&B>, i: usize, - auto_keys: &HashMap>, - scratch: &mut Scratch, + auto_keys: &HashMap, + scratch: &mut Scratch, ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxCopy - + VecZnxRotateInplace - + VecZnxSub - + VecZnxNegateInplace - + VecZnxRshInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxSubInplace - + VecZnxRotate - + VecZnxAutomorphismInplace - + VecZnxBigSubSmallNegateInplace - + VecZnxBigAutomorphismInplace - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeGLWECt, + B: GLWEToRef + GLWEInfos, + M: GLWEAutomorphism + GaloisElement + GLWERotate + GLWESub + GLWEShift + GLWEAdd + GLWENormalize, + B: GLWEToRef + GLWEInfos, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, + Scratch: ScratchTakeCore, { let log_n: usize = acc.data.n().log2(); - let a: &mut GLWECiphertext> = &mut acc.data; + let a: &mut GLWE> = &mut acc.data; let gal_el: i64 = if i == 0 { -1 @@ -336,53 +331,53 @@ fn combine( // since 2*(I(X) * Q/2) = I(X) * Q = 0 mod Q. if acc.value { if let Some(b) = b { - let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(a); + let (mut tmp_b, scratch_1) = scratch.take_glwe(a); // a = a * X^-t - a.rotate_inplace(module, -t, scratch_1); + module.glwe_rotate_inplace(-t, a, scratch_1); // tmp_b = a * X^-t - b - tmp_b.sub(module, a, b); - tmp_b.rsh(module, 1, scratch_1); + module.glwe_sub(&mut tmp_b, a, b); + module.glwe_rsh(1, &mut tmp_b, scratch_1); // a = a * X^-t + b - a.add_inplace(module, b); - a.rsh(module, 1, scratch_1); + module.glwe_add_inplace(a, b); + module.glwe_rsh(1, a, scratch_1); - tmp_b.normalize_inplace(module, scratch_1); + module.glwe_normalize_inplace(&mut tmp_b, scratch_1); // tmp_b = phi(a * X^-t - b) - if let Some(key) = auto_keys.get(&gal_el) { - tmp_b.automorphism_inplace(module, key, scratch_1); + if let Some(auto_key) = auto_keys.get(&gal_el) { + module.glwe_automorphism_inplace(&mut tmp_b, auto_key, scratch_1); } else { panic!("auto_key[{gal_el}] not found"); } // a = a * X^-t + b - phi(a * X^-t - b) - a.sub_inplace_ab(module, &tmp_b); - a.normalize_inplace(module, scratch_1); + module.glwe_sub_inplace(a, &tmp_b); + module.glwe_normalize_inplace(a, scratch_1); // a = a + b * X^t - phi(a * X^-t - b) * X^t // = a + b * X^t - phi(a * X^-t - b) * - phi(X^t) // = a + b * X^t + phi(a - b * X^t) - a.rotate_inplace(module, t, scratch_1); + module.glwe_rotate_inplace(t, a, scratch_1); } else { - a.rsh(module, 1, scratch); + module.glwe_rsh(1, a, scratch); // a = a + phi(a) - if let Some(key) = auto_keys.get(&gal_el) { - a.automorphism_add_inplace(module, key, scratch); + if let Some(auto_key) = auto_keys.get(&gal_el) { + module.glwe_automorphism_add_inplace(a, auto_key, scratch); } else { panic!("auto_key[{gal_el}] not found"); } } } else if let Some(b) = b { - let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(a); - tmp_b.rotate(module, 1 << (log_n - i - 1), b); - tmp_b.rsh(module, 1, scratch_1); + let (mut tmp_b, scratch_1) = scratch.take_glwe(a); + module.glwe_rotate(t, &mut tmp_b, b); + module.glwe_rsh(1, &mut tmp_b, scratch_1); // a = (b* X^t - phi(b* X^t)) - if let Some(key) = auto_keys.get(&gal_el) { - a.automorphism_sub_negate(module, &tmp_b, key, scratch_1); + if let Some(auto_key) = auto_keys.get(&gal_el) { + module.glwe_automorphism_sub_negate(a, &tmp_b, auto_key, scratch_1); } else { panic!("auto_key[{gal_el}] not found"); } @@ -391,110 +386,20 @@ fn combine( } } -/// Packs [x_0: GLWE(m_0), x_1: GLWE(m_1), ..., x_i: GLWE(m_i)] -/// to [0: GLWE(m_0 * X^x_0 + m_1 * X^x_1 + ... + m_i * X^x_i)] -pub fn glwe_packing( - module: &Module, - cts: &mut HashMap>, - log_gap_out: usize, - auto_keys: &HashMap>, - scratch: &mut Scratch, -) where - ATK: DataRef, - Module: VecZnxRotateInplace - + VecZnxNormalizeInplace - + VecZnxNormalizeTmpBytes - + VecZnxSwitchRing - + VecZnxBigAutomorphismInplace - + VecZnxRshInplace - + VecZnxDftCopy - + VecZnxIdftApplyTmpA - + VecZnxSub - + VecZnxAddInplace - + VecZnxNegateInplace - + VecZnxCopy - + VecZnxSubInplace - + VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxAutomorphismInplace - + VecZnxBigSubSmallNegateInplace - + VecZnxRotate - + VecZnxNormalize, - Scratch: TakeVecZnx + TakeVecZnxDft + ScratchAvailable, -{ - #[cfg(debug_assertions)] - { - assert!(*cts.keys().max().unwrap() < module.n()) - } - - let log_n: usize = module.log_n(); - - (0..log_n - log_gap_out).for_each(|i| { - let t: usize = (1 << log_n).min(1 << (log_n - 1 - i)); - - let auto_key: &GGLWEAutomorphismKeyPrepared = if i == 0 { - auto_keys.get(&-1).unwrap() - } else { - auto_keys.get(&module.galois_element(1 << (i - 1))).unwrap() - }; - - (0..t).for_each(|j| { - let mut a: Option<&mut GLWECiphertext> = cts.remove(&j); - let mut b: Option<&mut GLWECiphertext> = cts.remove(&(j + t)); - - pack_internal(module, &mut a, &mut b, i, auto_key, scratch); - - if let Some(a) = a { - cts.insert(j, a); - } else if let Some(b) = b { - cts.insert(j, b); - } - }); - }); -} - #[allow(clippy::too_many_arguments)] -fn pack_internal( - module: &Module, - a: &mut Option<&mut GLWECiphertext>, - b: &mut Option<&mut GLWECiphertext>, +fn pack_internal( + module: &M, + a: &mut Option<&mut A>, + b: &mut Option<&mut B>, i: usize, - auto_key: &GGLWEAutomorphismKeyPrepared, - scratch: &mut Scratch, + auto_key: &K, + scratch: &mut Scratch, ) where - Module: VecZnxRotateInplace - + VecZnxNormalizeInplace - + VecZnxNormalizeTmpBytes - + VecZnxBigAutomorphismInplace - + VecZnxRshInplace - + VecZnxDftCopy - + VecZnxIdftApplyTmpA - + VecZnxSub - + VecZnxAddInplace - + VecZnxNegateInplace - + VecZnxCopy - + VecZnxSubInplace - + VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxAutomorphismInplace - + VecZnxBigSubSmallNegateInplace - + VecZnxRotate - + VecZnxNormalize, - Scratch: TakeVecZnx + TakeVecZnxDft + ScratchAvailable, + M: GLWEAutomorphism + GLWERotate + GLWESub + GLWEShift + GLWEAdd + GLWENormalize, + A: GLWEToMut + GLWEToRef + GLWEInfos, + B: GLWEToMut + GLWEToRef + GLWEInfos, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, + Scratch: ScratchTakeCore, { // Goal is to evaluate: a = a + b*X^t + phi(a - b*X^t)) // We also use the identity: AUTO(a * X^t, g) = -X^t * AUTO(a, g) @@ -510,45 +415,45 @@ fn pack_internal( let t: i64 = 1 << (a.n().log2() - i - 1); if let Some(b) = b.as_deref_mut() { - let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(a); + let (mut tmp_b, scratch_1) = scratch.take_glwe(a); // a = a * X^-t - a.rotate_inplace(module, -t, scratch_1); + module.glwe_rotate_inplace(-t, a, scratch_1); // tmp_b = a * X^-t - b - tmp_b.sub(module, a, b); - tmp_b.rsh(module, 1, scratch_1); + module.glwe_sub(&mut tmp_b, a, b); + module.glwe_rsh(1, &mut tmp_b, scratch_1); // a = a * X^-t + b - a.add_inplace(module, b); - a.rsh(module, 1, scratch_1); + module.glwe_add_inplace(a, b); + module.glwe_rsh(1, a, scratch_1); - tmp_b.normalize_inplace(module, scratch_1); + module.glwe_normalize_inplace(&mut tmp_b, scratch_1); // tmp_b = phi(a * X^-t - b) - tmp_b.automorphism_inplace(module, auto_key, scratch_1); + module.glwe_automorphism_inplace(&mut tmp_b, auto_key, scratch_1); // a = a * X^-t + b - phi(a * X^-t - b) - a.sub_inplace_ab(module, &tmp_b); - a.normalize_inplace(module, scratch_1); + module.glwe_sub_inplace(a, &tmp_b); + module.glwe_normalize_inplace(a, scratch_1); // a = a + b * X^t - phi(a * X^-t - b) * X^t // = a + b * X^t - phi(a * X^-t - b) * - phi(X^t) // = a + b * X^t + phi(a - b * X^t) - a.rotate_inplace(module, t, scratch_1); + module.glwe_rotate_inplace(t, a, scratch_1); } else { - a.rsh(module, 1, scratch); + module.glwe_rsh(1, a, scratch); // a = a + phi(a) - a.automorphism_add_inplace(module, auto_key, scratch); + module.glwe_automorphism_add_inplace(a, auto_key, scratch); } } else if let Some(b) = b.as_deref_mut() { let t: i64 = 1 << (b.n().log2() - i - 1); - let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(b); - tmp_b.rotate(module, t, b); - tmp_b.rsh(module, 1, scratch_1); + let (mut tmp_b, scratch_1) = scratch.take_glwe(b); + module.glwe_rotate(t, &mut tmp_b, b); + module.glwe_rsh(1, &mut tmp_b, scratch_1); // a = (b* X^t - phi(b* X^t)) - b.automorphism_sub_negate(module, &tmp_b, auto_key, scratch_1); + module.glwe_automorphism_sub_negate(b, &tmp_b, auto_key, scratch_1); } } diff --git a/poulpy-core/src/glwe_trace.rs b/poulpy-core/src/glwe_trace.rs index 4e1769e..c2ba15c 100644 --- a/poulpy-core/src/glwe_trace.rs +++ b/poulpy-core/src/glwe_trace.rs @@ -1,181 +1,189 @@ use std::collections::HashMap; use poulpy_hal::{ - api::{ - ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace, VecZnxBigNormalize, - VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalize, - VecZnxNormalizeTmpBytes, VecZnxRshInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, - }, - layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx}, + api::ModuleLogN, + layouts::{Backend, DataMut, GaloisElement, Module, Scratch, VecZnx, galois_element}, }; use crate::{ - TakeGLWECt, + GLWEAutomorphism, GLWECopy, GLWEShift, ScratchTakeCore, layouts::{ - Base2K, GGLWEInfos, GLWECiphertext, GLWECiphertextLayout, GLWEInfos, LWEInfos, prepared::GGLWEAutomorphismKeyPrepared, + Base2K, GGLWEInfos, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWELayout, GLWEToMut, GLWEToRef, GetGaloisElement, LWEInfos, }, - operations::GLWEOperations, }; -impl GLWECiphertext> { - pub fn trace_galois_elements(module: &Module) -> Vec { - let mut gal_els: Vec = Vec::new(); - (0..module.log_n()).for_each(|i| { - if i == 0 { - gal_els.push(-1); - } else { - gal_els.push(module.galois_element(1 << (i - 1))); - } - }); - gal_els +impl GLWE> { + pub fn trace_galois_elements(module: &M) -> Vec + where + M: GLWETrace, + { + module.glwe_trace_galois_elements() } - pub fn trace_scratch_space( - module: &Module, - out_infos: &OUT, - in_infos: &IN, - key_infos: &KEY, - ) -> usize + pub fn trace_tmp_bytes(module: &M, res_infos: &R, a_infos: &A, key_infos: &K) -> usize where - OUT: GLWEInfos, - IN: GLWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, + R: GLWEInfos, + A: GLWEInfos, + K: GGLWEInfos, + M: GLWETrace, { - let trace: usize = Self::automorphism_inplace_scratch_space(module, out_infos, key_infos); - if in_infos.base2k() != key_infos.base2k() { - let glwe_conv: usize = VecZnx::alloc_bytes( - module.n(), + module.glwe_automorphism_tmp_bytes(res_infos, a_infos, key_infos) + } +} + +impl GLWE { + pub fn trace( + &mut self, + module: &M, + start: usize, + end: usize, + a: &A, + keys: &HashMap, + scratch: &mut Scratch, + ) where + A: GLWEToRef, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, + Scratch: ScratchTakeCore, + M: GLWETrace, + { + module.glwe_trace(self, start, end, a, keys, scratch); + } + + pub fn trace_inplace( + &mut self, + module: &M, + start: usize, + end: usize, + keys: &HashMap, + scratch: &mut Scratch, + ) where + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, + Scratch: ScratchTakeCore, + M: GLWETrace, + { + module.glwe_trace_inplace(self, start, end, keys, scratch); + } +} + +impl GLWETrace for Module where + Self: ModuleLogN + GaloisElement + GLWEAutomorphism + GLWEShift + GLWECopy +{ +} + +#[inline(always)] +pub fn trace_galois_elements(log_n: usize, cyclotomic_order: i64) -> Vec { + (0..log_n) + .map(|i| { + if i == 0 { + -1 + } else { + galois_element(1 << (i - 1), cyclotomic_order) + } + }) + .collect() +} + +pub trait GLWETrace +where + Self: ModuleLogN + GaloisElement + GLWEAutomorphism + GLWEShift + GLWECopy, +{ + fn glwe_trace_galois_elements(&self) -> Vec { + trace_galois_elements(self.log_n(), self.cyclotomic_order()) + } + + fn glwe_trace_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize + where + R: GLWEInfos, + A: GLWEInfos, + K: GGLWEInfos, + { + let trace: usize = self.glwe_automorphism_tmp_bytes(res_infos, a_infos, key_infos); + if a_infos.base2k() != key_infos.base2k() { + let glwe_conv: usize = VecZnx::bytes_of( + self.n(), (key_infos.rank_out() + 1).into(), - out_infos.k().min(in_infos.k()).div_ceil(key_infos.base2k()) as usize, - ) + module.vec_znx_normalize_tmp_bytes(); + res_infos.k().min(a_infos.k()).div_ceil(key_infos.base2k()) as usize, + ) + self.vec_znx_normalize_tmp_bytes(); return glwe_conv + trace; } trace } - pub fn trace_inplace_scratch_space(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize + fn glwe_trace(&self, res: &mut R, start: usize, end: usize, a: &A, keys: &HashMap, scratch: &mut Scratch) where - OUT: GLWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, + R: GLWEToMut, + A: GLWEToRef, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, + Scratch: ScratchTakeCore, { - Self::trace_scratch_space(module, out_infos, out_infos, key_infos) - } -} - -impl GLWECiphertext { - pub fn trace( - &mut self, - module: &Module, - start: usize, - end: usize, - lhs: &GLWECiphertext, - auto_keys: &HashMap>, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxBigAutomorphismInplace - + VecZnxRshInplace - + VecZnxCopy - + VecZnxNormalizeTmpBytes - + VecZnxNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, - { - self.copy(module, lhs); - self.trace_inplace(module, start, end, auto_keys, scratch); + self.glwe_copy(res, a); + self.glwe_trace_inplace(res, start, end, keys, scratch); } - pub fn trace_inplace( - &mut self, - module: &Module, - start: usize, - end: usize, - auto_keys: &HashMap>, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxBigAutomorphismInplace - + VecZnxRshInplace - + VecZnxNormalizeTmpBytes - + VecZnxNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + fn glwe_trace_inplace(&self, res: &mut R, start: usize, end: usize, keys: &HashMap, scratch: &mut Scratch) + where + R: GLWEToMut, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, + Scratch: ScratchTakeCore, { - let basek_ksk: Base2K = auto_keys - .get(auto_keys.keys().next().unwrap()) - .unwrap() - .base2k(); + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + + let basek_ksk: Base2K = keys.get(keys.keys().next().unwrap()).unwrap().base2k(); #[cfg(debug_assertions)] { - assert_eq!(self.n(), module.n() as u32); + assert_eq!(res.n(), self.n() as u32); assert!(start < end); - assert!(end <= module.log_n()); - for key in auto_keys.values() { - assert_eq!(key.n(), module.n() as u32); + assert!(end <= self.log_n()); + for key in keys.values() { + assert_eq!(key.n(), self.n() as u32); assert_eq!(key.base2k(), basek_ksk); - assert_eq!(key.rank_in(), self.rank()); - assert_eq!(key.rank_out(), self.rank()); + assert_eq!(key.rank_in(), res.rank()); + assert_eq!(key.rank_out(), res.rank()); } } - if self.base2k() != basek_ksk { - let (mut self_conv, scratch_1) = scratch.take_glwe_ct(&GLWECiphertextLayout { - n: module.n().into(), + if res.base2k() != basek_ksk { + let (mut self_conv, scratch_1) = scratch.take_glwe(&GLWELayout { + n: self.n().into(), base2k: basek_ksk, - k: self.k(), - rank: self.rank(), + k: res.k(), + rank: res.rank(), }); - for j in 0..(self.rank() + 1).into() { - module.vec_znx_normalize( + for j in 0..(res.rank() + 1).into() { + self.vec_znx_normalize( basek_ksk.into(), &mut self_conv.data, j, basek_ksk.into(), - &self.data, + res.data(), j, scratch_1, ); } for i in start..end { - self_conv.rsh(module, 1, scratch_1); + self.glwe_rsh(1, &mut self_conv, scratch_1); let p: i64 = if i == 0 { -1 } else { - module.galois_element(1 << (i - 1)) + self.galois_element(1 << (i - 1)) }; - if let Some(key) = auto_keys.get(&p) { - self_conv.automorphism_add_inplace(module, key, scratch_1); + if let Some(key) = keys.get(&p) { + self.glwe_automorphism_add_inplace(&mut self_conv, key, scratch_1); } else { - panic!("auto_keys[{p}] is empty") + panic!("keys[{p}] is empty") } } - for j in 0..(self.rank() + 1).into() { - module.vec_znx_normalize( - self.base2k().into(), - &mut self.data, + for j in 0..(res.rank() + 1).into() { + self.vec_znx_normalize( + res.base2k().into(), + res.data_mut(), j, basek_ksk.into(), &self_conv.data, @@ -184,19 +192,21 @@ impl GLWECiphertext { ); } } else { + // println!("res: {}", res); + for i in start..end { - self.rsh(module, 1, scratch); + self.glwe_rsh(1, res, scratch); let p: i64 = if i == 0 { -1 } else { - module.galois_element(1 << (i - 1)) + self.galois_element(1 << (i - 1)) }; - if let Some(key) = auto_keys.get(&p) { - self.automorphism_add_inplace(module, key, scratch); + if let Some(key) = keys.get(&p) { + self.glwe_automorphism_add_inplace(res, key, scratch); } else { - panic!("auto_keys[{p}] is empty") + panic!("keys[{p}] is empty") } } } diff --git a/poulpy-core/src/keyswitching/gglwe.rs b/poulpy-core/src/keyswitching/gglwe.rs new file mode 100644 index 0000000..d837002 --- /dev/null +++ b/poulpy-core/src/keyswitching/gglwe.rs @@ -0,0 +1,199 @@ +use poulpy_hal::layouts::{Backend, DataMut, Module, Scratch}; + +use crate::{ + ScratchTakeCore, + keyswitching::GLWEKeyswitch, + layouts::{GGLWE, GGLWEInfos, GGLWEPreparedToRef, GGLWEToMut, GGLWEToRef, GLWEAutomorphismKey, GLWESwitchingKey}, +}; + +impl GLWEAutomorphismKey> { + pub fn keyswitch_tmp_bytes(module: &M, res_infos: &R, a_infos: &A, key_infos: &K) -> usize + where + R: GGLWEInfos, + A: GGLWEInfos, + K: GGLWEInfos, + M: GGLWEKeyswitch, + { + module.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos) + } +} + +impl GLWEAutomorphismKey { + pub fn keyswitch(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch) + where + A: GGLWEToRef + GGLWEToRef, + B: GGLWEPreparedToRef + GGLWEInfos, + Scratch: ScratchTakeCore, + M: GGLWEKeyswitch, + { + module.gglwe_keyswitch(self, a, b, scratch); + } + + pub fn keyswitch_inplace(&mut self, module: &M, a: &A, scratch: &mut Scratch) + where + A: GGLWEPreparedToRef + GGLWEInfos, + Scratch: ScratchTakeCore, + M: GGLWEKeyswitch, + { + module.gglwe_keyswitch_inplace(self, a, scratch); + } +} + +impl GLWESwitchingKey> { + pub fn keyswitch_tmp_bytes(module: &M, res_infos: &R, a_infos: &A, key_infos: &K) -> usize + where + R: GGLWEInfos, + A: GGLWEInfos, + K: GGLWEInfos, + M: GGLWEKeyswitch, + { + module.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos) + } +} + +impl GLWESwitchingKey { + pub fn keyswitch(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch) + where + A: GGLWEToRef, + B: GGLWEPreparedToRef + GGLWEInfos, + Scratch: ScratchTakeCore, + M: GGLWEKeyswitch, + { + module.gglwe_keyswitch(self, a, b, scratch); + } + + pub fn keyswitch_inplace(&mut self, module: &M, a: &A, scratch: &mut Scratch) + where + A: GGLWEPreparedToRef + GGLWEInfos, + Scratch: ScratchTakeCore, + M: GGLWEKeyswitch, + { + module.gglwe_keyswitch_inplace(self, a, scratch); + } +} + +impl GGLWE> { + pub fn keyswitch_tmp_bytes(module: &M, res_infos: &R, a_infos: &A, key_infos: &K) -> usize + where + R: GGLWEInfos, + A: GGLWEInfos, + K: GGLWEInfos, + M: GGLWEKeyswitch, + { + module.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos) + } +} + +impl GGLWE { + pub fn keyswitch(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch) + where + A: GGLWEToRef, + B: GGLWEPreparedToRef + GGLWEInfos, + Scratch: ScratchTakeCore, + M: GGLWEKeyswitch, + { + module.gglwe_keyswitch(self, a, b, scratch); + } + + pub fn keyswitch_inplace(&mut self, module: &M, a: &A, scratch: &mut Scratch) + where + A: GGLWEPreparedToRef + GGLWEInfos, + Scratch: ScratchTakeCore, + M: GGLWEKeyswitch, + { + module.gglwe_keyswitch_inplace(self, a, scratch); + } +} + +impl GGLWEKeyswitch for Module where Self: GLWEKeyswitch {} + +pub trait GGLWEKeyswitch +where + Self: GLWEKeyswitch, +{ + fn gglwe_keyswitch_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize + where + R: GGLWEInfos, + A: GGLWEInfos, + K: GGLWEInfos, + { + self.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos) + } + + fn gglwe_keyswitch(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) + where + R: GGLWEToMut, + A: GGLWEToRef, + B: GGLWEPreparedToRef + GGLWEInfos, + Scratch: ScratchTakeCore, + { + let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GGLWE<&[u8]> = &a.to_ref(); + + assert_eq!( + res.rank_in(), + a.rank_in(), + "res input rank: {} != a input rank: {}", + res.rank_in(), + a.rank_in() + ); + assert_eq!( + a.rank_out(), + b.rank_in(), + "res output rank: {} != b input rank: {}", + a.rank_out(), + b.rank_in() + ); + assert_eq!( + res.rank_out(), + b.rank_out(), + "res output rank: {} != b output rank: {}", + res.rank_out(), + b.rank_out() + ); + assert!( + res.dnum() <= a.dnum(), + "res.dnum()={} > a.dnum()={}", + res.dnum(), + a.dnum() + ); + assert_eq!( + res.dsize(), + a.dsize(), + "res dsize: {} != a dsize: {}", + res.dsize(), + a.dsize() + ); + + for row in 0..res.dnum().into() { + for col in 0..res.rank_in().into() { + self.glwe_keyswitch(&mut res.at_mut(row, col), &a.at(row, col), b, scratch); + } + } + } + + fn gglwe_keyswitch_inplace(&self, res: &mut R, a: &A, scratch: &mut Scratch) + where + R: GGLWEToMut, + A: GGLWEPreparedToRef + GGLWEInfos, + Scratch: ScratchTakeCore, + { + let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut(); + + assert_eq!( + res.rank_out(), + a.rank_out(), + "res output rank: {} != a output rank: {}", + res.rank_out(), + a.rank_out() + ); + + for row in 0..res.dnum().into() { + for col in 0..res.rank_in().into() { + self.glwe_keyswitch_inplace(&mut res.at_mut(row, col), a, scratch); + } + } + } +} + +impl GLWESwitchingKey {} diff --git a/poulpy-core/src/keyswitching/gglwe_ct.rs b/poulpy-core/src/keyswitching/gglwe_ct.rs deleted file mode 100644 index 9f6caa5..0000000 --- a/poulpy-core/src/keyswitching/gglwe_ct.rs +++ /dev/null @@ -1,224 +0,0 @@ -use poulpy_hal::{ - api::{ - ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, - VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, - }, - layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero}, -}; - -use crate::layouts::{ - GGLWEAutomorphismKey, GGLWEInfos, GGLWESwitchingKey, GLWECiphertext, GLWEInfos, - prepared::{GGLWEAutomorphismKeyPrepared, GGLWESwitchingKeyPrepared}, -}; - -impl GGLWEAutomorphismKey> { - pub fn keyswitch_scratch_space( - module: &Module, - out_infos: &OUT, - in_infos: &IN, - key_infos: &KEY, - ) -> usize - where - OUT: GGLWEInfos, - IN: GGLWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, - { - GGLWESwitchingKey::keyswitch_scratch_space(module, out_infos, in_infos, key_infos) - } - - pub fn keyswitch_inplace_scratch_space(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize - where - OUT: GGLWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, - { - GGLWESwitchingKey::keyswitch_inplace_scratch_space(module, out_infos, key_infos) - } -} - -impl GGLWEAutomorphismKey { - pub fn keyswitch( - &mut self, - module: &Module, - lhs: &GGLWEAutomorphismKey, - rhs: &GGLWESwitchingKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, - { - self.key.keyswitch(module, &lhs.key, rhs, scratch); - } - - pub fn keyswitch_inplace( - &mut self, - module: &Module, - rhs: &GGLWEAutomorphismKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, - { - self.key.keyswitch_inplace(module, &rhs.key, scratch); - } -} - -impl GGLWESwitchingKey> { - pub fn keyswitch_scratch_space( - module: &Module, - out_infos: &OUT, - in_infos: &IN, - key_apply: &KEY, - ) -> usize - where - OUT: GGLWEInfos, - IN: GGLWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, - { - GLWECiphertext::keyswitch_scratch_space(module, out_infos, in_infos, key_apply) - } - - pub fn keyswitch_inplace_scratch_space(module: &Module, out_infos: &OUT, key_apply: &KEY) -> usize - where - OUT: GGLWEInfos + GLWEInfos, - KEY: GGLWEInfos + GLWEInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, - { - GLWECiphertext::keyswitch_inplace_scratch_space(module, out_infos, key_apply) - } -} - -impl GGLWESwitchingKey { - pub fn keyswitch( - &mut self, - module: &Module, - lhs: &GGLWESwitchingKey, - rhs: &GGLWESwitchingKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnx, - { - #[cfg(debug_assertions)] - { - assert_eq!( - self.rank_in(), - lhs.rank_in(), - "ksk_out input rank: {} != ksk_in input rank: {}", - self.rank_in(), - lhs.rank_in() - ); - assert_eq!( - lhs.rank_out(), - rhs.rank_in(), - "ksk_in output rank: {} != ksk_apply input rank: {}", - self.rank_out(), - rhs.rank_in() - ); - assert_eq!( - self.rank_out(), - rhs.rank_out(), - "ksk_out output rank: {} != ksk_apply output rank: {}", - self.rank_out(), - rhs.rank_out() - ); - assert!( - self.dnum() <= lhs.dnum(), - "self.dnum()={} > lhs.dnum()={}", - self.dnum(), - lhs.dnum() - ); - assert_eq!( - self.dsize(), - lhs.dsize(), - "ksk_out dsize: {} != ksk_in dsize: {}", - self.dsize(), - lhs.dsize() - ) - } - - (0..self.rank_in().into()).for_each(|col_i| { - (0..self.dnum().into()).for_each(|row_j| { - self.at_mut(row_j, col_i) - .keyswitch(module, &lhs.at(row_j, col_i), rhs, scratch); - }); - }); - - (self.dnum().min(lhs.dnum()).into()..self.dnum().into()).for_each(|row_i| { - (0..self.rank_in().into()).for_each(|col_j| { - self.at_mut(row_i, col_j).data.zero(); - }); - }); - } - - pub fn keyswitch_inplace( - &mut self, - module: &Module, - rhs: &GGLWESwitchingKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnx, - { - #[cfg(debug_assertions)] - { - assert_eq!( - self.rank_out(), - rhs.rank_out(), - "ksk_out output rank: {} != ksk_apply output rank: {}", - self.rank_out(), - rhs.rank_out() - ); - } - - (0..self.rank_in().into()).for_each(|col_i| { - (0..self.dnum().into()).for_each(|row_j| { - self.at_mut(row_j, col_i) - .keyswitch_inplace(module, rhs, scratch) - }); - }); - } -} diff --git a/poulpy-core/src/keyswitching/ggsw.rs b/poulpy-core/src/keyswitching/ggsw.rs new file mode 100644 index 0000000..231b071 --- /dev/null +++ b/poulpy-core/src/keyswitching/ggsw.rs @@ -0,0 +1,129 @@ +use poulpy_hal::layouts::{Backend, DataMut, Module, Scratch, VecZnx}; + +use crate::{ + GGSWExpandRows, ScratchTakeCore, + keyswitching::GLWEKeyswitch, + layouts::{GGLWEInfos, GGLWEPreparedToRef, GGSW, GGSWInfos, GGSWToMut, GGSWToRef, prepared::GLWETensorKeyPreparedToRef}, +}; + +impl GGSW> { + pub fn keyswitch_tmp_bytes( + module: &M, + res_infos: &R, + a_infos: &A, + key_infos: &K, + tsk_infos: &T, + ) -> usize + where + R: GGSWInfos, + A: GGSWInfos, + K: GGLWEInfos, + T: GGLWEInfos, + M: GGSWKeyswitch, + { + module.ggsw_keyswitch_tmp_bytes(res_infos, a_infos, key_infos, tsk_infos) + } +} + +impl GGSW { + pub fn keyswitch(&mut self, module: &M, a: &A, key: &K, tsk: &T, scratch: &mut Scratch) + where + A: GGSWToRef, + K: GGLWEPreparedToRef, + T: GLWETensorKeyPreparedToRef, + Scratch: ScratchTakeCore, + M: GGSWKeyswitch, + { + module.ggsw_keyswitch(self, a, key, tsk, scratch); + } + + pub fn keyswitch_inplace(&mut self, module: &M, key: &K, tsk: &T, scratch: &mut Scratch) + where + K: GGLWEPreparedToRef, + T: GLWETensorKeyPreparedToRef, + Scratch: ScratchTakeCore, + M: GGSWKeyswitch, + { + module.ggsw_keyswitch_inplace(self, key, tsk, scratch); + } +} + +impl GGSWKeyswitch for Module where Self: GLWEKeyswitch + GGSWExpandRows {} + +pub trait GGSWKeyswitch +where + Self: GLWEKeyswitch + GGSWExpandRows, +{ + fn ggsw_keyswitch_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &K, tsk_infos: &T) -> usize + where + R: GGSWInfos, + A: GGSWInfos, + K: GGLWEInfos, + T: GGLWEInfos, + { + assert_eq!(key_infos.rank_in(), key_infos.rank_out()); + assert_eq!(tsk_infos.rank_in(), tsk_infos.rank_out()); + assert_eq!(key_infos.rank_in(), tsk_infos.rank_in()); + + let rank: usize = key_infos.rank_out().into(); + + let size_out: usize = res_infos.k().div_ceil(res_infos.base2k()) as usize; + let res_znx: usize = VecZnx::bytes_of(self.n(), rank + 1, size_out); + let ci_dft: usize = self.bytes_of_vec_znx_dft(rank + 1, size_out); + let ks: usize = self.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos); + let expand_rows: usize = self.ggsw_expand_rows_tmp_bytes(res_infos, tsk_infos); + let res_dft: usize = self.bytes_of_vec_znx_dft(rank + 1, size_out); + + if a_infos.base2k() == tsk_infos.base2k() { + res_znx + ci_dft + (ks | expand_rows | res_dft) + } else { + let a_conv: usize = VecZnx::bytes_of( + self.n(), + 1, + res_infos.k().div_ceil(tsk_infos.base2k()) as usize, + ) + self.vec_znx_normalize_tmp_bytes(); + res_znx + ci_dft + (a_conv | ks | expand_rows | res_dft) + } + } + + fn ggsw_keyswitch(&self, res: &mut R, a: &A, key: &K, tsk: &T, scratch: &mut Scratch) + where + R: GGSWToMut, + A: GGSWToRef, + K: GGLWEPreparedToRef, + T: GLWETensorKeyPreparedToRef, + Scratch: ScratchTakeCore, + { + let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); + let a: &GGSW<&[u8]> = &a.to_ref(); + + assert!(res.dnum() <= a.dnum()); + assert_eq!(res.dsize(), a.dsize()); + + for row in 0..a.dnum().into() { + // Key-switch column 0, i.e. + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2) + self.glwe_keyswitch(&mut res.at_mut(row, 0), &a.at(row, 0), key, scratch); + } + + self.ggsw_expand_row(res, tsk, scratch); + } + + fn ggsw_keyswitch_inplace(&self, res: &mut R, key: &K, tsk: &T, scratch: &mut Scratch) + where + R: GGSWToMut, + K: GGLWEPreparedToRef, + T: GLWETensorKeyPreparedToRef, + Scratch: ScratchTakeCore, + { + let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); + + for row in 0..res.dnum().into() { + // Key-switch column 0, i.e. + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2) + self.glwe_keyswitch_inplace(&mut res.at_mut(row, 0), key, scratch); + } + + self.ggsw_expand_row(res, tsk, scratch); + } +} diff --git a/poulpy-core/src/keyswitching/ggsw_ct.rs b/poulpy-core/src/keyswitching/ggsw_ct.rs deleted file mode 100644 index d261f03..0000000 --- a/poulpy-core/src/keyswitching/ggsw_ct.rs +++ /dev/null @@ -1,366 +0,0 @@ -use poulpy_hal::{ - api::{ - ScratchAvailable, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, - VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftApply, - VecZnxDftCopy, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, - VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, - }, - layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx, VmpPMat, ZnxInfos}, -}; - -use crate::{ - layouts::{ - GGLWECiphertext, GGLWEInfos, GGSWCiphertext, GGSWInfos, GLWECiphertext, GLWEInfos, LWEInfos, - prepared::{GGLWESwitchingKeyPrepared, GGLWETensorKeyPrepared}, - }, - operations::GLWEOperations, -}; - -impl GGSWCiphertext> { - pub(crate) fn expand_row_scratch_space(module: &Module, out_infos: &OUT, tsk_infos: &TSK) -> usize - where - OUT: GGSWInfos, - TSK: GGLWEInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes, - { - let tsk_size: usize = tsk_infos.k().div_ceil(tsk_infos.base2k()) as usize; - let size_in: usize = out_infos - .k() - .div_ceil(tsk_infos.base2k()) - .div_ceil(tsk_infos.dsize().into()) as usize; - - let tmp_dft_i: usize = module.vec_znx_dft_alloc_bytes((tsk_infos.rank_out() + 1).into(), tsk_size); - let tmp_a: usize = module.vec_znx_dft_alloc_bytes(1, size_in); - let vmp: usize = module.vmp_apply_dft_to_dft_tmp_bytes( - tsk_size, - size_in, - size_in, - (tsk_infos.rank_in()).into(), // Verify if rank+1 - (tsk_infos.rank_out()).into(), // Verify if rank+1 - tsk_size, - ); - let tmp_idft: usize = module.vec_znx_big_alloc_bytes(1, tsk_size); - let norm: usize = module.vec_znx_normalize_tmp_bytes(); - - tmp_dft_i + ((tmp_a + vmp) | (tmp_idft + norm)) - } - - #[allow(clippy::too_many_arguments)] - pub fn keyswitch_scratch_space( - module: &Module, - out_infos: &OUT, - in_infos: &IN, - apply_infos: &KEY, - tsk_infos: &TSK, - ) -> usize - where - OUT: GGSWInfos, - IN: GGSWInfos, - KEY: GGLWEInfos, - TSK: GGLWEInfos, - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigAllocBytes - + VecZnxNormalizeTmpBytes - + VecZnxBigNormalizeTmpBytes, - { - #[cfg(debug_assertions)] - { - assert_eq!(apply_infos.rank_in(), apply_infos.rank_out()); - assert_eq!(tsk_infos.rank_in(), tsk_infos.rank_out()); - assert_eq!(apply_infos.rank_in(), tsk_infos.rank_in()); - } - - let rank: usize = apply_infos.rank_out().into(); - - let size_out: usize = out_infos.k().div_ceil(out_infos.base2k()) as usize; - let res_znx: usize = VecZnx::alloc_bytes(module.n(), rank + 1, size_out); - let ci_dft: usize = module.vec_znx_dft_alloc_bytes(rank + 1, size_out); - let ks: usize = GLWECiphertext::keyswitch_scratch_space(module, out_infos, in_infos, apply_infos); - let expand_rows: usize = GGSWCiphertext::expand_row_scratch_space(module, out_infos, tsk_infos); - let res_dft: usize = module.vec_znx_dft_alloc_bytes(rank + 1, size_out); - - if in_infos.base2k() == tsk_infos.base2k() { - res_znx + ci_dft + (ks | expand_rows | res_dft) - } else { - let a_conv: usize = VecZnx::alloc_bytes( - module.n(), - 1, - out_infos.k().div_ceil(tsk_infos.base2k()) as usize, - ) + module.vec_znx_normalize_tmp_bytes(); - res_znx + ci_dft + (a_conv | ks | expand_rows | res_dft) - } - } - - #[allow(clippy::too_many_arguments)] - pub fn keyswitch_inplace_scratch_space( - module: &Module, - out_infos: &OUT, - apply_infos: &KEY, - tsk_infos: &TSK, - ) -> usize - where - OUT: GGSWInfos, - KEY: GGLWEInfos, - TSK: GGLWEInfos, - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigAllocBytes - + VecZnxNormalizeTmpBytes - + VecZnxBigNormalizeTmpBytes, - { - GGSWCiphertext::keyswitch_scratch_space(module, out_infos, out_infos, apply_infos, tsk_infos) - } -} - -impl GGSWCiphertext { - pub fn from_gglwe( - &mut self, - module: &Module, - a: &GGLWECiphertext, - tsk: &GGLWETensorKeyPrepared, - scratch: &mut Scratch, - ) where - DataA: DataRef, - DataTsk: DataRef, - Module: VecZnxCopy - + VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigAllocBytes - + VecZnxNormalizeTmpBytes - + VecZnxDftApply - + VecZnxDftCopy - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftAddInplace - + VecZnxBigNormalize - + VecZnxIdftApplyTmpA - + VecZnxNormalize, - Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnxBig + TakeVecZnx, - { - #[cfg(debug_assertions)] - { - use crate::layouts::{GLWEInfos, LWEInfos}; - - assert_eq!(self.rank(), a.rank_out()); - assert_eq!(self.dnum(), a.dnum()); - assert_eq!(self.n(), module.n() as u32); - assert_eq!(a.n(), module.n() as u32); - assert_eq!(tsk.n(), module.n() as u32); - } - (0..self.dnum().into()).for_each(|row_i| { - self.at_mut(row_i, 0).copy(module, &a.at(row_i, 0)); - }); - self.expand_row(module, tsk, scratch); - } - - pub fn keyswitch( - &mut self, - module: &Module, - lhs: &GGSWCiphertext, - ksk: &GGLWESwitchingKeyPrepared, - tsk: &GGLWETensorKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxDftAllocBytes - + VecZnxBigAllocBytes - + VecZnxNormalizeTmpBytes - + VecZnxDftCopy - + VecZnxDftAddInplace - + VecZnxIdftApplyTmpA - + VecZnxNormalize, - Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnxBig + TakeVecZnx, - { - (0..lhs.dnum().into()).for_each(|row_i| { - // Key-switch column 0, i.e. - // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2) - self.at_mut(row_i, 0) - .keyswitch(module, &lhs.at(row_i, 0), ksk, scratch); - }); - self.expand_row(module, tsk, scratch); - } - - pub fn keyswitch_inplace( - &mut self, - module: &Module, - ksk: &GGLWESwitchingKeyPrepared, - tsk: &GGLWETensorKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxDftAllocBytes - + VecZnxBigAllocBytes - + VecZnxNormalizeTmpBytes - + VecZnxDftCopy - + VecZnxDftAddInplace - + VecZnxIdftApplyTmpA - + VecZnxNormalize, - Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnxBig + TakeVecZnx, - { - (0..self.dnum().into()).for_each(|row_i| { - // Key-switch column 0, i.e. - // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2) - self.at_mut(row_i, 0) - .keyswitch_inplace(module, ksk, scratch); - }); - self.expand_row(module, tsk, scratch); - } - - pub fn expand_row( - &mut self, - module: &Module, - tsk: &GGLWETensorKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigAllocBytes - + VecZnxNormalizeTmpBytes - + VecZnxDftApply - + VecZnxDftCopy - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftAddInplace - + VecZnxBigNormalize - + VecZnxIdftApplyTmpA - + VecZnxNormalize, - Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnxBig + TakeVecZnx, - { - let basek_in: usize = self.base2k().into(); - let basek_tsk: usize = tsk.base2k().into(); - - assert!(scratch.available() >= GGSWCiphertext::expand_row_scratch_space(module, self, tsk)); - - let n: usize = self.n().into(); - let rank: usize = self.rank().into(); - let cols: usize = rank + 1; - - let a_size: usize = (self.size() * basek_in).div_ceil(basek_tsk); - - // Keyswitch the j-th row of the col 0 - for row_i in 0..self.dnum().into() { - let a = &self.at(row_i, 0).data; - - // Pre-compute DFT of (a0, a1, a2) - let (mut ci_dft, scratch_1) = scratch.take_vec_znx_dft(n, cols, a_size); - - if basek_in == basek_tsk { - for i in 0..cols { - module.vec_znx_dft_apply(1, 0, &mut ci_dft, i, a, i); - } - } else { - let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(n, 1, a_size); - for i in 0..cols { - module.vec_znx_normalize(basek_tsk, &mut a_conv, 0, basek_in, a, i, scratch_2); - module.vec_znx_dft_apply(1, 0, &mut ci_dft, i, &a_conv, 0); - } - } - - for col_j in 1..cols { - // Example for rank 3: - // - // Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is - // actually composed of that many dnum and we focus on a specific row here - // implicitely given ci_dft. - // - // # Input - // - // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) - // col 1: (0, 0, 0, 0) - // col 2: (0, 0, 0, 0) - // col 3: (0, 0, 0, 0) - // - // # Output - // - // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) - // col 1: (-(b0s0 + b1s1 + b2s2) , b0 + M[i], b1 , b2 ) - // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M[i], c2 ) - // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M[i]) - - let dsize: usize = tsk.dsize().into(); - - let (mut tmp_dft_i, scratch_2) = scratch_1.take_vec_znx_dft(n, cols, tsk.size()); - let (mut tmp_a, scratch_3) = scratch_2.take_vec_znx_dft(n, 1, ci_dft.size().div_ceil(dsize)); - - { - // Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2 - // - // # Example for col=1 - // - // a0 * (-(f0s0 + f1s1 + f1s2) + s0^2, f0, f1, f2) = (-(a0f0s0 + a0f1s1 + a0f1s2) + a0s0^2, a0f0, a0f1, a0f2) - // + - // a1 * (-(g0s0 + g1s1 + g1s2) + s0s1, g0, g1, g2) = (-(a1g0s0 + a1g1s1 + a1g1s2) + a1s0s1, a1g0, a1g1, a1g2) - // + - // a2 * (-(h0s0 + h1s1 + h1s2) + s0s2, h0, h1, h2) = (-(a2h0s0 + a2h1s1 + a2h1s2) + a2s0s2, a2h0, a2h1, a2h2) - // = - // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2) - for col_i in 1..cols { - let pmat: &VmpPMat = &tsk.at(col_i - 1, col_j - 1).key.data; // Selects Enc(s[i]s[j]) - - // Extracts a[i] and multipies with Enc(s[i]s[j]) - for di in 0..dsize { - tmp_a.set_size((ci_dft.size() + di) / dsize); - - // Small optimization for dsize > 2 - // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then - // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}. - // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. - // It is possible to further ignore the last dsize-1 limbs, but this introduce - // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same - // noise is kept with respect to the ideal functionality. - tmp_dft_i.set_size(tsk.size() - ((dsize - di) as isize - 2).max(0) as usize); - - module.vec_znx_dft_copy(dsize, dsize - 1 - di, &mut tmp_a, 0, &ci_dft, col_i); - if di == 0 && col_i == 1 { - module.vmp_apply_dft_to_dft(&mut tmp_dft_i, &tmp_a, pmat, scratch_3); - } else { - module.vmp_apply_dft_to_dft_add(&mut tmp_dft_i, &tmp_a, pmat, di, scratch_3); - } - } - } - } - - // Adds -(sum a[i] * s[i]) + m) on the i-th column of tmp_idft_i - // - // (-(x0s0 + x1s1 + x2s2) + a0s0s0 + a1s0s1 + a2s0s2, x0, x1, x2) - // + - // (0, -(a0s0 + a1s1 + a2s2) + M[i], 0, 0) - // = - // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0 -(a0s0 + a1s1 + a2s2) + M[i], x1, x2) - // = - // (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2) - module.vec_znx_dft_add_inplace(&mut tmp_dft_i, col_j, &ci_dft, 0); - let (mut tmp_idft, scratch_3) = scratch_2.take_vec_znx_big(n, 1, tsk.size()); - for i in 0..cols { - module.vec_znx_idft_apply_tmpa(&mut tmp_idft, 0, &mut tmp_dft_i, i); - module.vec_znx_big_normalize( - basek_in, - &mut self.at_mut(row_i, col_j).data, - i, - basek_tsk, - &tmp_idft, - 0, - scratch_3, - ); - } - } - } - } -} diff --git a/poulpy-core/src/keyswitching/glwe.rs b/poulpy-core/src/keyswitching/glwe.rs new file mode 100644 index 0000000..a021777 --- /dev/null +++ b/poulpy-core/src/keyswitching/glwe.rs @@ -0,0 +1,370 @@ +use poulpy_hal::{ + api::{ + ModuleN, ScratchAvailable, ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, + VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, + VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + }, + layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, ZnxInfos}, +}; + +use crate::{ + ScratchTakeCore, + layouts::{GGLWEInfos, GGLWEPrepared, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos}, +}; + +impl GLWE> { + pub fn keyswitch_tmp_bytes(module: &M, res_infos: &R, a_infos: &A, key_infos: &B) -> usize + where + R: GLWEInfos, + A: GLWEInfos, + B: GGLWEInfos, + M: GLWEKeyswitch, + { + module.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos) + } +} + +impl GLWE { + pub fn keyswitch(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch) + where + A: GLWEToRef, + B: GGLWEPreparedToRef, + M: GLWEKeyswitch, + Scratch: ScratchTakeCore, + { + module.glwe_keyswitch(self, a, b, scratch); + } + + pub fn keyswitch_inplace(&mut self, module: &M, a: &A, scratch: &mut Scratch) + where + A: GGLWEPreparedToRef, + M: GLWEKeyswitch, + Scratch: ScratchTakeCore, + { + module.glwe_keyswitch_inplace(self, a, scratch); + } +} + +impl GLWEKeyswitch for Module where + Self: Sized + + ModuleN + + VecZnxDftBytesOf + + VmpApplyDftToDftTmpBytes + + VecZnxBigNormalizeTmpBytes + + VecZnxNormalizeTmpBytes + + VecZnxDftBytesOf + + VmpApplyDftToDftTmpBytes + + VecZnxBigNormalizeTmpBytes + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VecZnxDftApply + + VecZnxIdftApplyConsume + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize + + VecZnxNormalize + + VecZnxNormalizeTmpBytes +{ +} + +pub trait GLWEKeyswitch +where + Self: Sized + + ModuleN + + VecZnxDftBytesOf + + VmpApplyDftToDftTmpBytes + + VecZnxBigNormalizeTmpBytes + + VecZnxNormalizeTmpBytes + + VecZnxDftBytesOf + + VmpApplyDftToDftTmpBytes + + VecZnxBigNormalizeTmpBytes + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VecZnxDftApply + + VecZnxIdftApplyConsume + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize + + VecZnxNormalize + + VecZnxNormalizeTmpBytes, +{ + fn glwe_keyswitch_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &B) -> usize + where + R: GLWEInfos, + A: GLWEInfos, + B: GGLWEInfos, + { + let in_size: usize = a_infos + .k() + .div_ceil(key_infos.base2k()) + .div_ceil(key_infos.dsize().into()) as usize; + let out_size: usize = res_infos.size(); + let ksk_size: usize = key_infos.size(); + let res_dft: usize = self.bytes_of_vec_znx_dft((key_infos.rank_out() + 1).into(), ksk_size); // TODO OPTIMIZE + let ai_dft: usize = self.bytes_of_vec_znx_dft((key_infos.rank_in()).into(), in_size); + let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes( + out_size, + in_size, + in_size, + (key_infos.rank_in()).into(), + (key_infos.rank_out() + 1).into(), + ksk_size, + ) + self.bytes_of_vec_znx_dft((key_infos.rank_in()).into(), in_size); + let normalize_big: usize = self.vec_znx_big_normalize_tmp_bytes(); + if a_infos.base2k() == key_infos.base2k() { + res_dft + ((ai_dft + vmp) | normalize_big) + } else if key_infos.dsize() == 1 { + // In this case, we only need one column, temporary, that we can drop once a_dft is computed. + let normalize_conv: usize = VecZnx::bytes_of(self.n(), 1, in_size) + self.vec_znx_normalize_tmp_bytes(); + res_dft + (((ai_dft + normalize_conv) | vmp) | normalize_big) + } else { + // Since we stride over a to get a_dft when dsize > 1, we need to store the full columns of a with in the base conversion. + let normalize_conv: usize = VecZnx::bytes_of(self.n(), (key_infos.rank_in()).into(), in_size); + res_dft + ((ai_dft + normalize_conv + (self.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big) + } + } + + fn glwe_keyswitch(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + A: GLWEToRef, + K: GGLWEPreparedToRef, + Scratch: ScratchTakeCore, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); + let b: &GGLWEPrepared<&[u8], BE> = &key.to_ref(); + + assert_eq!( + a.rank(), + b.rank_in(), + "a.rank(): {} != b.rank_in(): {}", + a.rank(), + b.rank_in() + ); + assert_eq!( + res.rank(), + b.rank_out(), + "res.rank(): {} != b.rank_out(): {}", + res.rank(), + b.rank_out() + ); + + assert_eq!(res.n(), self.n() as u32); + assert_eq!(a.n(), self.n() as u32); + assert_eq!(b.n(), self.n() as u32); + + let scrach_needed: usize = self.glwe_keyswitch_tmp_bytes(res, a, b); + + assert!( + scratch.available() >= scrach_needed, + "scratch.available()={} < glwe_keyswitch_tmp_bytes={scrach_needed}", + scratch.available(), + ); + + let basek_out: usize = res.base2k().into(); + let base2k_out: usize = b.base2k().into(); + + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), b.size()); // Todo optimise + let res_big: VecZnxBig<&mut [u8], BE> = keyswitch_internal(self, res_dft, a, b, scratch_1); + (0..(res.rank() + 1).into()).for_each(|i| { + self.vec_znx_big_normalize( + basek_out, + &mut res.data, + i, + base2k_out, + &res_big, + i, + scratch_1, + ); + }) + } + + fn glwe_keyswitch_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + K: GGLWEPreparedToRef, + Scratch: ScratchTakeCore, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GGLWEPrepared<&[u8], BE> = &key.to_ref(); + + assert_eq!( + res.rank(), + a.rank_in(), + "res.rank(): {} != a.rank_in(): {}", + res.rank(), + a.rank_in() + ); + assert_eq!( + res.rank(), + a.rank_out(), + "res.rank(): {} != b.rank_out(): {}", + res.rank(), + a.rank_out() + ); + + assert_eq!(res.n(), self.n() as u32); + assert_eq!(a.n(), self.n() as u32); + + let scrach_needed: usize = self.glwe_keyswitch_tmp_bytes(res, res, a); + + assert!( + scratch.available() >= scrach_needed, + "scratch.available()={} < glwe_keyswitch_tmp_bytes={scrach_needed}", + scratch.available(), + ); + + let base2k_in: usize = res.base2k().into(); + let base2k_out: usize = a.base2k().into(); + + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), a.size()); // Todo optimise + let res_big: VecZnxBig<&mut [u8], BE> = keyswitch_internal(self, res_dft, res, a, scratch_1); + (0..(res.rank() + 1).into()).for_each(|i| { + self.vec_znx_big_normalize( + base2k_in, + &mut res.data, + i, + base2k_out, + &res_big, + i, + scratch_1, + ); + }) + } +} + +impl GLWE> {} + +impl GLWE {} + +pub(crate) fn keyswitch_internal( + module: &M, + mut res: VecZnxDft, + a: &A, + key: &K, + scratch: &mut Scratch, +) -> VecZnxBig +where + DR: DataMut, + A: GLWEToRef, + K: GGLWEPreparedToRef, + M: ModuleN + + VecZnxDftBytesOf + + VmpApplyDftToDftTmpBytes + + VecZnxBigNormalizeTmpBytes + + VmpApplyDftToDftTmpBytes + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VecZnxDftApply + + VecZnxIdftApplyConsume + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize + + VecZnxNormalize, + Scratch: ScratchTakeCore, +{ + let a: &GLWE<&[u8]> = &a.to_ref(); + let key: &GGLWEPrepared<&[u8], BE> = &key.to_ref(); + + let base2k_in: usize = a.base2k().into(); + let base2k_out: usize = key.base2k().into(); + let cols: usize = (a.rank() + 1).into(); + let a_size: usize = (a.size() * base2k_in).div_ceil(base2k_out); + let pmat: &VmpPMat<&[u8], BE> = &key.data; + + if key.dsize() == 1 { + let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols - 1, a.size()); + + if base2k_in == base2k_out { + (0..cols - 1).for_each(|col_i| { + module.vec_znx_dft_apply(1, 0, &mut ai_dft, col_i, a.data(), col_i + 1); + }); + } else { + let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(module.n(), 1, a_size); + (0..cols - 1).for_each(|col_i| { + module.vec_znx_normalize( + base2k_out, + &mut a_conv, + 0, + base2k_in, + a.data(), + col_i + 1, + scratch_2, + ); + module.vec_znx_dft_apply(1, 0, &mut ai_dft, col_i, &a_conv, 0); + }); + } + + module.vmp_apply_dft_to_dft(&mut res, &ai_dft, pmat, scratch_1); + } else { + let dsize: usize = key.dsize().into(); + + let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols - 1, a_size.div_ceil(dsize)); + ai_dft.data_mut().fill(0); + + if base2k_in == base2k_out { + for di in 0..dsize { + ai_dft.set_size((a_size + di) / dsize); + + // Small optimization for dsize > 2 + // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then + // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}. + // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. + // It is possible to further ignore the last dsize-1 limbs, but this introduce + // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same + // noise is kept with respect to the ideal functionality. + res.set_size(pmat.size() - ((dsize - di) as isize - 2).max(0) as usize); + + for j in 0..cols - 1 { + module.vec_znx_dft_apply(dsize, dsize - di - 1, &mut ai_dft, j, a.data(), j + 1); + } + + if di == 0 { + module.vmp_apply_dft_to_dft(&mut res, &ai_dft, pmat, scratch_1); + } else { + module.vmp_apply_dft_to_dft_add(&mut res, &ai_dft, pmat, di, scratch_1); + } + } + } else { + let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(module.n(), cols - 1, a_size); + for j in 0..cols - 1 { + module.vec_znx_normalize( + base2k_out, + &mut a_conv, + j, + base2k_in, + a.data(), + j + 1, + scratch_2, + ); + } + + for di in 0..dsize { + ai_dft.set_size((a_size + di) / dsize); + + // Small optimization for dsize > 2 + // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then + // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}. + // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. + // It is possible to further ignore the last dsize-1 limbs, but this introduce + // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same + // noise is kept with respect to the ideal functionality. + res.set_size(pmat.size() - ((dsize - di) as isize - 2).max(0) as usize); + + for j in 0..cols - 1 { + module.vec_znx_dft_apply(dsize, dsize - di - 1, &mut ai_dft, j, &a_conv, j); + } + + if di == 0 { + module.vmp_apply_dft_to_dft(&mut res, &ai_dft, pmat, scratch_2); + } else { + module.vmp_apply_dft_to_dft_add(&mut res, &ai_dft, pmat, di, scratch_2); + } + } + } + + res.set_size(res.max_size()); + } + + let mut res_big: VecZnxBig = module.vec_znx_idft_apply_consume(res); + module.vec_znx_big_add_small_inplace(&mut res_big, 0, a.data(), 0); + res_big +} diff --git a/poulpy-core/src/keyswitching/glwe_ct.rs b/poulpy-core/src/keyswitching/glwe_ct.rs deleted file mode 100644 index 07d95e9..0000000 --- a/poulpy-core/src/keyswitching/glwe_ct.rs +++ /dev/null @@ -1,414 +0,0 @@ -use poulpy_hal::{ - api::{ - ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, - VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, - }, - layouts::{Backend, DataMut, DataRef, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, ZnxInfos}, -}; - -use crate::layouts::{GGLWEInfos, GLWECiphertext, GLWEInfos, LWEInfos, prepared::GGLWESwitchingKeyPrepared}; - -impl GLWECiphertext> { - pub fn keyswitch_scratch_space( - module: &Module, - out_infos: &OUT, - in_infos: &IN, - key_apply: &KEY, - ) -> usize - where - OUT: GLWEInfos, - IN: GLWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, - { - let in_size: usize = in_infos - .k() - .div_ceil(key_apply.base2k()) - .div_ceil(key_apply.dsize().into()) as usize; - let out_size: usize = out_infos.size(); - let ksk_size: usize = key_apply.size(); - let res_dft: usize = module.vec_znx_dft_alloc_bytes((key_apply.rank_out() + 1).into(), ksk_size); // TODO OPTIMIZE - let ai_dft: usize = module.vec_znx_dft_alloc_bytes((key_apply.rank_in()).into(), in_size); - let vmp: usize = module.vmp_apply_dft_to_dft_tmp_bytes( - out_size, - in_size, - in_size, - (key_apply.rank_in()).into(), - (key_apply.rank_out() + 1).into(), - ksk_size, - ) + module.vec_znx_dft_alloc_bytes((key_apply.rank_in()).into(), in_size); - let normalize_big: usize = module.vec_znx_big_normalize_tmp_bytes(); - if in_infos.base2k() == key_apply.base2k() { - res_dft + ((ai_dft + vmp) | normalize_big) - } else if key_apply.dsize() == 1 { - // In this case, we only need one column, temporary, that we can drop once a_dft is computed. - let normalize_conv: usize = VecZnx::alloc_bytes(module.n(), 1, in_size) + module.vec_znx_normalize_tmp_bytes(); - res_dft + (((ai_dft + normalize_conv) | vmp) | normalize_big) - } else { - // Since we stride over a to get a_dft when dsize > 1, we need to store the full columns of a with in the base conversion. - let normalize_conv: usize = VecZnx::alloc_bytes(module.n(), (key_apply.rank_in()).into(), in_size); - res_dft + ((ai_dft + normalize_conv + (module.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big) - } - } - - pub fn keyswitch_inplace_scratch_space(module: &Module, out_infos: &OUT, key_apply: &KEY) -> usize - where - OUT: GLWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, - { - Self::keyswitch_scratch_space(module, out_infos, out_infos, key_apply) - } -} - -impl GLWECiphertext { - #[allow(dead_code)] - pub(crate) fn assert_keyswitch( - &self, - module: &Module, - lhs: &GLWECiphertext, - rhs: &GGLWESwitchingKeyPrepared, - scratch: &Scratch, - ) where - DataLhs: DataRef, - DataRhs: DataRef, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable, - { - assert_eq!( - lhs.rank(), - rhs.rank_in(), - "lhs.rank(): {} != rhs.rank_in(): {}", - lhs.rank(), - rhs.rank_in() - ); - assert_eq!( - self.rank(), - rhs.rank_out(), - "self.rank(): {} != rhs.rank_out(): {}", - self.rank(), - rhs.rank_out() - ); - assert_eq!(rhs.n(), self.n()); - assert_eq!(lhs.n(), self.n()); - - let scrach_needed: usize = GLWECiphertext::keyswitch_scratch_space(module, self, lhs, rhs); - - assert!( - scratch.available() >= scrach_needed, - "scratch.available()={} < GLWECiphertext::keyswitch_scratch_space( - module, - self.base2k(), - self.k(), - lhs.base2k(), - lhs.k(), - rhs.base2k(), - rhs.k(), - rhs.dsize(), - rhs.rank_in(), - rhs.rank_out(), - )={scrach_needed}", - scratch.available(), - ); - } - - #[allow(dead_code)] - pub(crate) fn assert_keyswitch_inplace( - &self, - module: &Module, - rhs: &GGLWESwitchingKeyPrepared, - scratch: &Scratch, - ) where - DataRhs: DataRef, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable, - { - assert_eq!( - self.rank(), - rhs.rank_out(), - "self.rank(): {} != rhs.rank_out(): {}", - self.rank(), - rhs.rank_out() - ); - - assert_eq!(rhs.n(), self.n()); - - let scrach_needed: usize = GLWECiphertext::keyswitch_inplace_scratch_space(module, self, rhs); - - assert!( - scratch.available() >= scrach_needed, - "scratch.available()={} < GLWECiphertext::keyswitch_scratch_space()={scrach_needed}", - scratch.available(), - ); - } -} - -impl GLWECiphertext { - pub fn keyswitch( - &mut self, - module: &Module, - glwe_in: &GLWECiphertext, - rhs: &GGLWESwitchingKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnx, - { - #[cfg(debug_assertions)] - { - self.assert_keyswitch(module, glwe_in, rhs, scratch); - } - - let basek_out: usize = self.base2k().into(); - let basek_ksk: usize = rhs.base2k().into(); - - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // Todo optimise - let res_big: VecZnxBig<_, B> = glwe_in.keyswitch_internal(module, res_dft, rhs, scratch_1); - (0..(self.rank() + 1).into()).for_each(|i| { - module.vec_znx_big_normalize( - basek_out, - &mut self.data, - i, - basek_ksk, - &res_big, - i, - scratch_1, - ); - }) - } - - pub fn keyswitch_inplace( - &mut self, - module: &Module, - rhs: &GGLWESwitchingKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnx, - { - #[cfg(debug_assertions)] - { - self.assert_keyswitch_inplace(module, rhs, scratch); - } - - let basek_in: usize = self.base2k().into(); - let basek_ksk: usize = rhs.base2k().into(); - - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // Todo optimise - let res_big: VecZnxBig<_, B> = self.keyswitch_internal(module, res_dft, rhs, scratch_1); - (0..(self.rank() + 1).into()).for_each(|i| { - module.vec_znx_big_normalize( - basek_in, - &mut self.data, - i, - basek_ksk, - &res_big, - i, - scratch_1, - ); - }) - } -} - -impl GLWECiphertext { - pub(crate) fn keyswitch_internal( - &self, - module: &Module, - res_dft: VecZnxDft, - rhs: &GGLWESwitchingKeyPrepared, - scratch: &mut Scratch, - ) -> VecZnxBig - where - DataRes: DataMut, - DataKey: DataRef, - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalize, - Scratch: TakeVecZnxDft + TakeVecZnx, - { - if rhs.dsize() == 1 { - return keyswitch_vmp_one_digit( - module, - self.base2k().into(), - rhs.base2k().into(), - res_dft, - &self.data, - &rhs.key.data, - scratch, - ); - } - - keyswitch_vmp_multiple_digits( - module, - self.base2k().into(), - rhs.base2k().into(), - res_dft, - &self.data, - &rhs.key.data, - rhs.dsize().into(), - scratch, - ) - } -} - -fn keyswitch_vmp_one_digit( - module: &Module, - basek_in: usize, - basek_ksk: usize, - mut res_dft: VecZnxDft, - a: &VecZnx, - mat: &VmpPMat, - scratch: &mut Scratch, -) -> VecZnxBig -where - DataRes: DataMut, - DataIn: DataRef, - DataVmp: DataRef, - Module: VecZnxDftAllocBytes - + VecZnxDftApply - + VmpApplyDftToDft - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxNormalize, - Scratch: TakeVecZnxDft + TakeVecZnx, -{ - let cols: usize = a.cols(); - - let a_size: usize = (a.size() * basek_in).div_ceil(basek_ksk); - let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(a.n(), cols - 1, a.size()); - - if basek_in == basek_ksk { - (0..cols - 1).for_each(|col_i| { - module.vec_znx_dft_apply(1, 0, &mut ai_dft, col_i, a, col_i + 1); - }); - } else { - let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(a.n(), 1, a_size); - (0..cols - 1).for_each(|col_i| { - module.vec_znx_normalize(basek_ksk, &mut a_conv, 0, basek_in, a, col_i + 1, scratch_2); - module.vec_znx_dft_apply(1, 0, &mut ai_dft, col_i, &a_conv, 0); - }); - } - - module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch_1); - let mut res_big: VecZnxBig = module.vec_znx_idft_apply_consume(res_dft); - module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0); - res_big -} - -#[allow(clippy::too_many_arguments)] -fn keyswitch_vmp_multiple_digits( - module: &Module, - basek_in: usize, - basek_ksk: usize, - mut res_dft: VecZnxDft, - a: &VecZnx, - mat: &VmpPMat, - dsize: usize, - scratch: &mut Scratch, -) -> VecZnxBig -where - DataRes: DataMut, - DataIn: DataRef, - DataVmp: DataRef, - Module: VecZnxDftAllocBytes - + VecZnxDftApply - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxNormalize, - Scratch: TakeVecZnxDft + TakeVecZnx, -{ - let cols: usize = a.cols(); - let a_size: usize = (a.size() * basek_in).div_ceil(basek_ksk); - let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(a.n(), cols - 1, a_size.div_ceil(dsize)); - ai_dft.data_mut().fill(0); - - if basek_in == basek_ksk { - for di in 0..dsize { - ai_dft.set_size((a_size + di) / dsize); - - // Small optimization for dsize > 2 - // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then - // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}. - // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. - // It is possible to further ignore the last dsize-1 limbs, but this introduce - // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same - // noise is kept with respect to the ideal functionality. - res_dft.set_size(mat.size() - ((dsize - di) as isize - 2).max(0) as usize); - - for j in 0..cols - 1 { - module.vec_znx_dft_apply(dsize, dsize - di - 1, &mut ai_dft, j, a, j + 1); - } - - if di == 0 { - module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch_1); - } else { - module.vmp_apply_dft_to_dft_add(&mut res_dft, &ai_dft, mat, di, scratch_1); - } - } - } else { - let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(a.n(), cols - 1, a_size); - for j in 0..cols - 1 { - module.vec_znx_normalize(basek_ksk, &mut a_conv, j, basek_in, a, j + 1, scratch_2); - } - - for di in 0..dsize { - ai_dft.set_size((a_size + di) / dsize); - - // Small optimization for dsize > 2 - // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then - // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}. - // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. - // It is possible to further ignore the last dsize-1 limbs, but this introduce - // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same - // noise is kept with respect to the ideal functionality. - res_dft.set_size(mat.size() - ((dsize - di) as isize - 2).max(0) as usize); - - for j in 0..cols - 1 { - module.vec_znx_dft_apply(dsize, dsize - di - 1, &mut ai_dft, j, &a_conv, j); - } - - if di == 0 { - module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch_2); - } else { - module.vmp_apply_dft_to_dft_add(&mut res_dft, &ai_dft, mat, di, scratch_2); - } - } - } - - res_dft.set_size(res_dft.max_size()); - let mut res_big: VecZnxBig = module.vec_znx_idft_apply_consume(res_dft); - module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0); - res_big -} diff --git a/poulpy-core/src/keyswitching/lwe.rs b/poulpy-core/src/keyswitching/lwe.rs new file mode 100644 index 0000000..bf5abf5 --- /dev/null +++ b/poulpy-core/src/keyswitching/lwe.rs @@ -0,0 +1,116 @@ +use poulpy_hal::{ + api::ScratchAvailable, + layouts::{Backend, DataMut, Module, Scratch, ZnxView, ZnxViewMut, ZnxZero}, +}; + +use crate::{ + LWESampleExtract, ScratchTakeCore, + keyswitching::GLWEKeyswitch, + layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWE, GLWELayout, LWE, LWEInfos, LWEToMut, LWEToRef, Rank, TorusPrecision}, +}; + +impl LWE> { + pub fn keyswitch_tmp_bytes(module: &M, res_infos: &R, a_infos: &A, key_infos: &K) -> usize + where + R: LWEInfos, + A: LWEInfos, + K: GGLWEInfos, + M: LWEKeySwitch, + { + module.lwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos) + } +} + +impl LWE { + pub fn keyswitch(&mut self, module: &M, a: &A, ksk: &K, scratch: &mut Scratch) + where + A: LWEToRef, + K: GGLWEPreparedToRef + GGLWEInfos, + Scratch: ScratchTakeCore, + M: LWEKeySwitch, + { + module.lwe_keyswitch(self, a, ksk, scratch); + } +} + +impl LWEKeySwitch for Module where Self: GLWEKeyswitch + LWESampleExtract {} + +pub trait LWEKeySwitch +where + Self: GLWEKeyswitch + LWESampleExtract, +{ + fn lwe_keyswitch_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize + where + R: LWEInfos, + A: LWEInfos, + K: GGLWEInfos, + { + let max_k: TorusPrecision = a_infos.k().max(res_infos.k()); + + let glwe_a_infos: GLWELayout = GLWELayout { + n: self.n().into(), + base2k: a_infos.base2k(), + k: max_k, + rank: Rank(1), + }; + + let glwe_res_infos: GLWELayout = GLWELayout { + n: self.n().into(), + base2k: res_infos.base2k(), + k: max_k, + rank: Rank(1), + }; + + let glwe_in: usize = GLWE::bytes_of_from_infos(&glwe_a_infos); + let glwe_out: usize = GLWE::bytes_of_from_infos(&glwe_res_infos); + let ks: usize = self.glwe_keyswitch_tmp_bytes(&glwe_res_infos, &glwe_a_infos, key_infos); + + glwe_in + glwe_out + ks + } + + fn lwe_keyswitch(&self, res: &mut R, a: &A, ksk: &K, scratch: &mut Scratch) + where + R: LWEToMut, + A: LWEToRef, + K: GGLWEPreparedToRef + GGLWEInfos, + Scratch: ScratchTakeCore, + { + let res: &mut LWE<&mut [u8]> = &mut res.to_mut(); + let a: &LWE<&[u8]> = &a.to_ref(); + + assert!(res.n().as_usize() <= self.n()); + assert!(a.n().as_usize() <= self.n()); + assert_eq!(ksk.n(), self.n() as u32); + assert!(scratch.available() >= self.lwe_keyswitch_tmp_bytes(res, a, ksk)); + + let max_k: TorusPrecision = res.k().max(a.k()); + + let a_size: usize = a.k().div_ceil(ksk.base2k()) as usize; + + let (mut glwe_in, scratch_1) = scratch.take_glwe(&GLWELayout { + n: ksk.n(), + base2k: a.base2k(), + k: max_k, + rank: Rank(1), + }); + glwe_in.data.zero(); + + let (mut glwe_out, scratch_1) = scratch_1.take_glwe(&GLWELayout { + n: ksk.n(), + base2k: res.base2k(), + k: max_k, + rank: Rank(1), + }); + + let n_lwe: usize = a.n().into(); + + for i in 0..a_size { + let data_lwe: &[i64] = a.data.at(0, i); + glwe_in.data.at_mut(0, i)[0] = data_lwe[0]; + glwe_in.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]); + } + + self.glwe_keyswitch(&mut glwe_out, &glwe_in, ksk, scratch_1); + self.lwe_sample_extract(res, &glwe_out); + } +} diff --git a/poulpy-core/src/keyswitching/lwe_ct.rs b/poulpy-core/src/keyswitching/lwe_ct.rs deleted file mode 100644 index 7d9e08e..0000000 --- a/poulpy-core/src/keyswitching/lwe_ct.rs +++ /dev/null @@ -1,126 +0,0 @@ -use poulpy_hal::{ - api::{ - ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, - VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, - VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, - }, - layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut, ZnxZero}, -}; - -use crate::{ - TakeGLWECt, - layouts::{ - GGLWEInfos, GLWECiphertext, GLWECiphertextLayout, LWECiphertext, LWEInfos, Rank, TorusPrecision, - prepared::LWESwitchingKeyPrepared, - }, -}; - -impl LWECiphertext> { - pub fn keyswitch_scratch_space( - module: &Module, - out_infos: &OUT, - in_infos: &IN, - key_infos: &KEY, - ) -> usize - where - OUT: LWEInfos, - IN: LWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalizeTmpBytes, - { - let max_k: TorusPrecision = in_infos.k().max(out_infos.k()); - - let glwe_in_infos: GLWECiphertextLayout = GLWECiphertextLayout { - n: module.n().into(), - base2k: in_infos.base2k(), - k: max_k, - rank: Rank(1), - }; - - let glwe_out_infos: GLWECiphertextLayout = GLWECiphertextLayout { - n: module.n().into(), - base2k: out_infos.base2k(), - k: max_k, - rank: Rank(1), - }; - - let glwe_in: usize = GLWECiphertext::alloc_bytes(&glwe_in_infos); - let glwe_out: usize = GLWECiphertext::alloc_bytes(&glwe_out_infos); - let ks: usize = GLWECiphertext::keyswitch_scratch_space(module, &glwe_out_infos, &glwe_in_infos, key_infos); - - glwe_in + glwe_out + ks - } -} - -impl LWECiphertext { - pub fn keyswitch( - &mut self, - module: &Module, - a: &LWECiphertext, - ksk: &LWESwitchingKeyPrepared, - scratch: &mut Scratch, - ) where - A: DataRef, - DKs: DataRef, - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalize - + VecZnxNormalizeTmpBytes - + VecZnxCopy, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, - { - #[cfg(debug_assertions)] - { - assert!(self.n() <= module.n() as u32); - assert!(a.n() <= module.n() as u32); - assert!(scratch.available() >= LWECiphertext::keyswitch_scratch_space(module, self, a, ksk)); - } - - let max_k: TorusPrecision = self.k().max(a.k()); - - let a_size: usize = a.k().div_ceil(ksk.base2k()) as usize; - - let (mut glwe_in, scratch_1) = scratch.take_glwe_ct(&GLWECiphertextLayout { - n: ksk.n(), - base2k: a.base2k(), - k: max_k, - rank: Rank(1), - }); - glwe_in.data.zero(); - - let (mut glwe_out, scratch_1) = scratch_1.take_glwe_ct(&GLWECiphertextLayout { - n: ksk.n(), - base2k: self.base2k(), - k: max_k, - rank: Rank(1), - }); - - let n_lwe: usize = a.n().into(); - - for i in 0..a_size { - let data_lwe: &[i64] = a.data.at(0, i); - glwe_in.data.at_mut(0, i)[0] = data_lwe[0]; - glwe_in.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]); - } - - glwe_out.keyswitch(module, &glwe_in, &ksk.0, scratch_1); - self.sample_extract(&glwe_out); - } -} diff --git a/poulpy-core/src/keyswitching/mod.rs b/poulpy-core/src/keyswitching/mod.rs index c6a3610..fab0f5c 100644 --- a/poulpy-core/src/keyswitching/mod.rs +++ b/poulpy-core/src/keyswitching/mod.rs @@ -1,4 +1,9 @@ -mod gglwe_ct; -mod ggsw_ct; -mod glwe_ct; -mod lwe_ct; +mod gglwe; +mod ggsw; +mod glwe; +mod lwe; + +pub use gglwe::*; +pub use ggsw::*; +pub use glwe::*; +pub use lwe::*; diff --git a/poulpy-core/src/layouts/compressed/gglwe_ct.rs b/poulpy-core/src/layouts/compressed/gglwe.rs similarity index 56% rename from poulpy-core/src/layouts/compressed/gglwe_ct.rs rename to poulpy-core/src/layouts/compressed/gglwe.rs index f7a4df9..0fb0382 100644 --- a/poulpy-core/src/layouts/compressed/gglwe_ct.rs +++ b/poulpy-core/src/layouts/compressed/gglwe.rs @@ -1,18 +1,20 @@ use poulpy_hal::{ api::{VecZnxCopy, VecZnxFillUniform}, - layouts::{Backend, Data, DataMut, DataRef, FillUniform, MatZnx, Module, ReaderFrom, WriterTo, ZnxInfos}, + layouts::{ + Backend, Data, DataMut, DataRef, FillUniform, MatZnx, MatZnxToMut, MatZnxToRef, Module, ReaderFrom, WriterTo, ZnxInfos, + }, source::Source, }; use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWECiphertext, GGLWEInfos, GLWEInfos, LWEInfos, Rank, TorusPrecision, - compressed::{Decompress, GLWECiphertextCompressed}, + Base2K, Degree, Dnum, Dsize, GGLWE, GGLWEInfos, GGLWEToMut, GLWEInfos, LWEInfos, Rank, TorusPrecision, + compressed::{GLWECompressed, GLWEDecompress}, }; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::fmt; #[derive(PartialEq, Eq, Clone)] -pub struct GGLWECiphertextCompressed { +pub struct GGLWECompressed { pub(crate) data: MatZnx, pub(crate) base2k: Base2K, pub(crate) k: TorusPrecision, @@ -21,7 +23,26 @@ pub struct GGLWECiphertextCompressed { pub(crate) seed: Vec<[u8; 32]>, } -impl LWEInfos for GGLWECiphertextCompressed { +pub trait GGLWECompressedSeedMut { + fn seed_mut(&mut self) -> &mut Vec<[u8; 32]>; +} + +impl GGLWECompressedSeedMut for GGLWECompressed { + fn seed_mut(&mut self) -> &mut Vec<[u8; 32]> { + &mut self.seed + } +} + +pub trait GGLWECompressedSeed { + fn seed(&self) -> &Vec<[u8; 32]>; +} + +impl GGLWECompressedSeed for GGLWECompressed { + fn seed(&self) -> &Vec<[u8; 32]> { + &self.seed + } +} +impl LWEInfos for GGLWECompressed { fn n(&self) -> Degree { Degree(self.data.n() as u32) } @@ -38,13 +59,13 @@ impl LWEInfos for GGLWECiphertextCompressed { self.data.size() } } -impl GLWEInfos for GGLWECiphertextCompressed { +impl GLWEInfos for GGLWECompressed { fn rank(&self) -> Rank { self.rank_out() } } -impl GGLWEInfos for GGLWECiphertextCompressed { +impl GGLWEInfos for GGLWECompressed { fn rank_in(&self) -> Rank { Rank(self.data.cols_in() as u32) } @@ -62,34 +83,34 @@ impl GGLWEInfos for GGLWECiphertextCompressed { } } -impl fmt::Debug for GGLWECiphertextCompressed { +impl fmt::Debug for GGLWECompressed { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{self}") } } -impl FillUniform for GGLWECiphertextCompressed { +impl FillUniform for GGLWECompressed { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.data.fill_uniform(log_bound, source); } } -impl fmt::Display for GGLWECiphertextCompressed { +impl fmt::Display for GGLWECompressed { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "(GGLWECiphertextCompressed: base2k={} k={} dsize={}) {}", + "(GGLWECompressed: base2k={} k={} dsize={}) {}", self.base2k.0, self.k.0, self.dsize.0, self.data ) } } -impl GGLWECiphertextCompressed> { - pub fn alloc(infos: &A) -> Self +impl GGLWECompressed> { + pub fn alloc_from_infos(infos: &A) -> Self where A: GGLWEInfos, { - Self::alloc_with( + Self::alloc( infos.n(), infos.base2k(), infos.k(), @@ -100,15 +121,7 @@ impl GGLWECiphertextCompressed> { ) } - pub fn alloc_with( - n: Degree, - base2k: Base2K, - k: TorusPrecision, - rank_in: Rank, - rank_out: Rank, - dnum: Dnum, - dsize: Dsize, - ) -> Self { + pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank_in: Rank, rank_out: Rank, dnum: Dnum, dsize: Dsize) -> Self { let size: usize = k.0.div_ceil(base2k.0) as usize; debug_assert!( size as u32 > dsize.0, @@ -123,7 +136,7 @@ impl GGLWECiphertextCompressed> { dsize.0, ); - Self { + GGLWECompressed { data: MatZnx::alloc( n.into(), dnum.into(), @@ -139,11 +152,11 @@ impl GGLWECiphertextCompressed> { } } - pub fn alloc_bytes(infos: &A) -> usize + pub fn bytes_of_from_infos(infos: &A) -> usize where A: GGLWEInfos, { - Self::alloc_bytes_with( + Self::bytes_of( infos.n(), infos.base2k(), infos.k(), @@ -153,7 +166,7 @@ impl GGLWECiphertextCompressed> { ) } - pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum, dsize: Dsize) -> usize { + pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum, dsize: Dsize) -> usize { let size: usize = k.0.div_ceil(base2k.0) as usize; debug_assert!( size as u32 > dsize.0, @@ -168,7 +181,7 @@ impl GGLWECiphertextCompressed> { dsize.0, ); - MatZnx::alloc_bytes( + MatZnx::bytes_of( n.into(), dnum.into(), rank_in.into(), @@ -178,10 +191,10 @@ impl GGLWECiphertextCompressed> { } } -impl GGLWECiphertextCompressed { - pub(crate) fn at(&self, row: usize, col: usize) -> GLWECiphertextCompressed<&[u8]> { +impl GGLWECompressed { + pub(crate) fn at(&self, row: usize, col: usize) -> GLWECompressed<&[u8]> { let rank_in: usize = self.rank_in().into(); - GLWECiphertextCompressed { + GLWECompressed { data: self.data.at(row, col), k: self.k, base2k: self.base2k, @@ -191,10 +204,10 @@ impl GGLWECiphertextCompressed { } } -impl GGLWECiphertextCompressed { - pub(crate) fn at_mut(&mut self, row: usize, col: usize) -> GLWECiphertextCompressed<&mut [u8]> { +impl GGLWECompressed { + pub(crate) fn at_mut(&mut self, row: usize, col: usize) -> GLWECompressed<&mut [u8]> { let rank_in: usize = self.rank_in().into(); - GLWECiphertextCompressed { + GLWECompressed { k: self.k, base2k: self.base2k, rank: self.rank_out, @@ -204,7 +217,7 @@ impl GGLWECiphertextCompressed { } } -impl ReaderFrom for GGLWECiphertextCompressed { +impl ReaderFrom for GGLWECompressed { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { self.k = TorusPrecision(reader.read_u32::()?); self.base2k = Base2K(reader.read_u32::()?); @@ -219,7 +232,7 @@ impl ReaderFrom for GGLWECiphertextCompressed { } } -impl WriterTo for GGLWECiphertextCompressed { +impl WriterTo for GGLWECompressed { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { writer.write_u32::(self.k.into())?; writer.write_u32::(self.base2k.into())?; @@ -233,59 +246,74 @@ impl WriterTo for GGLWECiphertextCompressed { } } -impl Decompress> for GGLWECiphertext +pub trait GGLWEDecompress where - Module: VecZnxFillUniform + VecZnxCopy, + Self: GLWEDecompress, { - fn decompress(&mut self, module: &Module, other: &GGLWECiphertextCompressed) { - #[cfg(debug_assertions)] - { - assert_eq!( - self.n(), - other.n(), - "invalid receiver: self.n()={} != other.n()={}", - self.n(), - other.n() - ); - assert_eq!( - self.size(), - other.size(), - "invalid receiver: self.size()={} != other.size()={}", - self.size(), - other.size() - ); - assert_eq!( - self.rank_in(), - other.rank_in(), - "invalid receiver: self.rank_in()={} != other.rank_in()={}", - self.rank_in(), - other.rank_in() - ); - assert_eq!( - self.rank_out(), - other.rank_out(), - "invalid receiver: self.rank_out()={} != other.rank_out()={}", - self.rank_out(), - other.rank_out() - ); + fn decompress_gglwe(&self, res: &mut R, other: &O) + where + R: GGLWEToMut, + O: GGLWECompressedToRef, + { + let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut(); + let other: &GGLWECompressed<&[u8]> = &other.to_ref(); - assert_eq!( - self.dnum(), - other.dnum(), - "invalid receiver: self.dnum()={} != other.dnum()={}", - self.dnum(), - other.dnum() - ); + assert_eq!(res.dsize(), other.dsize()); + assert!(res.dnum() <= other.dnum()); + + let rank_in: usize = res.rank_in().into(); + let dnum: usize = res.dnum().into(); + + for row_i in 0..dnum { + for col_i in 0..rank_in { + self.decompress_glwe(&mut res.at_mut(row_i, col_i), &other.at(row_i, col_i)); + } + } + } +} + +impl GGLWEDecompress for Module where Self: VecZnxFillUniform + VecZnxCopy {} + +impl GGLWE { + pub fn decompress(&mut self, module: &M, other: &O) + where + O: GGLWECompressedToRef, + M: GGLWEDecompress, + { + module.decompress_gglwe(self, other); + } +} + +pub trait GGLWECompressedToMut { + fn to_mut(&mut self) -> GGLWECompressed<&mut [u8]>; +} + +impl GGLWECompressedToMut for GGLWECompressed { + fn to_mut(&mut self) -> GGLWECompressed<&mut [u8]> { + GGLWECompressed { + k: self.k(), + base2k: self.base2k(), + dsize: self.dsize(), + seed: self.seed.clone(), + rank_out: self.rank_out, + data: self.data.to_mut(), + } + } +} + +pub trait GGLWECompressedToRef { + fn to_ref(&self) -> GGLWECompressed<&[u8]>; +} + +impl GGLWECompressedToRef for GGLWECompressed { + fn to_ref(&self) -> GGLWECompressed<&[u8]> { + GGLWECompressed { + k: self.k(), + base2k: self.base2k(), + dsize: self.dsize(), + seed: self.seed.clone(), + rank_out: self.rank_out, + data: self.data.to_ref(), } - - let rank_in: usize = self.rank_in().into(); - let dnum: usize = self.dnum().into(); - - (0..rank_in).for_each(|col_i| { - (0..dnum).for_each(|row_i| { - self.at_mut(row_i, col_i) - .decompress(module, &other.at(row_i, col_i)); - }); - }); } } diff --git a/poulpy-core/src/layouts/compressed/gglwe_atk.rs b/poulpy-core/src/layouts/compressed/gglwe_atk.rs deleted file mode 100644 index 2a10765..0000000 --- a/poulpy-core/src/layouts/compressed/gglwe_atk.rs +++ /dev/null @@ -1,133 +0,0 @@ -use poulpy_hal::{ - api::{VecZnxCopy, VecZnxFillUniform}, - layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo}, - source::Source, -}; - -use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWEAutomorphismKey, GGLWEInfos, GLWEInfos, LWEInfos, Rank, TorusPrecision, - compressed::{Decompress, GGLWESwitchingKeyCompressed}, -}; -use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; -use std::fmt; - -#[derive(PartialEq, Eq, Clone)] -pub struct GGLWEAutomorphismKeyCompressed { - pub(crate) key: GGLWESwitchingKeyCompressed, - pub(crate) p: i64, -} - -impl LWEInfos for GGLWEAutomorphismKeyCompressed { - fn n(&self) -> Degree { - self.key.n() - } - - fn base2k(&self) -> Base2K { - self.key.base2k() - } - - fn k(&self) -> TorusPrecision { - self.key.k() - } - - fn size(&self) -> usize { - self.key.size() - } -} -impl GLWEInfos for GGLWEAutomorphismKeyCompressed { - fn rank(&self) -> Rank { - self.rank_out() - } -} - -impl GGLWEInfos for GGLWEAutomorphismKeyCompressed { - fn rank_in(&self) -> Rank { - self.key.rank_in() - } - - fn rank_out(&self) -> Rank { - self.key.rank_out() - } - - fn dsize(&self) -> Dsize { - self.key.dsize() - } - - fn dnum(&self) -> Dnum { - self.key.dnum() - } -} - -impl fmt::Debug for GGLWEAutomorphismKeyCompressed { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{self}") - } -} - -impl FillUniform for GGLWEAutomorphismKeyCompressed { - fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { - self.key.fill_uniform(log_bound, source); - } -} - -impl fmt::Display for GGLWEAutomorphismKeyCompressed { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "(AutomorphismKeyCompressed: p={}) {}", self.p, self.key) - } -} - -impl GGLWEAutomorphismKeyCompressed> { - pub fn alloc(infos: &A) -> Self - where - A: GGLWEInfos, - { - debug_assert_eq!(infos.rank_in(), infos.rank_out()); - Self { - key: GGLWESwitchingKeyCompressed::alloc(infos), - p: 0, - } - } - - pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self { - Self { - key: GGLWESwitchingKeyCompressed::alloc_with(n, base2k, k, rank, rank, dnum, dsize), - p: 0, - } - } - - pub fn alloc_bytes(infos: &A) -> usize - where - A: GGLWEInfos, - { - debug_assert_eq!(infos.rank_in(), infos.rank_out()); - GGLWESwitchingKeyCompressed::alloc_bytes(infos) - } - - pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { - GGLWESwitchingKeyCompressed::alloc_bytes_with(n, base2k, k, rank, dnum, dsize) - } -} - -impl ReaderFrom for GGLWEAutomorphismKeyCompressed { - fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { - self.p = reader.read_u64::()? as i64; - self.key.read_from(reader) - } -} - -impl WriterTo for GGLWEAutomorphismKeyCompressed { - fn write_to(&self, writer: &mut W) -> std::io::Result<()> { - writer.write_u64::(self.p as u64)?; - self.key.write_to(writer) - } -} - -impl Decompress> for GGLWEAutomorphismKey -where - Module: VecZnxFillUniform + VecZnxCopy, -{ - fn decompress(&mut self, module: &Module, other: &GGLWEAutomorphismKeyCompressed) { - self.key.decompress(module, &other.key); - self.p = other.p; - } -} diff --git a/poulpy-core/src/layouts/compressed/gglwe_ksk.rs b/poulpy-core/src/layouts/compressed/gglwe_ksk.rs deleted file mode 100644 index 60d9316..0000000 --- a/poulpy-core/src/layouts/compressed/gglwe_ksk.rs +++ /dev/null @@ -1,149 +0,0 @@ -use poulpy_hal::{ - api::{VecZnxCopy, VecZnxFillUniform}, - layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo}, - source::Source, -}; - -use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWESwitchingKey, GLWEInfos, LWEInfos, Rank, TorusPrecision, - compressed::{Decompress, GGLWECiphertextCompressed}, -}; -use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; -use std::fmt; - -#[derive(PartialEq, Eq, Clone)] -pub struct GGLWESwitchingKeyCompressed { - pub(crate) key: GGLWECiphertextCompressed, - pub(crate) sk_in_n: usize, // Degree of sk_in - pub(crate) sk_out_n: usize, // Degree of sk_out -} - -impl LWEInfos for GGLWESwitchingKeyCompressed { - fn n(&self) -> Degree { - self.key.n() - } - - fn base2k(&self) -> Base2K { - self.key.base2k() - } - - fn k(&self) -> TorusPrecision { - self.key.k() - } - - fn size(&self) -> usize { - self.key.size() - } -} -impl GLWEInfos for GGLWESwitchingKeyCompressed { - fn rank(&self) -> Rank { - self.rank_out() - } -} - -impl GGLWEInfos for GGLWESwitchingKeyCompressed { - fn rank_in(&self) -> Rank { - self.key.rank_in() - } - - fn rank_out(&self) -> Rank { - self.key.rank_out() - } - - fn dsize(&self) -> Dsize { - self.key.dsize() - } - - fn dnum(&self) -> Dnum { - self.key.dnum() - } -} - -impl fmt::Debug for GGLWESwitchingKeyCompressed { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{self}") - } -} - -impl FillUniform for GGLWESwitchingKeyCompressed { - fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { - self.key.fill_uniform(log_bound, source); - } -} - -impl fmt::Display for GGLWESwitchingKeyCompressed { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "(GLWESwitchingKeyCompressed: sk_in_n={} sk_out_n={}) {}", - self.sk_in_n, self.sk_out_n, self.key.data - ) - } -} - -impl GGLWESwitchingKeyCompressed> { - pub fn alloc(infos: &A) -> Self - where - A: GGLWEInfos, - { - GGLWESwitchingKeyCompressed { - key: GGLWECiphertextCompressed::alloc(infos), - sk_in_n: 0, - sk_out_n: 0, - } - } - - pub fn alloc_with( - n: Degree, - base2k: Base2K, - k: TorusPrecision, - rank_in: Rank, - rank_out: Rank, - dnum: Dnum, - dsize: Dsize, - ) -> Self { - GGLWESwitchingKeyCompressed { - key: GGLWECiphertextCompressed::alloc_with(n, base2k, k, rank_in, rank_out, dnum, dsize), - sk_in_n: 0, - sk_out_n: 0, - } - } - - pub fn alloc_bytes(infos: &A) -> usize - where - A: GGLWEInfos, - { - GGLWECiphertextCompressed::alloc_bytes(infos) - } - - pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum, dsize: Dsize) -> usize { - GGLWECiphertextCompressed::alloc_bytes_with(n, base2k, k, rank_in, dnum, dsize) - } -} - -impl ReaderFrom for GGLWESwitchingKeyCompressed { - fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { - self.sk_in_n = reader.read_u64::()? as usize; - self.sk_out_n = reader.read_u64::()? as usize; - self.key.read_from(reader) - } -} - -impl WriterTo for GGLWESwitchingKeyCompressed { - fn write_to(&self, writer: &mut W) -> std::io::Result<()> { - writer.write_u64::(self.sk_in_n as u64)?; - writer.write_u64::(self.sk_out_n as u64)?; - self.key.write_to(writer) - } -} - -impl Decompress> for GGLWESwitchingKey -where - Module: VecZnxFillUniform + VecZnxCopy, -{ - fn decompress(&mut self, module: &Module, other: &GGLWESwitchingKeyCompressed) { - self.key.decompress(module, &other.key); - self.sk_in_n = other.sk_in_n; - self.sk_out_n = other.sk_out_n; - } -} diff --git a/poulpy-core/src/layouts/compressed/gglwe_tsk.rs b/poulpy-core/src/layouts/compressed/gglwe_tsk.rs deleted file mode 100644 index fef4647..0000000 --- a/poulpy-core/src/layouts/compressed/gglwe_tsk.rs +++ /dev/null @@ -1,207 +0,0 @@ -use poulpy_hal::{ - api::{VecZnxCopy, VecZnxFillUniform}, - layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo}, - source::Source, -}; - -use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWETensorKey, GLWEInfos, LWEInfos, Rank, TorusPrecision, - compressed::{Decompress, GGLWESwitchingKeyCompressed}, -}; -use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; -use std::fmt; - -#[derive(PartialEq, Eq, Clone)] -pub struct GGLWETensorKeyCompressed { - pub(crate) keys: Vec>, -} - -impl LWEInfos for GGLWETensorKeyCompressed { - fn n(&self) -> Degree { - self.keys[0].n() - } - - fn base2k(&self) -> Base2K { - self.keys[0].base2k() - } - - fn k(&self) -> TorusPrecision { - self.keys[0].k() - } - fn size(&self) -> usize { - self.keys[0].size() - } -} -impl GLWEInfos for GGLWETensorKeyCompressed { - fn rank(&self) -> Rank { - self.rank_out() - } -} - -impl GGLWEInfos for GGLWETensorKeyCompressed { - fn rank_in(&self) -> Rank { - self.rank_out() - } - - fn rank_out(&self) -> Rank { - self.keys[0].rank_out() - } - - fn dsize(&self) -> Dsize { - self.keys[0].dsize() - } - - fn dnum(&self) -> Dnum { - self.keys[0].dnum() - } -} - -impl fmt::Debug for GGLWETensorKeyCompressed { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{self}") - } -} - -impl FillUniform for GGLWETensorKeyCompressed { - fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { - self.keys - .iter_mut() - .for_each(|key: &mut GGLWESwitchingKeyCompressed| key.fill_uniform(log_bound, source)) - } -} - -impl fmt::Display for GGLWETensorKeyCompressed { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!(f, "(GLWETensorKeyCompressed)",)?; - for (i, key) in self.keys.iter().enumerate() { - write!(f, "{i}: {key}")?; - } - Ok(()) - } -} - -impl GGLWETensorKeyCompressed> { - pub fn alloc(infos: &A) -> Self - where - A: GGLWEInfos, - { - assert_eq!( - infos.rank_in(), - infos.rank_out(), - "rank_in != rank_out is not supported for GGLWETensorKeyCompressed" - ); - Self::alloc_with( - infos.n(), - infos.base2k(), - infos.k(), - infos.rank_out(), - infos.dnum(), - infos.dsize(), - ) - } - - pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self { - let mut keys: Vec>> = Vec::new(); - let pairs: u32 = (((rank.0 + 1) * rank.0) >> 1).max(1); - (0..pairs).for_each(|_| { - keys.push(GGLWESwitchingKeyCompressed::alloc_with( - n, - base2k, - k, - Rank(1), - rank, - dnum, - dsize, - )); - }); - Self { keys } - } - - pub fn alloc_bytes(infos: &A) -> usize - where - A: GGLWEInfos, - { - assert_eq!( - infos.rank_in(), - infos.rank_out(), - "rank_in != rank_out is not supported for GGLWETensorKeyCompressed" - ); - let rank_out: usize = infos.rank_out().into(); - let pairs: usize = (((rank_out + 1) * rank_out) >> 1).max(1); - pairs - * GGLWESwitchingKeyCompressed::alloc_bytes_with( - infos.n(), - infos.base2k(), - infos.k(), - Rank(1), - infos.dnum(), - infos.dsize(), - ) - } - - pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { - let pairs: usize = (((rank.0 + 1) * rank.0) >> 1).max(1) as usize; - pairs * GGLWESwitchingKeyCompressed::alloc_bytes_with(n, base2k, k, Rank(1), dnum, dsize) - } -} - -impl ReaderFrom for GGLWETensorKeyCompressed { - fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { - let len: usize = reader.read_u64::()? as usize; - if self.keys.len() != len { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - format!("self.keys.len()={} != read len={}", self.keys.len(), len), - )); - } - for key in &mut self.keys { - key.read_from(reader)?; - } - Ok(()) - } -} - -impl WriterTo for GGLWETensorKeyCompressed { - fn write_to(&self, writer: &mut W) -> std::io::Result<()> { - writer.write_u64::(self.keys.len() as u64)?; - for key in &self.keys { - key.write_to(writer)?; - } - Ok(()) - } -} - -impl GGLWETensorKeyCompressed { - pub(crate) fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GGLWESwitchingKeyCompressed { - if i > j { - std::mem::swap(&mut i, &mut j); - }; - let rank: usize = self.rank_out().into(); - &mut self.keys[i * rank + j - (i * (i + 1) / 2)] - } -} - -impl Decompress> for GGLWETensorKey -where - Module: VecZnxFillUniform + VecZnxCopy, -{ - fn decompress(&mut self, module: &Module, other: &GGLWETensorKeyCompressed) { - #[cfg(debug_assertions)] - { - assert_eq!( - self.keys.len(), - other.keys.len(), - "invalid receiver: self.keys.len()={} != other.keys.len()={}", - self.keys.len(), - other.keys.len() - ); - } - - self.keys - .iter_mut() - .zip(other.keys.iter()) - .for_each(|(a, b)| { - a.decompress(module, b); - }); - } -} diff --git a/poulpy-core/src/layouts/compressed/ggsw_ct.rs b/poulpy-core/src/layouts/compressed/ggsw.rs similarity index 53% rename from poulpy-core/src/layouts/compressed/ggsw_ct.rs rename to poulpy-core/src/layouts/compressed/ggsw.rs index f0a62cc..14fdd7a 100644 --- a/poulpy-core/src/layouts/compressed/ggsw_ct.rs +++ b/poulpy-core/src/layouts/compressed/ggsw.rs @@ -1,18 +1,19 @@ use poulpy_hal::{ - api::{VecZnxCopy, VecZnxFillUniform}, - layouts::{Backend, Data, DataMut, DataRef, FillUniform, MatZnx, Module, ReaderFrom, WriterTo, ZnxInfos}, + layouts::{ + Backend, Data, DataMut, DataRef, FillUniform, MatZnx, MatZnxToMut, MatZnxToRef, Module, ReaderFrom, WriterTo, ZnxInfos, + }, source::Source, }; use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGSWCiphertext, GGSWInfos, GLWEInfos, LWEInfos, Rank, TorusPrecision, - compressed::{Decompress, GLWECiphertextCompressed}, + Base2K, Degree, Dnum, Dsize, GGSW, GGSWInfos, GGSWToMut, GLWEInfos, LWEInfos, Rank, TorusPrecision, + compressed::{GLWECompressed, GLWEDecompress}, }; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::fmt; #[derive(PartialEq, Eq, Clone)] -pub struct GGSWCiphertextCompressed { +pub struct GGSWCompressed { pub(crate) data: MatZnx, pub(crate) k: TorusPrecision, pub(crate) base2k: Base2K, @@ -21,7 +22,27 @@ pub struct GGSWCiphertextCompressed { pub(crate) seed: Vec<[u8; 32]>, } -impl LWEInfos for GGSWCiphertextCompressed { +pub trait GGSWCompressedSeedMut { + fn seed_mut(&mut self) -> &mut Vec<[u8; 32]>; +} + +impl GGSWCompressedSeedMut for GGSWCompressed { + fn seed_mut(&mut self) -> &mut Vec<[u8; 32]> { + &mut self.seed + } +} + +pub trait GGSWCompressedSeed { + fn seed(&self) -> &Vec<[u8; 32]>; +} + +impl GGSWCompressedSeed for GGSWCompressed { + fn seed(&self) -> &Vec<[u8; 32]> { + &self.seed + } +} + +impl LWEInfos for GGSWCompressed { fn n(&self) -> Degree { Degree(self.data.n() as u32) } @@ -37,13 +58,13 @@ impl LWEInfos for GGSWCiphertextCompressed { self.data.size() } } -impl GLWEInfos for GGSWCiphertextCompressed { +impl GLWEInfos for GGSWCompressed { fn rank(&self) -> Rank { self.rank } } -impl GGSWInfos for GGSWCiphertextCompressed { +impl GGSWInfos for GGSWCompressed { fn dsize(&self) -> Dsize { self.dsize } @@ -53,34 +74,34 @@ impl GGSWInfos for GGSWCiphertextCompressed { } } -impl fmt::Debug for GGSWCiphertextCompressed { +impl fmt::Debug for GGSWCompressed { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}", self.data) } } -impl fmt::Display for GGSWCiphertextCompressed { +impl fmt::Display for GGSWCompressed { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "(GGSWCiphertextCompressed: base2k={} k={} dsize={}) {}", + "(GGSWCompressed: base2k={} k={} dsize={}) {}", self.base2k, self.k, self.dsize, self.data ) } } -impl FillUniform for GGSWCiphertextCompressed { +impl FillUniform for GGSWCompressed { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.data.fill_uniform(log_bound, source); } } -impl GGSWCiphertextCompressed> { - pub fn alloc(infos: &A) -> Self +impl GGSWCompressed> { + pub fn alloc_from_infos(infos: &A) -> Self where A: GGSWInfos, { - Self::alloc_with( + Self::alloc( infos.n(), infos.base2k(), infos.k(), @@ -90,9 +111,9 @@ impl GGSWCiphertextCompressed> { ) } - pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self { + pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self { let size: usize = k.0.div_ceil(base2k.0) as usize; - debug_assert!( + assert!( size as u32 > dsize.0, "invalid ggsw: ceil(k/base2k): {size} <= dsize: {}", dsize.0 @@ -105,7 +126,7 @@ impl GGSWCiphertextCompressed> { dsize.0, ); - Self { + GGSWCompressed { data: MatZnx::alloc( n.into(), dnum.into(), @@ -117,15 +138,15 @@ impl GGSWCiphertextCompressed> { base2k, dsize, rank, - seed: Vec::new(), + seed: vec![[0u8; 32]; dnum.as_usize() * (rank.as_usize() + 1)], } } - pub fn alloc_bytes(infos: &A) -> usize + pub fn bytes_of_from_infos(infos: &A) -> usize where A: GGSWInfos, { - Self::alloc_bytes_with( + Self::bytes_of( infos.n(), infos.base2k(), infos.k(), @@ -135,9 +156,9 @@ impl GGSWCiphertextCompressed> { ) } - pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { + pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { let size: usize = k.0.div_ceil(base2k.0) as usize; - debug_assert!( + assert!( size as u32 > dsize.0, "invalid ggsw: ceil(k/base2k): {size} <= dsize: {}", dsize.0 @@ -150,7 +171,7 @@ impl GGSWCiphertextCompressed> { dsize.0, ); - MatZnx::alloc_bytes( + MatZnx::bytes_of( n.into(), dnum.into(), (rank + 1).into(), @@ -160,10 +181,10 @@ impl GGSWCiphertextCompressed> { } } -impl GGSWCiphertextCompressed { - pub fn at(&self, row: usize, col: usize) -> GLWECiphertextCompressed<&[u8]> { +impl GGSWCompressed { + pub fn at(&self, row: usize, col: usize) -> GLWECompressed<&[u8]> { let rank: usize = self.rank().into(); - GLWECiphertextCompressed { + GLWECompressed { data: self.data.at(row, col), k: self.k, base2k: self.base2k, @@ -173,10 +194,10 @@ impl GGSWCiphertextCompressed { } } -impl GGSWCiphertextCompressed { - pub fn at_mut(&mut self, row: usize, col: usize) -> GLWECiphertextCompressed<&mut [u8]> { +impl GGSWCompressed { + pub fn at_mut(&mut self, row: usize, col: usize) -> GLWECompressed<&mut [u8]> { let rank: usize = self.rank().into(); - GLWECiphertextCompressed { + GLWECompressed { data: self.data.at_mut(row, col), k: self.k, base2k: self.base2k, @@ -186,7 +207,7 @@ impl GGSWCiphertextCompressed { } } -impl ReaderFrom for GGSWCiphertextCompressed { +impl ReaderFrom for GGSWCompressed { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { self.k = TorusPrecision(reader.read_u32::()?); self.base2k = Base2K(reader.read_u32::()?); @@ -201,7 +222,7 @@ impl ReaderFrom for GGSWCiphertextCompressed { } } -impl WriterTo for GGSWCiphertextCompressed { +impl WriterTo for GGSWCompressed { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { writer.write_u32::(self.k.into())?; writer.write_u32::(self.base2k.into())?; @@ -215,23 +236,72 @@ impl WriterTo for GGSWCiphertextCompressed { } } -impl Decompress> for GGSWCiphertext +pub trait GGSWDecompress where - Module: VecZnxFillUniform + VecZnxCopy, + Self: GLWEDecompress, { - fn decompress(&mut self, module: &Module, other: &GGSWCiphertextCompressed) { - #[cfg(debug_assertions)] - { - assert_eq!(self.rank(), other.rank()) - } + fn decompress_ggsw(&self, res: &mut R, other: &O) + where + R: GGSWToMut, + O: GGSWCompressedToRef, + { + let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); + let other: &GGSWCompressed<&[u8]> = &other.to_ref(); - let dnum: usize = self.dnum().into(); - let rank: usize = self.rank().into(); - (0..dnum).for_each(|row_i| { - (0..rank + 1).for_each(|col_j| { - self.at_mut(row_i, col_j) - .decompress(module, &other.at(row_i, col_j)); - }); - }); + assert_eq!(res.rank(), other.rank()); + let dnum: usize = res.dnum().into(); + let rank: usize = res.rank().into(); + + for row_i in 0..dnum { + for col_j in 0..rank + 1 { + self.decompress_glwe(&mut res.at_mut(row_i, col_j), &other.at(row_i, col_j)); + } + } + } +} + +impl GGSWDecompress for Module where Self: GLWEDecompress {} + +impl GGSW { + pub fn decompress(&mut self, module: &M, other: &O) + where + O: GGSWCompressedToRef, + M: GGSWDecompress, + { + module.decompress_ggsw(self, other); + } +} + +pub trait GGSWCompressedToMut { + fn to_mut(&mut self) -> GGSWCompressed<&mut [u8]>; +} + +impl GGSWCompressedToMut for GGSWCompressed { + fn to_mut(&mut self) -> GGSWCompressed<&mut [u8]> { + GGSWCompressed { + k: self.k(), + base2k: self.base2k(), + dsize: self.dsize(), + rank: self.rank(), + seed: self.seed.clone(), + data: self.data.to_mut(), + } + } +} + +pub trait GGSWCompressedToRef { + fn to_ref(&self) -> GGSWCompressed<&[u8]>; +} + +impl GGSWCompressedToRef for GGSWCompressed { + fn to_ref(&self) -> GGSWCompressed<&[u8]> { + GGSWCompressed { + k: self.k(), + base2k: self.base2k(), + dsize: self.dsize(), + rank: self.rank(), + seed: self.seed.clone(), + data: self.data.to_ref(), + } } } diff --git a/poulpy-core/src/layouts/compressed/glwe.rs b/poulpy-core/src/layouts/compressed/glwe.rs new file mode 100644 index 0000000..eda1e8f --- /dev/null +++ b/poulpy-core/src/layouts/compressed/glwe.rs @@ -0,0 +1,218 @@ +use poulpy_hal::{ + api::{VecZnxCopy, VecZnxFillUniform}, + layouts::{ + Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, VecZnx, VecZnxToMut, VecZnxToRef, WriterTo, ZnxInfos, + }, + source::Source, +}; + +use crate::layouts::{Base2K, Degree, GLWE, GLWEInfos, GLWEToMut, GetDegree, LWEInfos, Rank, SetGLWEInfos, TorusPrecision}; +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; +use std::fmt; + +#[derive(PartialEq, Eq, Clone)] +pub struct GLWECompressed { + pub(crate) data: VecZnx, + pub(crate) base2k: Base2K, + pub(crate) k: TorusPrecision, + pub(crate) rank: Rank, + pub(crate) seed: [u8; 32], +} + +pub trait GLWECompressedSeedMut { + fn seed_mut(&mut self) -> &mut [u8; 32]; +} + +impl GLWECompressedSeedMut for GLWECompressed { + fn seed_mut(&mut self) -> &mut [u8; 32] { + &mut self.seed + } +} + +pub trait GLWECompressedSeed { + fn seed(&self) -> &[u8; 32]; +} + +impl GLWECompressedSeed for GLWECompressed { + fn seed(&self) -> &[u8; 32] { + &self.seed + } +} + +impl LWEInfos for GLWECompressed { + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } + + fn size(&self) -> usize { + self.data.size() + } + + fn n(&self) -> Degree { + Degree(self.data.n() as u32) + } +} +impl GLWEInfos for GLWECompressed { + fn rank(&self) -> Rank { + self.rank + } +} + +impl fmt::Debug for GLWECompressed { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{self}") + } +} + +impl fmt::Display for GLWECompressed { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "GLWECompressed: base2k={} k={} rank={} seed={:?}: {}", + self.base2k(), + self.k(), + self.rank(), + self.seed, + self.data + ) + } +} + +impl FillUniform for GLWECompressed { + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { + self.data.fill_uniform(log_bound, source); + } +} + +impl GLWECompressed> { + pub fn alloc_from_infos(infos: &A) -> Self + where + A: GLWEInfos, + { + Self::alloc(infos.n(), infos.base2k(), infos.k(), infos.rank()) + } + + pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self { + GLWECompressed { + data: VecZnx::alloc(n.into(), 1, k.0.div_ceil(base2k.0) as usize), + base2k, + k, + rank, + seed: [0u8; 32], + } + } + + pub fn bytes_of_from_infos(infos: &A) -> usize + where + A: GLWEInfos, + { + Self::bytes_of(infos.n(), infos.base2k(), infos.k()) + } + + pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision) -> usize { + VecZnx::bytes_of(n.into(), 1, k.0.div_ceil(base2k.0) as usize) + } +} + +impl ReaderFrom for GLWECompressed { + fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { + self.k = TorusPrecision(reader.read_u32::()?); + self.base2k = Base2K(reader.read_u32::()?); + self.rank = Rank(reader.read_u32::()?); + reader.read_exact(&mut self.seed)?; + self.data.read_from(reader) + } +} + +impl WriterTo for GLWECompressed { + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + writer.write_u32::(self.k.into())?; + writer.write_u32::(self.base2k.into())?; + writer.write_u32::(self.rank.into())?; + writer.write_all(&self.seed)?; + self.data.write_to(writer) + } +} + +pub trait GLWEDecompress +where + Self: GetDegree + VecZnxFillUniform + VecZnxCopy, +{ + fn decompress_glwe(&self, res: &mut R, other: &O) + where + R: GLWEToMut + SetGLWEInfos, + O: GLWECompressedToRef + GLWEInfos, + { + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let other: &GLWECompressed<&[u8]> = &other.to_ref(); + assert_eq!( + res.n(), + self.ring_degree(), + "invalid receiver: res.n()={} != other.n()={}", + res.n(), + self.ring_degree() + ); + + assert_eq!(res.glwe_layout(), other.glwe_layout()); + + let mut source: Source = Source::new(other.seed); + + self.vec_znx_copy(&mut res.data, 0, &other.data, 0); + (1..(other.rank() + 1).into()).for_each(|i| { + self.vec_znx_fill_uniform(other.base2k.into(), &mut res.data, i, &mut source); + }); + } + + res.set_base2k(other.base2k()); + res.set_k(other.k()); + } +} + +impl GLWEDecompress for Module where Self: GetDegree + VecZnxFillUniform + VecZnxCopy {} + +impl GLWE { + pub fn decompress(&mut self, module: &M, other: &O) + where + O: GLWECompressedToRef + GLWEInfos, + M: GLWEDecompress, + { + module.decompress_glwe(self, other); + } +} + +pub trait GLWECompressedToRef { + fn to_ref(&self) -> GLWECompressed<&[u8]>; +} + +impl GLWECompressedToRef for GLWECompressed { + fn to_ref(&self) -> GLWECompressed<&[u8]> { + GLWECompressed { + seed: self.seed, + base2k: self.base2k, + k: self.k, + rank: self.rank, + data: self.data.to_ref(), + } + } +} + +pub trait GLWECompressedToMut { + fn to_mut(&mut self) -> GLWECompressed<&mut [u8]>; +} + +impl GLWECompressedToMut for GLWECompressed { + fn to_mut(&mut self) -> GLWECompressed<&mut [u8]> { + GLWECompressed { + seed: self.seed, + base2k: self.base2k, + k: self.k, + rank: self.rank, + data: self.data.to_mut(), + } + } +} diff --git a/poulpy-core/src/layouts/compressed/glwe_automorphism_key.rs b/poulpy-core/src/layouts/compressed/glwe_automorphism_key.rs new file mode 100644 index 0000000..3e0a4b5 --- /dev/null +++ b/poulpy-core/src/layouts/compressed/glwe_automorphism_key.rs @@ -0,0 +1,192 @@ +use poulpy_hal::{ + layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo}, + source::Source, +}; + +use crate::layouts::{ + Base2K, Degree, Dnum, Dsize, GGLWECompressed, GGLWECompressedSeedMut, GGLWECompressedToMut, GGLWECompressedToRef, + GGLWEDecompress, GGLWEInfos, GGLWEToMut, GLWEAutomorphismKey, GLWEDecompress, GLWEInfos, GetGaloisElement, LWEInfos, Rank, + SetGaloisElement, TorusPrecision, +}; +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; +use std::fmt; + +#[derive(PartialEq, Eq, Clone)] +pub struct GLWEAutomorphismKeyCompressed { + pub(crate) key: GGLWECompressed, + pub(crate) p: i64, +} + +impl GetGaloisElement for GLWEAutomorphismKeyCompressed { + fn p(&self) -> i64 { + self.p + } +} + +impl LWEInfos for GLWEAutomorphismKeyCompressed { + fn n(&self) -> Degree { + self.key.n() + } + + fn base2k(&self) -> Base2K { + self.key.base2k() + } + + fn k(&self) -> TorusPrecision { + self.key.k() + } + + fn size(&self) -> usize { + self.key.size() + } +} + +impl GLWEInfos for GLWEAutomorphismKeyCompressed { + fn rank(&self) -> Rank { + self.rank_out() + } +} + +impl GGLWEInfos for GLWEAutomorphismKeyCompressed { + fn rank_in(&self) -> Rank { + self.key.rank_in() + } + + fn rank_out(&self) -> Rank { + self.key.rank_out() + } + + fn dsize(&self) -> Dsize { + self.key.dsize() + } + + fn dnum(&self) -> Dnum { + self.key.dnum() + } +} + +impl fmt::Debug for GLWEAutomorphismKeyCompressed { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{self}") + } +} + +impl FillUniform for GLWEAutomorphismKeyCompressed { + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { + self.key.fill_uniform(log_bound, source); + } +} + +impl fmt::Display for GLWEAutomorphismKeyCompressed { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "(AutomorphismKeyCompressed: p={}) {}", self.p, self.key) + } +} + +impl GLWEAutomorphismKeyCompressed> { + pub fn alloc_from_infos(infos: &A) -> Self + where + A: GGLWEInfos, + { + Self::alloc( + infos.n(), + infos.base2k(), + infos.k(), + infos.rank(), + infos.dnum(), + infos.dsize(), + ) + } + + pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self { + GLWEAutomorphismKeyCompressed { + key: GGLWECompressed::alloc(n, base2k, k, rank, rank, dnum, dsize), + p: 0, + } + } + + pub fn bytes_of_from_infos(infos: &A) -> usize + where + A: GGLWEInfos, + { + Self::bytes_of( + infos.n(), + infos.base2k(), + infos.k(), + infos.rank(), + infos.dnum(), + infos.dsize(), + ) + } + + pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { + GGLWECompressed::bytes_of(n, base2k, k, rank, dnum, dsize) + } +} + +impl ReaderFrom for GLWEAutomorphismKeyCompressed { + fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { + self.p = reader.read_u64::()? as i64; + self.key.read_from(reader) + } +} + +impl WriterTo for GLWEAutomorphismKeyCompressed { + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + writer.write_u64::(self.p as u64)?; + self.key.write_to(writer) + } +} + +pub trait AutomorphismKeyDecompress +where + Self: GGLWEDecompress, +{ + fn decompress_automorphism_key(&self, res: &mut R, other: &O) + where + R: GGLWEToMut + SetGaloisElement, + O: GGLWECompressedToRef + GetGaloisElement, + { + self.decompress_gglwe(res, other); + res.set_p(other.p()); + } +} + +impl AutomorphismKeyDecompress for Module where Self: GLWEDecompress {} + +impl GLWEAutomorphismKey +where + Self: SetGaloisElement, +{ + pub fn decompress(&mut self, module: &M, other: &O) + where + O: GGLWECompressedToRef + GetGaloisElement, + M: AutomorphismKeyDecompress, + { + module.decompress_automorphism_key(self, other); + } +} + +impl GGLWECompressedToRef for GLWEAutomorphismKeyCompressed { + fn to_ref(&self) -> GGLWECompressed<&[u8]> { + self.key.to_ref() + } +} + +impl GGLWECompressedToMut for GLWEAutomorphismKeyCompressed { + fn to_mut(&mut self) -> GGLWECompressed<&mut [u8]> { + self.key.to_mut() + } +} + +impl GGLWECompressedSeedMut for GLWEAutomorphismKeyCompressed { + fn seed_mut(&mut self) -> &mut Vec<[u8; 32]> { + &mut self.key.seed + } +} + +impl SetGaloisElement for GLWEAutomorphismKeyCompressed { + fn set_p(&mut self, p: i64) { + self.p = p + } +} diff --git a/poulpy-core/src/layouts/compressed/glwe_ct.rs b/poulpy-core/src/layouts/compressed/glwe_ct.rs deleted file mode 100644 index 30a3733..0000000 --- a/poulpy-core/src/layouts/compressed/glwe_ct.rs +++ /dev/null @@ -1,178 +0,0 @@ -use poulpy_hal::{ - api::{VecZnxCopy, VecZnxFillUniform}, - layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, VecZnx, WriterTo, ZnxInfos}, - source::Source, -}; - -use crate::layouts::{Base2K, Degree, GLWECiphertext, GLWEInfos, LWEInfos, Rank, TorusPrecision, compressed::Decompress}; -use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; -use std::fmt; - -#[derive(PartialEq, Eq, Clone)] -pub struct GLWECiphertextCompressed { - pub(crate) data: VecZnx, - pub(crate) base2k: Base2K, - pub(crate) k: TorusPrecision, - pub(crate) rank: Rank, - pub(crate) seed: [u8; 32], -} - -impl LWEInfos for GLWECiphertextCompressed { - fn base2k(&self) -> Base2K { - self.base2k - } - - fn k(&self) -> TorusPrecision { - self.k - } - - fn size(&self) -> usize { - self.data.size() - } - - fn n(&self) -> Degree { - Degree(self.data.n() as u32) - } -} -impl GLWEInfos for GLWECiphertextCompressed { - fn rank(&self) -> Rank { - self.rank - } -} - -impl fmt::Debug for GLWECiphertextCompressed { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{self}") - } -} - -impl fmt::Display for GLWECiphertextCompressed { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "GLWECiphertextCompressed: base2k={} k={} rank={} seed={:?}: {}", - self.base2k(), - self.k(), - self.rank(), - self.seed, - self.data - ) - } -} - -impl FillUniform for GLWECiphertextCompressed { - fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { - self.data.fill_uniform(log_bound, source); - } -} - -impl GLWECiphertextCompressed> { - pub fn alloc(infos: &A) -> Self - where - A: GLWEInfos, - { - Self::alloc_with(infos.n(), infos.base2k(), infos.k(), infos.rank()) - } - - pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self { - Self { - data: VecZnx::alloc(n.into(), 1, k.0.div_ceil(base2k.0) as usize), - base2k, - k, - rank, - seed: [0u8; 32], - } - } - - pub fn alloc_bytes(infos: &A) -> usize - where - A: GLWEInfos, - { - Self::alloc_bytes_with(infos.n(), infos.base2k(), infos.k()) - } - - pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision) -> usize { - VecZnx::alloc_bytes(n.into(), 1, k.0.div_ceil(base2k.0) as usize) - } -} - -impl ReaderFrom for GLWECiphertextCompressed { - fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { - self.k = TorusPrecision(reader.read_u32::()?); - self.base2k = Base2K(reader.read_u32::()?); - self.rank = Rank(reader.read_u32::()?); - reader.read_exact(&mut self.seed)?; - self.data.read_from(reader) - } -} - -impl WriterTo for GLWECiphertextCompressed { - fn write_to(&self, writer: &mut W) -> std::io::Result<()> { - writer.write_u32::(self.k.into())?; - writer.write_u32::(self.base2k.into())?; - writer.write_u32::(self.rank.into())?; - writer.write_all(&self.seed)?; - self.data.write_to(writer) - } -} - -impl Decompress> for GLWECiphertext -where - Module: VecZnxFillUniform + VecZnxCopy, -{ - fn decompress(&mut self, module: &Module, other: &GLWECiphertextCompressed) { - #[cfg(debug_assertions)] - { - assert_eq!( - self.n(), - other.n(), - "invalid receiver: self.n()={} != other.n()={}", - self.n(), - other.n() - ); - assert_eq!( - self.size(), - other.size(), - "invalid receiver: self.size()={} != other.size()={}", - self.size(), - other.size() - ); - assert_eq!( - self.rank(), - other.rank(), - "invalid receiver: self.rank()={} != other.rank()={}", - self.rank(), - other.rank() - ); - } - - let mut source: Source = Source::new(other.seed); - self.decompress_internal(module, other, &mut source); - } -} - -impl GLWECiphertext { - pub(crate) fn decompress_internal( - &mut self, - module: &Module, - other: &GLWECiphertextCompressed, - source: &mut Source, - ) where - DataOther: DataRef, - Module: VecZnxCopy + VecZnxFillUniform, - { - #[cfg(debug_assertions)] - { - assert_eq!(self.rank(), other.rank()); - debug_assert_eq!(self.size(), other.size()); - } - - module.vec_znx_copy(&mut self.data, 0, &other.data, 0); - (1..(other.rank() + 1).into()).for_each(|i| { - module.vec_znx_fill_uniform(other.base2k.into(), &mut self.data, i, source); - }); - - self.base2k = other.base2k; - self.k = other.k; - } -} diff --git a/poulpy-core/src/layouts/compressed/glwe_switching_key.rs b/poulpy-core/src/layouts/compressed/glwe_switching_key.rs new file mode 100644 index 0000000..6da1d32 --- /dev/null +++ b/poulpy-core/src/layouts/compressed/glwe_switching_key.rs @@ -0,0 +1,201 @@ +use poulpy_hal::{ + layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo}, + source::Source, +}; + +use crate::layouts::{ + Base2K, Degree, Dnum, Dsize, GGLWECompressedSeedMut, GGLWEInfos, GGLWEToMut, GLWEInfos, GLWESwitchingKey, + GLWESwitchingKeyDegrees, GLWESwitchingKeyDegreesMut, LWEInfos, Rank, TorusPrecision, + compressed::{GGLWECompressed, GGLWECompressedToMut, GGLWECompressedToRef, GGLWEDecompress}, +}; +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; +use std::fmt; + +#[derive(PartialEq, Eq, Clone)] +pub struct GLWESwitchingKeyCompressed { + pub(crate) key: GGLWECompressed, + pub(crate) input_degree: Degree, // Degree of sk_in + pub(crate) output_degree: Degree, // Degree of sk_out +} + +impl GGLWECompressedSeedMut for GLWESwitchingKeyCompressed { + fn seed_mut(&mut self) -> &mut Vec<[u8; 32]> { + &mut self.key.seed + } +} + +impl GLWESwitchingKeyDegrees for GLWESwitchingKeyCompressed { + fn output_degree(&self) -> &Degree { + &self.output_degree + } + + fn input_degree(&self) -> &Degree { + &self.input_degree + } +} + +impl GLWESwitchingKeyDegreesMut for GLWESwitchingKeyCompressed { + fn output_degree(&mut self) -> &mut Degree { + &mut self.output_degree + } + + fn input_degree(&mut self) -> &mut Degree { + &mut self.input_degree + } +} + +impl LWEInfos for GLWESwitchingKeyCompressed { + fn n(&self) -> Degree { + self.key.n() + } + + fn base2k(&self) -> Base2K { + self.key.base2k() + } + + fn k(&self) -> TorusPrecision { + self.key.k() + } + + fn size(&self) -> usize { + self.key.size() + } +} +impl GLWEInfos for GLWESwitchingKeyCompressed { + fn rank(&self) -> Rank { + self.rank_out() + } +} + +impl GGLWEInfos for GLWESwitchingKeyCompressed { + fn rank_in(&self) -> Rank { + self.key.rank_in() + } + + fn rank_out(&self) -> Rank { + self.key.rank_out() + } + + fn dsize(&self) -> Dsize { + self.key.dsize() + } + + fn dnum(&self) -> Dnum { + self.key.dnum() + } +} + +impl fmt::Debug for GLWESwitchingKeyCompressed { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{self}") + } +} + +impl FillUniform for GLWESwitchingKeyCompressed { + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { + self.key.fill_uniform(log_bound, source); + } +} + +impl fmt::Display for GLWESwitchingKeyCompressed { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "(GLWESwitchingKeyCompressed: sk_in_n={} sk_out_n={}) {}", + self.input_degree, self.output_degree, self.key.data + ) + } +} + +impl GLWESwitchingKeyCompressed> { + pub fn alloc_from_infos(infos: &A) -> Self + where + A: GGLWEInfos, + { + Self::alloc( + infos.n(), + infos.base2k(), + infos.k(), + infos.rank_in(), + infos.rank_out(), + infos.dnum(), + infos.dsize(), + ) + } + + pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank_in: Rank, rank_out: Rank, dnum: Dnum, dsize: Dsize) -> Self { + GLWESwitchingKeyCompressed { + key: GGLWECompressed::alloc(n, base2k, k, rank_in, rank_out, dnum, dsize), + input_degree: Degree(0), + output_degree: Degree(0), + } + } + + pub fn bytes_of_from_infos(infos: &A) -> usize + where + A: GGLWEInfos, + { + GGLWECompressed::bytes_of_from_infos(infos) + } + + pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum, dsize: Dsize) -> usize +where { + GGLWECompressed::bytes_of(n, base2k, k, rank_in, dnum, dsize) + } +} + +impl ReaderFrom for GLWESwitchingKeyCompressed { + fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { + self.input_degree = Degree(reader.read_u32::()?); + self.output_degree = Degree(reader.read_u32::()?); + self.key.read_from(reader) + } +} + +impl WriterTo for GLWESwitchingKeyCompressed { + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + writer.write_u32::(self.input_degree.into())?; + writer.write_u32::(self.output_degree.into())?; + self.key.write_to(writer) + } +} + +pub trait GLWESwitchingKeyDecompress +where + Self: GGLWEDecompress, +{ + fn decompress_glwe_switching_key(&self, res: &mut R, other: &O) + where + R: GGLWEToMut + GLWESwitchingKeyDegreesMut, + O: GGLWECompressedToRef + GLWESwitchingKeyDegrees, + { + self.decompress_gglwe(res, other); + + *res.input_degree() = *other.input_degree(); + *res.output_degree() = *other.output_degree(); + } +} + +impl GLWESwitchingKeyDecompress for Module where Self: GGLWEDecompress {} + +impl GLWESwitchingKey { + pub fn decompress(&mut self, module: &M, other: &O) + where + O: GGLWECompressedToRef + GLWESwitchingKeyDegrees, + M: GLWESwitchingKeyDecompress, + { + module.decompress_glwe_switching_key(self, other); + } +} + +impl GGLWECompressedToMut for GLWESwitchingKeyCompressed { + fn to_mut(&mut self) -> GGLWECompressed<&mut [u8]> { + self.key.to_mut() + } +} + +impl GGLWECompressedToRef for GLWESwitchingKeyCompressed { + fn to_ref(&self) -> GGLWECompressed<&[u8]> { + self.key.to_ref() + } +} diff --git a/poulpy-core/src/layouts/compressed/glwe_tensor_key.rs b/poulpy-core/src/layouts/compressed/glwe_tensor_key.rs new file mode 100644 index 0000000..6939ff2 --- /dev/null +++ b/poulpy-core/src/layouts/compressed/glwe_tensor_key.rs @@ -0,0 +1,246 @@ +use poulpy_hal::{ + layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo}, + source::Source, +}; + +use crate::layouts::{ + Base2K, Degree, Dnum, Dsize, GGLWECompressed, GGLWECompressedToMut, GGLWECompressedToRef, GGLWEDecompress, GGLWEInfos, + GLWEInfos, GLWETensorKey, GLWETensorKeyToMut, LWEInfos, Rank, TorusPrecision, +}; +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; +use std::fmt; + +#[derive(PartialEq, Eq, Clone)] +pub struct GLWETensorKeyCompressed { + pub(crate) keys: Vec>, +} + +impl LWEInfos for GLWETensorKeyCompressed { + fn n(&self) -> Degree { + self.keys[0].n() + } + + fn base2k(&self) -> Base2K { + self.keys[0].base2k() + } + + fn k(&self) -> TorusPrecision { + self.keys[0].k() + } + fn size(&self) -> usize { + self.keys[0].size() + } +} +impl GLWEInfos for GLWETensorKeyCompressed { + fn rank(&self) -> Rank { + self.rank_out() + } +} + +impl GGLWEInfos for GLWETensorKeyCompressed { + fn rank_in(&self) -> Rank { + self.rank_out() + } + + fn rank_out(&self) -> Rank { + self.keys[0].rank_out() + } + + fn dsize(&self) -> Dsize { + self.keys[0].dsize() + } + + fn dnum(&self) -> Dnum { + self.keys[0].dnum() + } +} + +impl fmt::Debug for GLWETensorKeyCompressed { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{self}") + } +} + +impl FillUniform for GLWETensorKeyCompressed { + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { + self.keys + .iter_mut() + .for_each(|key: &mut GGLWECompressed| key.fill_uniform(log_bound, source)) + } +} + +impl fmt::Display for GLWETensorKeyCompressed { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!(f, "(GLWETensorKeyCompressed)",)?; + for (i, key) in self.keys.iter().enumerate() { + write!(f, "{i}: {key}")?; + } + Ok(()) + } +} + +impl GLWETensorKeyCompressed> { + pub fn alloc_from_infos(infos: &A) -> Self + where + A: GGLWEInfos, + { + Self::alloc( + infos.n(), + infos.base2k(), + infos.k(), + infos.rank(), + infos.dnum(), + infos.dsize(), + ) + } + + pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self { + let pairs: u32 = (((rank.as_u32() + 1) * rank.as_u32()) >> 1).max(1); + GLWETensorKeyCompressed { + keys: (0..pairs) + .map(|_| GGLWECompressed::alloc(n, base2k, k, Rank(1), rank, dnum, dsize)) + .collect(), + } + } + + pub fn bytes_of_from_infos(infos: &A) -> usize + where + A: GGLWEInfos, + { + Self::bytes_of( + infos.n(), + infos.base2k(), + infos.k(), + infos.rank(), + infos.dnum(), + infos.dsize(), + ) + } + + pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { + let pairs: usize = (((rank.0 + 1) * rank.0) >> 1).max(1) as usize; + pairs * GGLWECompressed::bytes_of(n, base2k, k, Rank(1), dnum, dsize) + } +} + +impl ReaderFrom for GLWETensorKeyCompressed { + fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { + let len: usize = reader.read_u64::()? as usize; + if self.keys.len() != len { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("self.keys.len()={} != read len={}", self.keys.len(), len), + )); + } + for key in &mut self.keys { + key.read_from(reader)?; + } + Ok(()) + } +} + +impl WriterTo for GLWETensorKeyCompressed { + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + writer.write_u64::(self.keys.len() as u64)?; + for key in &self.keys { + key.write_to(writer)?; + } + Ok(()) + } +} + +pub trait GLWETensorKeyCompressedAtRef { + fn at(&self, i: usize, j: usize) -> &GGLWECompressed; +} + +impl GLWETensorKeyCompressedAtRef for GLWETensorKeyCompressed { + fn at(&self, mut i: usize, mut j: usize) -> &GGLWECompressed { + if i > j { + std::mem::swap(&mut i, &mut j); + }; + let rank: usize = self.rank_out().into(); + &self.keys[i * rank + j - (i * (i + 1) / 2)] + } +} + +pub trait GLWETensorKeyCompressedAtMut { + fn at_mut(&mut self, i: usize, j: usize) -> &mut GGLWECompressed; +} + +impl GLWETensorKeyCompressedAtMut for GLWETensorKeyCompressed { + fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GGLWECompressed { + if i > j { + std::mem::swap(&mut i, &mut j); + }; + let rank: usize = self.rank_out().into(); + &mut self.keys[i * rank + j - (i * (i + 1) / 2)] + } +} + +pub trait GLWETensorKeyDecompress +where + Self: GGLWEDecompress, +{ + fn decompress_tensor_key(&self, res: &mut R, other: &O) + where + R: GLWETensorKeyToMut, + O: GLWETensorKeyCompressedToRef, + { + let res: &mut GLWETensorKey<&mut [u8]> = &mut res.to_mut(); + let other: &GLWETensorKeyCompressed<&[u8]> = &other.to_ref(); + + assert_eq!( + res.keys.len(), + other.keys.len(), + "invalid receiver: res.keys.len()={} != other.keys.len()={}", + res.keys.len(), + other.keys.len() + ); + + for (a, b) in res.keys.iter_mut().zip(other.keys.iter()) { + self.decompress_gglwe(a, b); + } + } +} + +impl GLWETensorKeyDecompress for Module where Self: GGLWEDecompress {} + +impl GLWETensorKey { + pub fn decompress(&mut self, module: &M, other: &O) + where + O: GLWETensorKeyCompressedToRef, + M: GLWETensorKeyDecompress, + { + module.decompress_tensor_key(self, other); + } +} + +pub trait GLWETensorKeyCompressedToMut { + fn to_mut(&mut self) -> GLWETensorKeyCompressed<&mut [u8]>; +} + +impl GLWETensorKeyCompressedToMut for GLWETensorKeyCompressed +where + GGLWECompressed: GGLWECompressedToMut, +{ + fn to_mut(&mut self) -> GLWETensorKeyCompressed<&mut [u8]> { + GLWETensorKeyCompressed { + keys: self.keys.iter_mut().map(|c| c.to_mut()).collect(), + } + } +} + +pub trait GLWETensorKeyCompressedToRef { + fn to_ref(&self) -> GLWETensorKeyCompressed<&[u8]>; +} + +impl GLWETensorKeyCompressedToRef for GLWETensorKeyCompressed +where + GGLWECompressed: GGLWECompressedToRef, +{ + fn to_ref(&self) -> GLWETensorKeyCompressed<&[u8]> { + GLWETensorKeyCompressed { + keys: self.keys.iter().map(|c| c.to_ref()).collect(), + } + } +} diff --git a/poulpy-core/src/layouts/compressed/glwe_to_lwe_ksk.rs b/poulpy-core/src/layouts/compressed/glwe_to_lwe_switching_key.rs similarity index 54% rename from poulpy-core/src/layouts/compressed/glwe_to_lwe_ksk.rs rename to poulpy-core/src/layouts/compressed/glwe_to_lwe_switching_key.rs index 63933e8..6ac325c 100644 --- a/poulpy-core/src/layouts/compressed/glwe_to_lwe_ksk.rs +++ b/poulpy-core/src/layouts/compressed/glwe_to_lwe_switching_key.rs @@ -1,16 +1,18 @@ use std::fmt; use poulpy_hal::{ - layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, WriterTo}, + layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo}, source::Source, }; use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWEInfos, GLWEInfos, LWEInfos, Rank, TorusPrecision, compressed::GGLWESwitchingKeyCompressed, + Base2K, Degree, Dnum, Dsize, GGLWECompressed, GGLWECompressedToMut, GGLWECompressedToRef, GGLWEInfos, GGLWEToMut, GLWEInfos, + GLWESwitchingKeyDegrees, GLWESwitchingKeyDegreesMut, GLWEToLWESwitchingKey, LWEInfos, Rank, TorusPrecision, + compressed::{GLWESwitchingKeyCompressed, GLWESwitchingKeyDecompress}, }; #[derive(PartialEq, Eq, Clone)] -pub struct GLWEToLWESwitchingKeyCompressed(pub(crate) GGLWESwitchingKeyCompressed); +pub struct GLWEToLWESwitchingKeyCompressed(pub(crate) GLWESwitchingKeyCompressed); impl LWEInfos for GLWEToLWESwitchingKeyCompressed { fn base2k(&self) -> Base2K { @@ -84,25 +86,31 @@ impl WriterTo for GLWEToLWESwitchingKeyCompressed { } impl GLWEToLWESwitchingKeyCompressed> { - pub fn alloc(infos: &A) -> Self + pub fn alloc_from_infos(infos: &A) -> Self where A: GGLWEInfos, { - debug_assert_eq!( + assert_eq!( infos.rank_out().0, 1, "rank_out > 1 is unsupported for GLWEToLWESwitchingKeyCompressed" ); - debug_assert_eq!( + assert_eq!( infos.dsize().0, 1, "dsize > 1 is unsupported for GLWEToLWESwitchingKeyCompressed" ); - Self(GGLWESwitchingKeyCompressed::alloc(infos)) + Self::alloc( + infos.n(), + infos.base2k(), + infos.k(), + infos.rank_in(), + infos.dnum(), + ) } - pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> Self { - Self(GGLWESwitchingKeyCompressed::alloc_with( + pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> Self { + GLWEToLWESwitchingKeyCompressed(GLWESwitchingKeyCompressed::alloc( n, base2k, k, @@ -113,24 +121,61 @@ impl GLWEToLWESwitchingKeyCompressed> { )) } - pub fn alloc_bytes(infos: &A) -> usize + pub fn bytes_of_from_infos(infos: &A) -> usize where A: GGLWEInfos, { - debug_assert_eq!( + assert_eq!( infos.rank_out().0, 1, "rank_out > 1 is unsupported for GLWEToLWESwitchingKeyCompressed" ); - debug_assert_eq!( + assert_eq!( infos.dsize().0, 1, "dsize > 1 is unsupported for GLWEToLWESwitchingKeyCompressed" ); - GGLWESwitchingKeyCompressed::alloc_bytes(infos) + GLWESwitchingKeyCompressed::bytes_of_from_infos(infos) } - pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, dnum: Dnum, rank_in: Rank) -> usize { - GGLWESwitchingKeyCompressed::alloc_bytes_with(n, base2k, k, rank_in, dnum, Dsize(1)) + pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, dnum: Dnum, rank_in: Rank) -> usize { + GLWESwitchingKeyCompressed::bytes_of(n, base2k, k, rank_in, dnum, Dsize(1)) + } +} + +pub trait GLWEToLWESwitchingKeyDecompress +where + Self: GLWESwitchingKeyDecompress, +{ + fn decompress_glwe_to_lwe_switching_key(&self, res: &mut R, other: &O) + where + R: GGLWEToMut + GLWESwitchingKeyDegreesMut, + O: GGLWECompressedToRef + GLWESwitchingKeyDegrees, + { + self.decompress_glwe_switching_key(res, other); + } +} + +impl GLWEToLWESwitchingKeyDecompress for Module where Self: GLWESwitchingKeyDecompress {} + +impl GLWEToLWESwitchingKey { + pub fn decompress(&mut self, module: &M, other: &O) + where + O: GGLWECompressedToRef + GLWESwitchingKeyDegrees, + M: GLWEToLWESwitchingKeyDecompress, + { + module.decompress_glwe_to_lwe_switching_key(self, other); + } +} + +impl GGLWECompressedToRef for GLWEToLWESwitchingKeyCompressed { + fn to_ref(&self) -> GGLWECompressed<&[u8]> { + self.0.to_ref() + } +} + +impl GGLWECompressedToMut for GLWEToLWESwitchingKeyCompressed { + fn to_mut(&mut self) -> GGLWECompressed<&mut [u8]> { + self.0.to_mut() } } diff --git a/poulpy-core/src/layouts/compressed/lwe.rs b/poulpy-core/src/layouts/compressed/lwe.rs new file mode 100644 index 0000000..ce4c000 --- /dev/null +++ b/poulpy-core/src/layouts/compressed/lwe.rs @@ -0,0 +1,182 @@ +use std::fmt; + +use poulpy_hal::{ + api::ZnFillUniform, + layouts::{ + Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo, Zn, ZnToMut, ZnToRef, ZnxInfos, ZnxView, + ZnxViewMut, + }, + source::Source, +}; + +use crate::layouts::{Base2K, Degree, LWE, LWEInfos, LWEToMut, TorusPrecision}; + +#[derive(PartialEq, Eq, Clone)] +pub struct LWECompressed { + pub(crate) data: Zn, + pub(crate) k: TorusPrecision, + pub(crate) base2k: Base2K, + pub(crate) seed: [u8; 32], +} + +impl LWEInfos for LWECompressed { + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } + + fn n(&self) -> Degree { + Degree(self.data.n() as u32) + } + + fn size(&self) -> usize { + self.data.size() + } +} + +impl fmt::Debug for LWECompressed { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{self}") + } +} + +impl fmt::Display for LWECompressed { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "LWECompressed: base2k={} k={} seed={:?}: {}", + self.base2k(), + self.k(), + self.seed, + self.data + ) + } +} + +impl FillUniform for LWECompressed { + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { + self.data.fill_uniform(log_bound, source); + } +} + +impl LWECompressed> { + pub fn alloc_from_infos(infos: &A) -> Self + where + A: LWEInfos, + { + Self::alloc(infos.base2k(), infos.k()) + } + + pub fn alloc(base2k: Base2K, k: TorusPrecision) -> Self { + LWECompressed { + data: Zn::alloc(1, 1, k.0.div_ceil(base2k.0) as usize), + k, + base2k, + seed: [0u8; 32], + } + } + + pub fn bytes_of_from_infos(infos: &A) -> usize + where + A: LWEInfos, + { + Self::bytes_of(infos.base2k(), infos.k()) + } + + pub fn bytes_of(base2k: Base2K, k: TorusPrecision) -> usize { + Zn::bytes_of(1, 1, k.0.div_ceil(base2k.0) as usize) + } +} + +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; + +impl ReaderFrom for LWECompressed { + fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { + self.k = TorusPrecision(reader.read_u32::()?); + self.base2k = Base2K(reader.read_u32::()?); + reader.read_exact(&mut self.seed)?; + self.data.read_from(reader) + } +} + +impl WriterTo for LWECompressed { + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + writer.write_u32::(self.k.into())?; + writer.write_u32::(self.base2k.into())?; + writer.write_all(&self.seed)?; + self.data.write_to(writer) + } +} + +pub trait LWEDecompress +where + Self: ZnFillUniform, +{ + fn decompress_lwe(&self, res: &mut R, other: &O) + where + R: LWEToMut, + O: LWECompressedToRef, + { + let res: &mut LWE<&mut [u8]> = &mut res.to_mut(); + let other: &LWECompressed<&[u8]> = &other.to_ref(); + + assert_eq!(res.lwe_layout(), other.lwe_layout()); + + let mut source: Source = Source::new(other.seed); + self.zn_fill_uniform( + res.n().into(), + other.base2k().into(), + &mut res.data, + 0, + &mut source, + ); + for i in 0..res.size() { + res.data.at_mut(0, i)[0] = other.data.at(0, i)[0]; + } + } +} + +impl LWEDecompress for Module where Self: ZnFillUniform {} + +impl LWE { + pub fn decompress(&mut self, module: &M, other: &O) + where + O: LWECompressedToRef, + M: LWEDecompress, + { + module.decompress_lwe(self, other); + } +} + +pub trait LWECompressedToRef { + fn to_ref(&self) -> LWECompressed<&[u8]>; +} + +impl LWECompressedToRef for LWECompressed { + fn to_ref(&self) -> LWECompressed<&[u8]> { + LWECompressed { + k: self.k, + base2k: self.base2k, + seed: self.seed, + data: self.data.to_ref(), + } + } +} + +pub trait LWECompressedToMut { + fn to_mut(&mut self) -> LWECompressed<&mut [u8]>; +} + +impl LWECompressedToMut for LWECompressed { + fn to_mut(&mut self) -> LWECompressed<&mut [u8]> { + LWECompressed { + k: self.k, + base2k: self.base2k, + seed: self.seed, + data: self.data.to_mut(), + } + } +} diff --git a/poulpy-core/src/layouts/compressed/lwe_ct.rs b/poulpy-core/src/layouts/compressed/lwe_ct.rs deleted file mode 100644 index e11b3f3..0000000 --- a/poulpy-core/src/layouts/compressed/lwe_ct.rs +++ /dev/null @@ -1,129 +0,0 @@ -use std::fmt; - -use poulpy_hal::{ - api::ZnFillUniform, - layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo, Zn, ZnxInfos, ZnxView, ZnxViewMut}, - source::Source, -}; - -use crate::layouts::{Base2K, Degree, LWECiphertext, LWEInfos, TorusPrecision, compressed::Decompress}; - -#[derive(PartialEq, Eq, Clone)] -pub struct LWECiphertextCompressed { - pub(crate) data: Zn, - pub(crate) k: TorusPrecision, - pub(crate) base2k: Base2K, - pub(crate) seed: [u8; 32], -} - -impl LWEInfos for LWECiphertextCompressed { - fn base2k(&self) -> Base2K { - self.base2k - } - - fn k(&self) -> TorusPrecision { - self.k - } - - fn n(&self) -> Degree { - Degree(self.data.n() as u32) - } - - fn size(&self) -> usize { - self.data.size() - } -} - -impl fmt::Debug for LWECiphertextCompressed { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{self}") - } -} - -impl fmt::Display for LWECiphertextCompressed { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "LWECiphertextCompressed: base2k={} k={} seed={:?}: {}", - self.base2k(), - self.k(), - self.seed, - self.data - ) - } -} - -impl FillUniform for LWECiphertextCompressed { - fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { - self.data.fill_uniform(log_bound, source); - } -} - -impl LWECiphertextCompressed> { - pub fn alloc(infos: &A) -> Self - where - A: LWEInfos, - { - Self::alloc_with(infos.base2k(), infos.k()) - } - - pub fn alloc_with(base2k: Base2K, k: TorusPrecision) -> Self { - Self { - data: Zn::alloc(1, 1, k.0.div_ceil(base2k.0) as usize), - k, - base2k, - seed: [0u8; 32], - } - } - - pub fn alloc_bytes(infos: &A) -> usize - where - A: LWEInfos, - { - Self::alloc_bytes_with(infos.base2k(), infos.k()) - } - - pub fn alloc_bytes_with(base2k: Base2K, k: TorusPrecision) -> usize { - Zn::alloc_bytes(1, 1, k.0.div_ceil(base2k.0) as usize) - } -} - -use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; - -impl ReaderFrom for LWECiphertextCompressed { - fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { - self.k = TorusPrecision(reader.read_u32::()?); - self.base2k = Base2K(reader.read_u32::()?); - reader.read_exact(&mut self.seed)?; - self.data.read_from(reader) - } -} - -impl WriterTo for LWECiphertextCompressed { - fn write_to(&self, writer: &mut W) -> std::io::Result<()> { - writer.write_u32::(self.k.into())?; - writer.write_u32::(self.base2k.into())?; - writer.write_all(&self.seed)?; - self.data.write_to(writer) - } -} - -impl Decompress> for LWECiphertext -where - Module: ZnFillUniform, -{ - fn decompress(&mut self, module: &Module, other: &LWECiphertextCompressed) { - debug_assert_eq!(self.size(), other.size()); - let mut source: Source = Source::new(other.seed); - module.zn_fill_uniform( - self.n().into(), - other.base2k().into(), - &mut self.data, - 0, - &mut source, - ); - (0..self.size()).for_each(|i| { - self.data.at_mut(0, i)[0] = other.data.at(0, i)[0]; - }); - } -} diff --git a/poulpy-core/src/layouts/compressed/lwe_ksk.rs b/poulpy-core/src/layouts/compressed/lwe_switching_key.rs similarity index 53% rename from poulpy-core/src/layouts/compressed/lwe_ksk.rs rename to poulpy-core/src/layouts/compressed/lwe_switching_key.rs index 480707b..764d423 100644 --- a/poulpy-core/src/layouts/compressed/lwe_ksk.rs +++ b/poulpy-core/src/layouts/compressed/lwe_switching_key.rs @@ -1,17 +1,17 @@ use poulpy_hal::{ - api::{VecZnxCopy, VecZnxFillUniform}, layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo}, source::Source, }; use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWEInfos, GLWEInfos, LWEInfos, LWESwitchingKey, Rank, TorusPrecision, - compressed::{Decompress, GGLWESwitchingKeyCompressed}, + Base2K, Degree, Dnum, Dsize, GGLWECompressed, GGLWECompressedToMut, GGLWECompressedToRef, GGLWEInfos, GGLWEToMut, GLWEInfos, + GLWESwitchingKeyDegrees, GLWESwitchingKeyDegreesMut, LWEInfos, LWESwitchingKey, Rank, TorusPrecision, + compressed::{GLWESwitchingKeyCompressed, GLWESwitchingKeyDecompress}, }; use std::fmt; #[derive(PartialEq, Eq, Clone)] -pub struct LWESwitchingKeyCompressed(pub(crate) GGLWESwitchingKeyCompressed); +pub struct LWESwitchingKeyCompressed(pub(crate) GLWESwitchingKeyCompressed); impl LWEInfos for LWESwitchingKeyCompressed { fn base2k(&self) -> Base2K { @@ -84,30 +84,30 @@ impl WriterTo for LWESwitchingKeyCompressed { } impl LWESwitchingKeyCompressed> { - pub fn alloc(infos: &A) -> Self + pub fn alloc_from_infos(infos: &A) -> Self where A: GGLWEInfos, { - debug_assert_eq!( + assert_eq!( infos.dsize().0, 1, "dsize > 1 is not supported for LWESwitchingKeyCompressed" ); - debug_assert_eq!( + assert_eq!( infos.rank_in().0, 1, "rank_in > 1 is not supported for LWESwitchingKeyCompressed" ); - debug_assert_eq!( + assert_eq!( infos.rank_out().0, 1, "rank_out > 1 is not supported for LWESwitchingKeyCompressed" ); - Self(GGLWESwitchingKeyCompressed::alloc(infos)) + Self::alloc(infos.n(), infos.base2k(), infos.k(), infos.dnum()) } - pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> Self { - Self(GGLWESwitchingKeyCompressed::alloc_with( + pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> Self { + LWESwitchingKeyCompressed(GLWESwitchingKeyCompressed::alloc( n, base2k, k, @@ -118,38 +118,66 @@ impl LWESwitchingKeyCompressed> { )) } - pub fn alloc_bytes(infos: &A) -> usize + pub fn bytes_of_from_infos(infos: &A) -> usize where A: GGLWEInfos, { - debug_assert_eq!( + assert_eq!( infos.dsize().0, 1, - "dsize > 1 is not supported for LWESwitchingKey" + "dsize > 1 is not supported for LWESwitchingKeyCompressed" ); - debug_assert_eq!( + assert_eq!( infos.rank_in().0, 1, - "rank_in > 1 is not supported for LWESwitchingKey" + "rank_in > 1 is not supported for LWESwitchingKeyCompressed" ); - debug_assert_eq!( + assert_eq!( infos.rank_out().0, 1, - "rank_out > 1 is not supported for LWESwitchingKey" + "rank_out > 1 is not supported for LWESwitchingKeyCompressed" ); - GGLWESwitchingKeyCompressed::alloc_bytes(infos) + GLWESwitchingKeyCompressed::bytes_of_from_infos(infos) } - pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> usize { - GGLWESwitchingKeyCompressed::alloc_bytes_with(n, base2k, k, Rank(1), dnum, Dsize(1)) + pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> usize { + GLWESwitchingKeyCompressed::bytes_of(n, base2k, k, Rank(1), dnum, Dsize(1)) } } -impl Decompress> for LWESwitchingKey +pub trait LWESwitchingKeyDecompress where - Module: VecZnxFillUniform + VecZnxCopy, + Self: GLWESwitchingKeyDecompress, { - fn decompress(&mut self, module: &Module, other: &LWESwitchingKeyCompressed) { - self.0.decompress(module, &other.0); + fn decompress_lwe_switching_key(&self, res: &mut R, other: &O) + where + R: GGLWEToMut + GLWESwitchingKeyDegreesMut, + O: GGLWECompressedToRef + GLWESwitchingKeyDegrees, + { + self.decompress_glwe_switching_key(res, other); + } +} + +impl LWESwitchingKeyDecompress for Module where Self: GLWESwitchingKeyDecompress {} + +impl LWESwitchingKey { + pub fn decompress(&mut self, module: &M, other: &O) + where + O: GGLWECompressedToRef + GLWESwitchingKeyDegrees, + M: LWESwitchingKeyDecompress, + { + module.decompress_lwe_switching_key(self, other); + } +} + +impl GGLWECompressedToRef for LWESwitchingKeyCompressed { + fn to_ref(&self) -> GGLWECompressed<&[u8]> { + self.0.to_ref() + } +} + +impl GGLWECompressedToMut for LWESwitchingKeyCompressed { + fn to_mut(&mut self) -> GGLWECompressed<&mut [u8]> { + self.0.to_mut() } } diff --git a/poulpy-core/src/layouts/compressed/lwe_to_glwe_ksk.rs b/poulpy-core/src/layouts/compressed/lwe_to_glwe_switching_key.rs similarity index 55% rename from poulpy-core/src/layouts/compressed/lwe_to_glwe_ksk.rs rename to poulpy-core/src/layouts/compressed/lwe_to_glwe_switching_key.rs index 86c353b..7a724c9 100644 --- a/poulpy-core/src/layouts/compressed/lwe_to_glwe_ksk.rs +++ b/poulpy-core/src/layouts/compressed/lwe_to_glwe_switching_key.rs @@ -1,17 +1,17 @@ use poulpy_hal::{ - api::{VecZnxCopy, VecZnxFillUniform}, layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo}, source::Source, }; use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWEInfos, GLWEInfos, LWEInfos, LWEToGLWESwitchingKey, Rank, TorusPrecision, - compressed::{Decompress, GGLWESwitchingKeyCompressed}, + Base2K, Degree, Dnum, Dsize, GGLWECompressed, GGLWECompressedToMut, GGLWECompressedToRef, GGLWEInfos, GGLWEToMut, GLWEInfos, + GLWESwitchingKeyDegrees, GLWESwitchingKeyDegreesMut, LWEInfos, LWEToGLWESwitchingKey, Rank, TorusPrecision, + compressed::{GLWESwitchingKeyCompressed, GLWESwitchingKeyDecompress}, }; use std::fmt; #[derive(PartialEq, Eq, Clone)] -pub struct LWEToGLWESwitchingKeyCompressed(pub(crate) GGLWESwitchingKeyCompressed); +pub struct LWEToGLWESwitchingKeyCompressed(pub(crate) GLWESwitchingKeyCompressed); impl LWEInfos for LWEToGLWESwitchingKeyCompressed { fn n(&self) -> Degree { @@ -84,25 +84,31 @@ impl WriterTo for LWEToGLWESwitchingKeyCompressed { } impl LWEToGLWESwitchingKeyCompressed> { - pub fn alloc(infos: &A) -> Self + pub fn alloc_from_infos(infos: &A) -> Self where A: GGLWEInfos, { - debug_assert_eq!( + assert_eq!( infos.dsize().0, 1, "dsize > 1 is not supported for LWEToGLWESwitchingKeyCompressed" ); - debug_assert_eq!( + assert_eq!( infos.rank_in().0, 1, "rank_in > 1 is not supported for LWEToGLWESwitchingKeyCompressed" ); - Self(GGLWESwitchingKeyCompressed::alloc(infos)) + Self::alloc( + infos.n(), + infos.base2k(), + infos.k(), + infos.rank_out(), + infos.dnum(), + ) } - pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> Self { - Self(GGLWESwitchingKeyCompressed::alloc_with( + pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> Self { + LWEToGLWESwitchingKeyCompressed(GLWESwitchingKeyCompressed::alloc( n, base2k, k, @@ -113,33 +119,61 @@ impl LWEToGLWESwitchingKeyCompressed> { )) } - pub fn alloc_bytes(infos: &A) -> usize + pub fn bytes_of_from_infos(infos: &A) -> usize where A: GGLWEInfos, { - debug_assert_eq!( - infos.rank_in().0, - 1, - "rank_in > 1 is not supported for LWEToGLWESwitchingKey" - ); - debug_assert_eq!( + assert_eq!( infos.dsize().0, 1, - "dsize > 1 is not supported for LWEToGLWESwitchingKey" + "dsize > 1 is not supported for LWEToGLWESwitchingKeyCompressed" ); - GGLWESwitchingKeyCompressed::alloc_bytes(infos) + assert_eq!( + infos.rank_in().0, + 1, + "rank_in > 1 is not supported for LWEToGLWESwitchingKeyCompressed" + ); + GLWESwitchingKeyCompressed::bytes_of_from_infos(infos) } - pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> usize { - GGLWESwitchingKeyCompressed::alloc_bytes_with(n, base2k, k, Rank(1), dnum, Dsize(1)) + pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> usize { + GLWESwitchingKeyCompressed::bytes_of(n, base2k, k, Rank(1), dnum, Dsize(1)) } } -impl Decompress> for LWEToGLWESwitchingKey +pub trait LWEToGLWESwitchingKeyDecompress where - Module: VecZnxFillUniform + VecZnxCopy, + Self: GLWESwitchingKeyDecompress, { - fn decompress(&mut self, module: &Module, other: &LWEToGLWESwitchingKeyCompressed) { - self.0.decompress(module, &other.0); + fn decompress_lwe_to_glwe_switching_key(&self, res: &mut R, other: &O) + where + R: GGLWEToMut + GLWESwitchingKeyDegreesMut, + O: GGLWECompressedToRef + GLWESwitchingKeyDegrees, + { + self.decompress_glwe_switching_key(res, other); + } +} + +impl LWEToGLWESwitchingKeyDecompress for Module where Self: GLWESwitchingKeyDecompress {} + +impl LWEToGLWESwitchingKey { + pub fn decompress(&mut self, module: &M, other: &O) + where + O: GGLWECompressedToRef + GLWESwitchingKeyDegrees, + M: LWEToGLWESwitchingKeyDecompress, + { + module.decompress_lwe_to_glwe_switching_key(self, other); + } +} + +impl GGLWECompressedToRef for LWEToGLWESwitchingKeyCompressed { + fn to_ref(&self) -> GGLWECompressed<&[u8]> { + self.0.to_ref() + } +} + +impl GGLWECompressedToMut for LWEToGLWESwitchingKeyCompressed { + fn to_mut(&mut self) -> GGLWECompressed<&mut [u8]> { + self.0.to_mut() } } diff --git a/poulpy-core/src/layouts/compressed/mod.rs b/poulpy-core/src/layouts/compressed/mod.rs index c1fcacf..b85d48d 100644 --- a/poulpy-core/src/layouts/compressed/mod.rs +++ b/poulpy-core/src/layouts/compressed/mod.rs @@ -1,27 +1,21 @@ -mod gglwe_atk; -mod gglwe_ct; -mod gglwe_ksk; -mod gglwe_tsk; -mod ggsw_ct; -mod glwe_ct; -mod glwe_to_lwe_ksk; -mod lwe_ct; -mod lwe_ksk; -mod lwe_to_glwe_ksk; +mod gglwe; +mod ggsw; +mod glwe; +mod glwe_automorphism_key; +mod glwe_switching_key; +mod glwe_tensor_key; +mod glwe_to_lwe_switching_key; +mod lwe; +mod lwe_switching_key; +mod lwe_to_glwe_switching_key; -pub use gglwe_atk::*; -pub use gglwe_ct::*; -pub use gglwe_ksk::*; -pub use gglwe_tsk::*; -pub use ggsw_ct::*; -pub use glwe_ct::*; -pub use glwe_to_lwe_ksk::*; -pub use lwe_ct::*; -pub use lwe_ksk::*; -pub use lwe_to_glwe_ksk::*; - -use poulpy_hal::layouts::{Backend, Module}; - -pub trait Decompress { - fn decompress(&mut self, module: &Module, other: &C); -} +pub use gglwe::*; +pub use ggsw::*; +pub use glwe::*; +pub use glwe_automorphism_key::*; +pub use glwe_switching_key::*; +pub use glwe_tensor_key::*; +pub use glwe_to_lwe_switching_key::*; +pub use lwe::*; +pub use lwe_switching_key::*; +pub use lwe_to_glwe_switching_key::*; diff --git a/poulpy-core/src/layouts/gglwe_ct.rs b/poulpy-core/src/layouts/gglwe.rs similarity index 53% rename from poulpy-core/src/layouts/gglwe_ct.rs rename to poulpy-core/src/layouts/gglwe.rs index ca8236c..a491186 100644 --- a/poulpy-core/src/layouts/gglwe_ct.rs +++ b/poulpy-core/src/layouts/gglwe.rs @@ -1,9 +1,9 @@ use poulpy_hal::{ - layouts::{Data, DataMut, DataRef, FillUniform, MatZnx, ReaderFrom, WriterTo, ZnxInfos}, + layouts::{Data, DataMut, DataRef, FillUniform, MatZnx, MatZnxToMut, MatZnxToRef, ReaderFrom, WriterTo, ZnxInfos}, source::Source, }; -use crate::layouts::{Base2K, BuildError, Degree, Dnum, Dsize, GLWECiphertext, GLWEInfos, LWEInfos, Rank, TorusPrecision}; +use crate::layouts::{Base2K, Degree, Dnum, Dsize, GLWE, GLWEInfos, LWEInfos, Rank, TorusPrecision}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::fmt; @@ -16,8 +16,8 @@ where fn dsize(&self) -> Dsize; fn rank_in(&self) -> Rank; fn rank_out(&self) -> Rank; - fn layout(&self) -> GGLWECiphertextLayout { - GGLWECiphertextLayout { + fn gglwe_layout(&self) -> GGLWELayout { + GGLWELayout { n: self.n(), base2k: self.base2k(), k: self.k(), @@ -29,8 +29,12 @@ where } } +pub trait SetGGLWEInfos { + fn set_dsize(&mut self, dsize: usize); +} + #[derive(PartialEq, Eq, Copy, Clone, Debug)] -pub struct GGLWECiphertextLayout { +pub struct GGLWELayout { pub n: Degree, pub base2k: Base2K, pub k: TorusPrecision, @@ -40,7 +44,7 @@ pub struct GGLWECiphertextLayout { pub dsize: Dsize, } -impl LWEInfos for GGLWECiphertextLayout { +impl LWEInfos for GGLWELayout { fn base2k(&self) -> Base2K { self.base2k } @@ -54,13 +58,13 @@ impl LWEInfos for GGLWECiphertextLayout { } } -impl GLWEInfos for GGLWECiphertextLayout { +impl GLWEInfos for GGLWELayout { fn rank(&self) -> Rank { self.rank_out } } -impl GGLWEInfos for GGLWECiphertextLayout { +impl GGLWEInfos for GGLWELayout { fn rank_in(&self) -> Rank { self.rank_in } @@ -79,14 +83,14 @@ impl GGLWEInfos for GGLWECiphertextLayout { } #[derive(PartialEq, Eq, Clone)] -pub struct GGLWECiphertext { +pub struct GGLWE { pub(crate) data: MatZnx, pub(crate) k: TorusPrecision, pub(crate) base2k: Base2K, pub(crate) dsize: Dsize, } -impl LWEInfos for GGLWECiphertext { +impl LWEInfos for GGLWE { fn base2k(&self) -> Base2K { self.base2k } @@ -104,13 +108,13 @@ impl LWEInfos for GGLWECiphertext { } } -impl GLWEInfos for GGLWECiphertext { +impl GLWEInfos for GGLWE { fn rank(&self) -> Rank { self.rank_out() } } -impl GGLWEInfos for GGLWECiphertext { +impl GGLWEInfos for GGLWE { fn rank_in(&self) -> Rank { Rank(self.data.cols_in() as u32) } @@ -128,136 +132,35 @@ impl GGLWEInfos for GGLWECiphertext { } } -pub struct GGLWECiphertextBuilder { - data: Option>, - base2k: Option, - k: Option, - dsize: Option, -} - -impl GGLWECiphertext { - #[inline] - pub fn builder() -> GGLWECiphertextBuilder { - GGLWECiphertextBuilder { - data: None, - base2k: None, - k: None, - dsize: None, - } - } -} - -impl GGLWECiphertextBuilder> { - #[inline] - pub fn layout(mut self, infos: &A) -> Self - where - A: GGLWEInfos, - { - self.data = Some(MatZnx::alloc( - infos.n().into(), - infos.dnum().into(), - infos.rank_in().into(), - (infos.rank_out() + 1).into(), - infos.size(), - )); - self.base2k = Some(infos.base2k()); - self.k = Some(infos.k()); - self.dsize = Some(infos.dsize()); - self - } -} - -impl GGLWECiphertextBuilder { - #[inline] - pub fn data(mut self, data: MatZnx) -> Self { - self.data = Some(data); - self - } - #[inline] - pub fn base2k(mut self, base2k: Base2K) -> Self { - self.base2k = Some(base2k); - self - } - #[inline] - pub fn k(mut self, k: TorusPrecision) -> Self { - self.k = Some(k); - self - } - - #[inline] - pub fn dsize(mut self, dsize: Dsize) -> Self { - self.dsize = Some(dsize); - self - } - - pub fn build(self) -> Result, BuildError> { - let data: MatZnx = self.data.ok_or(BuildError::MissingData)?; - let base2k: Base2K = self.base2k.ok_or(BuildError::MissingBase2K)?; - let k: TorusPrecision = self.k.ok_or(BuildError::MissingK)?; - let dsize: Dsize = self.dsize.ok_or(BuildError::MissingDigits)?; - - if base2k == 0_u32 { - return Err(BuildError::ZeroBase2K); - } - - if dsize == 0_u32 { - return Err(BuildError::ZeroBase2K); - } - - if k == 0_u32 { - return Err(BuildError::ZeroTorusPrecision); - } - - if data.n() == 0 { - return Err(BuildError::ZeroDegree); - } - - if data.cols() == 0 { - return Err(BuildError::ZeroCols); - } - - if data.size() == 0 { - return Err(BuildError::ZeroLimbs); - } - - Ok(GGLWECiphertext { - data, - base2k, - k, - dsize, - }) - } -} - -impl GGLWECiphertext { +impl GGLWE { pub fn data(&self) -> &MatZnx { &self.data } } -impl GGLWECiphertext { +impl GGLWE { pub fn data_mut(&mut self) -> &mut MatZnx { &mut self.data } } -impl fmt::Debug for GGLWECiphertext { +impl fmt::Debug for GGLWE { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{self}") } } -impl FillUniform for GGLWECiphertext { +impl FillUniform for GGLWE { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.data.fill_uniform(log_bound, source); } } -impl fmt::Display for GGLWECiphertext { +impl fmt::Display for GGLWE { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "(GGLWECiphertext: k={} base2k={} dsize={}) {}", + "(GGLWE: k={} base2k={} dsize={}) {}", self.k().0, self.base2k().0, self.dsize().0, @@ -266,34 +169,32 @@ impl fmt::Display for GGLWECiphertext { } } -impl GGLWECiphertext { - pub fn at(&self, row: usize, col: usize) -> GLWECiphertext<&[u8]> { - GLWECiphertext::builder() - .data(self.data.at(row, col)) - .base2k(self.base2k()) - .k(self.k()) - .build() - .unwrap() +impl GGLWE { + pub fn at(&self, row: usize, col: usize) -> GLWE<&[u8]> { + GLWE { + k: self.k, + base2k: self.base2k, + data: self.data.at(row, col), + } } } -impl GGLWECiphertext { - pub fn at_mut(&mut self, row: usize, col: usize) -> GLWECiphertext<&mut [u8]> { - GLWECiphertext::builder() - .base2k(self.base2k()) - .k(self.k()) - .data(self.data.at_mut(row, col)) - .build() - .unwrap() +impl GGLWE { + pub fn at_mut(&mut self, row: usize, col: usize) -> GLWE<&mut [u8]> { + GLWE { + k: self.k, + base2k: self.base2k, + data: self.data.at_mut(row, col), + } } } -impl GGLWECiphertext> { - pub fn alloc(infos: &A) -> Self +impl GGLWE> { + pub fn alloc_from_infos(infos: &A) -> Self where A: GGLWEInfos, { - Self::alloc_with( + Self::alloc( infos.n(), infos.base2k(), infos.k(), @@ -304,15 +205,7 @@ impl GGLWECiphertext> { ) } - pub fn alloc_with( - n: Degree, - base2k: Base2K, - k: TorusPrecision, - rank_in: Rank, - rank_out: Rank, - dnum: Dnum, - dsize: Dsize, - ) -> Self { + pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank_in: Rank, rank_out: Rank, dnum: Dnum, dsize: Dsize) -> Self { let size: usize = k.0.div_ceil(base2k.0) as usize; debug_assert!( size as u32 > dsize.0, @@ -327,7 +220,7 @@ impl GGLWECiphertext> { dsize.0, ); - Self { + GGLWE { data: MatZnx::alloc( n.into(), dnum.into(), @@ -341,11 +234,11 @@ impl GGLWECiphertext> { } } - pub fn alloc_bytes(infos: &A) -> usize + pub fn bytes_of_from_infos(infos: &A) -> usize where A: GGLWEInfos, { - Self::alloc_bytes_with( + Self::bytes_of( infos.n(), infos.base2k(), infos.k(), @@ -356,7 +249,7 @@ impl GGLWECiphertext> { ) } - pub fn alloc_bytes_with( + pub fn bytes_of( n: Degree, base2k: Base2K, k: TorusPrecision, @@ -379,7 +272,7 @@ impl GGLWECiphertext> { dsize.0, ); - MatZnx::alloc_bytes( + MatZnx::bytes_of( n.into(), dnum.into(), rank_in.into(), @@ -389,7 +282,37 @@ impl GGLWECiphertext> { } } -impl ReaderFrom for GGLWECiphertext { +pub trait GGLWEToMut { + fn to_mut(&mut self) -> GGLWE<&mut [u8]>; +} + +impl GGLWEToMut for GGLWE { + fn to_mut(&mut self) -> GGLWE<&mut [u8]> { + GGLWE { + k: self.k(), + base2k: self.base2k(), + dsize: self.dsize(), + data: self.data.to_mut(), + } + } +} + +pub trait GGLWEToRef { + fn to_ref(&self) -> GGLWE<&[u8]>; +} + +impl GGLWEToRef for GGLWE { + fn to_ref(&self) -> GGLWE<&[u8]> { + GGLWE { + k: self.k(), + base2k: self.base2k(), + dsize: self.dsize(), + data: self.data.to_ref(), + } + } +} + +impl ReaderFrom for GGLWE { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { self.k = TorusPrecision(reader.read_u32::()?); self.base2k = Base2K(reader.read_u32::()?); @@ -398,7 +321,7 @@ impl ReaderFrom for GGLWECiphertext { } } -impl WriterTo for GGLWECiphertext { +impl WriterTo for GGLWE { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { writer.write_u32::(self.k.0)?; writer.write_u32::(self.base2k.0)?; diff --git a/poulpy-core/src/layouts/gglwe_ksk.rs b/poulpy-core/src/layouts/gglwe_ksk.rs deleted file mode 100644 index 31a483b..0000000 --- a/poulpy-core/src/layouts/gglwe_ksk.rs +++ /dev/null @@ -1,209 +0,0 @@ -use poulpy_hal::{ - layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, WriterTo}, - source::Source, -}; - -use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWECiphertext, GGLWEInfos, GLWECiphertext, GLWEInfos, LWEInfos, Rank, TorusPrecision, -}; -use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; - -use std::fmt; - -#[derive(PartialEq, Eq, Copy, Clone, Debug)] -pub struct GGLWESwitchingKeyLayout { - pub n: Degree, - pub base2k: Base2K, - pub k: TorusPrecision, - pub rank_in: Rank, - pub rank_out: Rank, - pub dnum: Dnum, - pub dsize: Dsize, -} - -impl LWEInfos for GGLWESwitchingKeyLayout { - fn n(&self) -> Degree { - self.n - } - - fn base2k(&self) -> Base2K { - self.base2k - } - - fn k(&self) -> TorusPrecision { - self.k - } -} - -impl GLWEInfos for GGLWESwitchingKeyLayout { - fn rank(&self) -> Rank { - self.rank_out() - } -} - -impl GGLWEInfos for GGLWESwitchingKeyLayout { - fn rank_in(&self) -> Rank { - self.rank_in - } - - fn rank_out(&self) -> Rank { - self.rank_out - } - - fn dsize(&self) -> Dsize { - self.dsize - } - - fn dnum(&self) -> Dnum { - self.dnum - } -} - -#[derive(PartialEq, Eq, Clone)] -pub struct GGLWESwitchingKey { - pub(crate) key: GGLWECiphertext, - pub(crate) sk_in_n: usize, // Degree of sk_in - pub(crate) sk_out_n: usize, // Degree of sk_out -} - -impl LWEInfos for GGLWESwitchingKey { - fn n(&self) -> Degree { - self.key.n() - } - - fn base2k(&self) -> Base2K { - self.key.base2k() - } - - fn k(&self) -> TorusPrecision { - self.key.k() - } - - fn size(&self) -> usize { - self.key.size() - } -} - -impl GLWEInfos for GGLWESwitchingKey { - fn rank(&self) -> Rank { - self.rank_out() - } -} - -impl GGLWEInfos for GGLWESwitchingKey { - fn rank_in(&self) -> Rank { - self.key.rank_in() - } - - fn rank_out(&self) -> Rank { - self.key.rank_out() - } - - fn dsize(&self) -> Dsize { - self.key.dsize() - } - - fn dnum(&self) -> Dnum { - self.key.dnum() - } -} - -impl fmt::Debug for GGLWESwitchingKey { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{self}") - } -} - -impl fmt::Display for GGLWESwitchingKey { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "(GLWESwitchingKey: sk_in_n={} sk_out_n={}) {}", - self.sk_in_n, - self.sk_out_n, - self.key.data() - ) - } -} - -impl FillUniform for GGLWESwitchingKey { - fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { - self.key.fill_uniform(log_bound, source); - } -} - -impl GGLWESwitchingKey> { - pub fn alloc(infos: &A) -> Self - where - A: GGLWEInfos, - { - GGLWESwitchingKey { - key: GGLWECiphertext::alloc(infos), - sk_in_n: 0, - sk_out_n: 0, - } - } - - pub fn alloc_with( - n: Degree, - base2k: Base2K, - k: TorusPrecision, - rank_in: Rank, - rank_out: Rank, - dnum: Dnum, - dsize: Dsize, - ) -> Self { - GGLWESwitchingKey { - key: GGLWECiphertext::alloc_with(n, base2k, k, rank_in, rank_out, dnum, dsize), - sk_in_n: 0, - sk_out_n: 0, - } - } - - pub fn alloc_bytes(infos: &A) -> usize - where - A: GGLWEInfos, - { - GGLWECiphertext::alloc_bytes(infos) - } - - pub fn alloc_bytes_with( - n: Degree, - base2k: Base2K, - k: TorusPrecision, - rank_in: Rank, - rank_out: Rank, - dnum: Dnum, - dsize: Dsize, - ) -> usize { - GGLWECiphertext::alloc_bytes_with(n, base2k, k, rank_in, rank_out, dnum, dsize) - } -} - -impl GGLWESwitchingKey { - pub fn at(&self, row: usize, col: usize) -> GLWECiphertext<&[u8]> { - self.key.at(row, col) - } -} - -impl GGLWESwitchingKey { - pub fn at_mut(&mut self, row: usize, col: usize) -> GLWECiphertext<&mut [u8]> { - self.key.at_mut(row, col) - } -} - -impl ReaderFrom for GGLWESwitchingKey { - fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { - self.sk_in_n = reader.read_u64::()? as usize; - self.sk_out_n = reader.read_u64::()? as usize; - self.key.read_from(reader) - } -} - -impl WriterTo for GGLWESwitchingKey { - fn write_to(&self, writer: &mut W) -> std::io::Result<()> { - writer.write_u64::(self.sk_in_n as u64)?; - writer.write_u64::(self.sk_out_n as u64)?; - self.key.write_to(writer) - } -} diff --git a/poulpy-core/src/layouts/ggsw.rs b/poulpy-core/src/layouts/ggsw.rs new file mode 100644 index 0000000..a4109d0 --- /dev/null +++ b/poulpy-core/src/layouts/ggsw.rs @@ -0,0 +1,284 @@ +use poulpy_hal::{ + layouts::{Data, DataMut, DataRef, FillUniform, MatZnx, MatZnxToMut, MatZnxToRef, ReaderFrom, WriterTo, ZnxInfos}, + source::Source, +}; +use std::fmt; + +use crate::layouts::{Base2K, Degree, Dnum, Dsize, GLWE, GLWEInfos, LWEInfos, Rank, TorusPrecision}; + +pub trait GGSWInfos +where + Self: GLWEInfos, +{ + fn dnum(&self) -> Dnum; + fn dsize(&self) -> Dsize; + fn ggsw_layout(&self) -> GGSWLayout { + GGSWLayout { + n: self.n(), + base2k: self.base2k(), + k: self.k(), + rank: self.rank(), + dnum: self.dnum(), + dsize: self.dsize(), + } + } +} + +#[derive(PartialEq, Eq, Copy, Clone, Debug)] +pub struct GGSWLayout { + pub n: Degree, + pub base2k: Base2K, + pub k: TorusPrecision, + pub rank: Rank, + pub dnum: Dnum, + pub dsize: Dsize, +} + +impl LWEInfos for GGSWLayout { + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } + + fn n(&self) -> Degree { + self.n + } +} +impl GLWEInfos for GGSWLayout { + fn rank(&self) -> Rank { + self.rank + } +} + +impl GGSWInfos for GGSWLayout { + fn dsize(&self) -> Dsize { + self.dsize + } + + fn dnum(&self) -> Dnum { + self.dnum + } +} + +#[derive(PartialEq, Eq, Clone)] +pub struct GGSW { + pub(crate) data: MatZnx, + pub(crate) k: TorusPrecision, + pub(crate) base2k: Base2K, + pub(crate) dsize: Dsize, +} + +impl LWEInfos for GGSW { + fn n(&self) -> Degree { + Degree(self.data.n() as u32) + } + + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } + + fn size(&self) -> usize { + self.data.size() + } +} + +impl GLWEInfos for GGSW { + fn rank(&self) -> Rank { + Rank(self.data.cols_out() as u32 - 1) + } +} + +impl GGSWInfos for GGSW { + fn dsize(&self) -> Dsize { + self.dsize + } + + fn dnum(&self) -> Dnum { + Dnum(self.data.rows() as u32) + } +} + +impl fmt::Debug for GGSW { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.data) + } +} + +impl fmt::Display for GGSW { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "(GGSW: k: {} base2k: {} dsize: {}) {}", + self.k().0, + self.base2k().0, + self.dsize().0, + self.data + ) + } +} + +impl FillUniform for GGSW { + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { + self.data.fill_uniform(log_bound, source); + } +} + +impl GGSW { + pub fn at(&self, row: usize, col: usize) -> GLWE<&[u8]> { + GLWE { + k: self.k, + base2k: self.base2k, + data: self.data.at(row, col), + } + } +} + +impl GGSW { + pub fn at_mut(&mut self, row: usize, col: usize) -> GLWE<&mut [u8]> { + GLWE { + k: self.k, + base2k: self.base2k, + data: self.data.at_mut(row, col), + } + } +} + +impl GGSW> { + pub fn alloc_from_infos(infos: &A) -> Self + where + A: GGSWInfos, + { + Self::alloc( + infos.n(), + infos.base2k(), + infos.k(), + infos.rank(), + infos.dnum(), + infos.dsize(), + ) + } + + pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self { + let size: usize = k.0.div_ceil(base2k.0) as usize; + debug_assert!( + size as u32 > dsize.0, + "invalid ggsw: ceil(k/base2k): {size} <= dsize: {}", + dsize.0 + ); + + assert!( + dnum.0 * dsize.0 <= size as u32, + "invalid ggsw: dnum: {} * dsize:{} > ceil(k/base2k): {size}", + dnum.0, + dsize.0, + ); + + GGSW { + data: MatZnx::alloc( + n.into(), + dnum.into(), + (rank + 1).into(), + (rank + 1).into(), + k.0.div_ceil(base2k.0) as usize, + ), + k, + base2k, + dsize, + } + } + + pub fn bytes_of_from_infos(infos: &A) -> usize + where + A: GGSWInfos, + { + Self::bytes_of( + infos.n(), + infos.base2k(), + infos.k(), + infos.rank(), + infos.dnum(), + infos.dsize(), + ) + } + + pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { + let size: usize = k.0.div_ceil(base2k.0) as usize; + debug_assert!( + size as u32 > dsize.0, + "invalid ggsw: ceil(k/base2k): {size} <= dsize: {}", + dsize.0 + ); + + assert!( + dnum.0 * dsize.0 <= size as u32, + "invalid ggsw: dnum: {} * dsize:{} > ceil(k/base2k): {size}", + dnum.0, + dsize.0, + ); + + MatZnx::bytes_of( + n.into(), + dnum.into(), + (rank + 1).into(), + (rank + 1).into(), + k.0.div_ceil(base2k.0) as usize, + ) + } +} + +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; + +impl ReaderFrom for GGSW { + fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { + self.k = TorusPrecision(reader.read_u32::()?); + self.base2k = Base2K(reader.read_u32::()?); + self.dsize = Dsize(reader.read_u32::()?); + self.data.read_from(reader) + } +} + +impl WriterTo for GGSW { + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + writer.write_u32::(self.k.into())?; + writer.write_u32::(self.base2k.into())?; + writer.write_u32::(self.dsize.into())?; + self.data.write_to(writer) + } +} + +pub trait GGSWToMut { + fn to_mut(&mut self) -> GGSW<&mut [u8]>; +} + +impl GGSWToMut for GGSW { + fn to_mut(&mut self) -> GGSW<&mut [u8]> { + GGSW { + dsize: self.dsize, + k: self.k, + base2k: self.base2k, + data: self.data.to_mut(), + } + } +} + +pub trait GGSWToRef { + fn to_ref(&self) -> GGSW<&[u8]>; +} + +impl GGSWToRef for GGSW { + fn to_ref(&self) -> GGSW<&[u8]> { + GGSW { + dsize: self.dsize, + k: self.k, + base2k: self.base2k, + data: self.data.to_ref(), + } + } +} diff --git a/poulpy-core/src/layouts/ggsw_ct.rs b/poulpy-core/src/layouts/ggsw_ct.rs deleted file mode 100644 index f1bb228..0000000 --- a/poulpy-core/src/layouts/ggsw_ct.rs +++ /dev/null @@ -1,372 +0,0 @@ -use poulpy_hal::{ - layouts::{Data, DataMut, DataRef, FillUniform, MatZnx, ReaderFrom, WriterTo, ZnxInfos}, - source::Source, -}; -use std::fmt; - -use crate::layouts::{Base2K, BuildError, Degree, Dnum, Dsize, GLWECiphertext, GLWEInfos, LWEInfos, Rank, TorusPrecision}; - -pub trait GGSWInfos -where - Self: GLWEInfos, -{ - fn dnum(&self) -> Dnum; - fn dsize(&self) -> Dsize; - fn ggsw_layout(&self) -> GGSWCiphertextLayout { - GGSWCiphertextLayout { - n: self.n(), - base2k: self.base2k(), - k: self.k(), - rank: self.rank(), - dnum: self.dnum(), - dsize: self.dsize(), - } - } -} - -#[derive(PartialEq, Eq, Copy, Clone, Debug)] -pub struct GGSWCiphertextLayout { - pub n: Degree, - pub base2k: Base2K, - pub k: TorusPrecision, - pub rank: Rank, - pub dnum: Dnum, - pub dsize: Dsize, -} - -impl LWEInfos for GGSWCiphertextLayout { - fn base2k(&self) -> Base2K { - self.base2k - } - - fn k(&self) -> TorusPrecision { - self.k - } - - fn n(&self) -> Degree { - self.n - } -} -impl GLWEInfos for GGSWCiphertextLayout { - fn rank(&self) -> Rank { - self.rank - } -} - -impl GGSWInfos for GGSWCiphertextLayout { - fn dsize(&self) -> Dsize { - self.dsize - } - - fn dnum(&self) -> Dnum { - self.dnum - } -} - -#[derive(PartialEq, Eq, Clone)] -pub struct GGSWCiphertext { - pub(crate) data: MatZnx, - pub(crate) k: TorusPrecision, - pub(crate) base2k: Base2K, - pub(crate) dsize: Dsize, -} - -impl LWEInfos for GGSWCiphertext { - fn n(&self) -> Degree { - Degree(self.data.n() as u32) - } - - fn base2k(&self) -> Base2K { - self.base2k - } - - fn k(&self) -> TorusPrecision { - self.k - } - - fn size(&self) -> usize { - self.data.size() - } -} - -impl GLWEInfos for GGSWCiphertext { - fn rank(&self) -> Rank { - Rank(self.data.cols_out() as u32 - 1) - } -} - -impl GGSWInfos for GGSWCiphertext { - fn dsize(&self) -> Dsize { - self.dsize - } - - fn dnum(&self) -> Dnum { - Dnum(self.data.rows() as u32) - } -} - -pub struct GGSWCiphertextBuilder { - data: Option>, - base2k: Option, - k: Option, - dsize: Option, -} - -impl GGSWCiphertext { - #[inline] - pub fn builder() -> GGSWCiphertextBuilder { - GGSWCiphertextBuilder { - data: None, - base2k: None, - k: None, - dsize: None, - } - } -} - -impl GGSWCiphertextBuilder> { - #[inline] - pub fn layout(mut self, infos: &A) -> Self - where - A: GGSWInfos, - { - debug_assert!( - infos.size() as u32 > infos.dsize().0, - "invalid ggsw: ceil(k/base2k): {} <= dsize: {}", - infos.size(), - infos.dsize() - ); - - assert!( - infos.dnum().0 * infos.dsize().0 <= infos.size() as u32, - "invalid ggsw: dnum: {} * dsize:{} > ceil(k/base2k): {}", - infos.dnum(), - infos.dsize(), - infos.size(), - ); - - self.data = Some(MatZnx::alloc( - infos.n().into(), - infos.dnum().into(), - (infos.rank() + 1).into(), - (infos.rank() + 1).into(), - infos.size(), - )); - self.base2k = Some(infos.base2k()); - self.k = Some(infos.k()); - self.dsize = Some(infos.dsize()); - self - } -} - -impl GGSWCiphertextBuilder { - #[inline] - pub fn data(mut self, data: MatZnx) -> Self { - self.data = Some(data); - self - } - #[inline] - pub fn base2k(mut self, base2k: Base2K) -> Self { - self.base2k = Some(base2k); - self - } - #[inline] - pub fn k(mut self, k: TorusPrecision) -> Self { - self.k = Some(k); - self - } - - #[inline] - pub fn dsize(mut self, dsize: Dsize) -> Self { - self.dsize = Some(dsize); - self - } - - pub fn build(self) -> Result, BuildError> { - let data: MatZnx = self.data.ok_or(BuildError::MissingData)?; - let base2k: Base2K = self.base2k.ok_or(BuildError::MissingBase2K)?; - let k: TorusPrecision = self.k.ok_or(BuildError::MissingK)?; - let dsize: Dsize = self.dsize.ok_or(BuildError::MissingDigits)?; - - if base2k == 0_u32 { - return Err(BuildError::ZeroBase2K); - } - - if dsize == 0_u32 { - return Err(BuildError::ZeroBase2K); - } - - if k == 0_u32 { - return Err(BuildError::ZeroTorusPrecision); - } - - if data.n() == 0 { - return Err(BuildError::ZeroDegree); - } - - if data.cols() == 0 { - return Err(BuildError::ZeroCols); - } - - if data.size() == 0 { - return Err(BuildError::ZeroLimbs); - } - - Ok(GGSWCiphertext { - data, - base2k, - k, - dsize, - }) - } -} - -impl fmt::Debug for GGSWCiphertext { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.data) - } -} - -impl fmt::Display for GGSWCiphertext { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "(GGSWCiphertext: k: {} base2k: {} dsize: {}) {}", - self.k().0, - self.base2k().0, - self.dsize().0, - self.data - ) - } -} - -impl FillUniform for GGSWCiphertext { - fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { - self.data.fill_uniform(log_bound, source); - } -} - -impl GGSWCiphertext { - pub fn at(&self, row: usize, col: usize) -> GLWECiphertext<&[u8]> { - GLWECiphertext::builder() - .data(self.data.at(row, col)) - .base2k(self.base2k()) - .k(self.k()) - .build() - .unwrap() - } -} - -impl GGSWCiphertext { - pub fn at_mut(&mut self, row: usize, col: usize) -> GLWECiphertext<&mut [u8]> { - GLWECiphertext::builder() - .base2k(self.base2k()) - .k(self.k()) - .data(self.data.at_mut(row, col)) - .build() - .unwrap() - } -} - -impl GGSWCiphertext> { - pub fn alloc(infos: &A) -> Self - where - A: GGSWInfos, - { - Self::alloc_with( - infos.n(), - infos.base2k(), - infos.k(), - infos.rank(), - infos.dnum(), - infos.dsize(), - ) - } - - pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self { - let size: usize = k.0.div_ceil(base2k.0) as usize; - debug_assert!( - size as u32 > dsize.0, - "invalid ggsw: ceil(k/base2k): {size} <= dsize: {}", - dsize.0 - ); - - assert!( - dnum.0 * dsize.0 <= size as u32, - "invalid ggsw: dnum: {} * dsize:{} > ceil(k/base2k): {size}", - dnum.0, - dsize.0, - ); - - Self { - data: MatZnx::alloc( - n.into(), - dnum.into(), - (rank + 1).into(), - (rank + 1).into(), - k.0.div_ceil(base2k.0) as usize, - ), - k, - base2k, - dsize, - } - } - - pub fn alloc_bytes(infos: &A) -> usize - where - A: GGSWInfos, - { - Self::alloc_bytes_with( - infos.n(), - infos.base2k(), - infos.k(), - infos.rank(), - infos.dnum(), - infos.dsize(), - ) - } - - pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { - let size: usize = k.0.div_ceil(base2k.0) as usize; - debug_assert!( - size as u32 > dsize.0, - "invalid ggsw: ceil(k/base2k): {size} <= dsize: {}", - dsize.0 - ); - - assert!( - dnum.0 * dsize.0 <= size as u32, - "invalid ggsw: dnum: {} * dsize:{} > ceil(k/base2k): {size}", - dnum.0, - dsize.0, - ); - - MatZnx::alloc_bytes( - n.into(), - dnum.into(), - (rank + 1).into(), - (rank + 1).into(), - k.0.div_ceil(base2k.0) as usize, - ) - } -} - -use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; - -impl ReaderFrom for GGSWCiphertext { - fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { - self.k = TorusPrecision(reader.read_u32::()?); - self.base2k = Base2K(reader.read_u32::()?); - self.dsize = Dsize(reader.read_u32::()?); - self.data.read_from(reader) - } -} - -impl WriterTo for GGSWCiphertext { - fn write_to(&self, writer: &mut W) -> std::io::Result<()> { - writer.write_u32::(self.k.into())?; - writer.write_u32::(self.base2k.into())?; - writer.write_u32::(self.dsize.into())?; - self.data.write_to(writer) - } -} diff --git a/poulpy-core/src/layouts/glwe.rs b/poulpy-core/src/layouts/glwe.rs new file mode 100644 index 0000000..d47f8cd --- /dev/null +++ b/poulpy-core/src/layouts/glwe.rs @@ -0,0 +1,218 @@ +use poulpy_hal::{ + layouts::{ + Data, DataMut, DataRef, FillUniform, ReaderFrom, ToOwnedDeep, VecZnx, VecZnxToMut, VecZnxToRef, WriterTo, ZnxInfos, + }, + source::Source, +}; + +use crate::layouts::{Base2K, Degree, LWEInfos, Rank, TorusPrecision}; +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; +use std::fmt; + +pub trait GLWEInfos +where + Self: LWEInfos, +{ + fn rank(&self) -> Rank; + fn glwe_layout(&self) -> GLWELayout { + GLWELayout { + n: self.n(), + base2k: self.base2k(), + k: self.k(), + rank: self.rank(), + } + } +} + +pub trait SetGLWEInfos { + fn set_k(&mut self, k: TorusPrecision); + fn set_base2k(&mut self, base2k: Base2K); +} + +#[derive(PartialEq, Eq, Copy, Clone, Debug)] +pub struct GLWELayout { + pub n: Degree, + pub base2k: Base2K, + pub k: TorusPrecision, + pub rank: Rank, +} + +impl LWEInfos for GLWELayout { + fn n(&self) -> Degree { + self.n + } + + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } +} + +impl GLWEInfos for GLWELayout { + fn rank(&self) -> Rank { + self.rank + } +} + +#[derive(PartialEq, Eq, Clone)] +pub struct GLWE { + pub(crate) data: VecZnx, + pub(crate) base2k: Base2K, + pub(crate) k: TorusPrecision, +} + +impl SetGLWEInfos for GLWE { + fn set_base2k(&mut self, base2k: Base2K) { + self.base2k = base2k + } + + fn set_k(&mut self, k: TorusPrecision) { + self.k = k + } +} + +impl GLWE { + pub fn data(&self) -> &VecZnx { + &self.data + } +} + +impl GLWE { + pub fn data_mut(&mut self) -> &mut VecZnx { + &mut self.data + } +} + +impl LWEInfos for GLWE { + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } + + fn n(&self) -> Degree { + Degree(self.data.n() as u32) + } + + fn size(&self) -> usize { + self.data.size() + } +} + +impl GLWEInfos for GLWE { + fn rank(&self) -> Rank { + Rank(self.data.cols() as u32 - 1) + } +} + +impl ToOwnedDeep for GLWE { + type Owned = GLWE>; + fn to_owned_deep(&self) -> Self::Owned { + GLWE { + data: self.data.to_owned_deep(), + k: self.k, + base2k: self.base2k, + } + } +} + +impl fmt::Debug for GLWE { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{self}") + } +} + +impl fmt::Display for GLWE { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "GLWE: base2k={} k={}: {}", + self.base2k().0, + self.k().0, + self.data + ) + } +} + +impl FillUniform for GLWE { + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { + self.data.fill_uniform(log_bound, source); + } +} + +impl GLWE> { + pub fn alloc_from_infos(infos: &A) -> Self + where + A: GLWEInfos, + { + Self::alloc(infos.n(), infos.base2k(), infos.k(), infos.rank()) + } + + pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self { + GLWE { + data: VecZnx::alloc(n.into(), (rank + 1).into(), k.0.div_ceil(base2k.0) as usize), + base2k, + k, + } + } + + pub fn bytes_of_from_infos(infos: &A) -> usize + where + A: GLWEInfos, + { + Self::bytes_of(infos.n(), infos.base2k(), infos.k(), infos.rank()) + } + + pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize { + VecZnx::bytes_of(n.into(), (rank + 1).into(), k.0.div_ceil(base2k.0) as usize) + } +} + +impl ReaderFrom for GLWE { + fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { + self.k = TorusPrecision(reader.read_u32::()?); + self.base2k = Base2K(reader.read_u32::()?); + self.data.read_from(reader) + } +} + +impl WriterTo for GLWE { + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + writer.write_u32::(self.k.0)?; + writer.write_u32::(self.base2k.0)?; + self.data.write_to(writer) + } +} + +pub trait GLWEToRef { + fn to_ref(&self) -> GLWE<&[u8]>; +} + +impl GLWEToRef for GLWE { + fn to_ref(&self) -> GLWE<&[u8]> { + GLWE { + k: self.k, + base2k: self.base2k, + data: self.data.to_ref(), + } + } +} + +pub trait GLWEToMut { + fn to_mut(&mut self) -> GLWE<&mut [u8]>; +} + +impl GLWEToMut for GLWE { + fn to_mut(&mut self) -> GLWE<&mut [u8]> { + GLWE { + k: self.k, + base2k: self.base2k, + data: self.data.to_mut(), + } + } +} diff --git a/poulpy-core/src/layouts/gglwe_atk.rs b/poulpy-core/src/layouts/glwe_automorphism_key.rs similarity index 51% rename from poulpy-core/src/layouts/gglwe_atk.rs rename to poulpy-core/src/layouts/glwe_automorphism_key.rs index 5c786d2..a378e6d 100644 --- a/poulpy-core/src/layouts/gglwe_atk.rs +++ b/poulpy-core/src/layouts/glwe_automorphism_key.rs @@ -4,14 +4,14 @@ use poulpy_hal::{ }; use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWESwitchingKey, GLWECiphertext, GLWEInfos, LWEInfos, Rank, TorusPrecision, + Base2K, Degree, Dnum, Dsize, GGLWE, GGLWEInfos, GGLWEToMut, GGLWEToRef, GLWE, GLWEInfos, LWEInfos, Rank, TorusPrecision, }; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::fmt; #[derive(PartialEq, Eq, Copy, Clone, Debug)] -pub struct GGLWEAutomorphismKeyLayout { +pub struct GLWEAutomorphismKeyLayout { pub n: Degree, pub base2k: Base2K, pub k: TorusPrecision, @@ -21,18 +21,38 @@ pub struct GGLWEAutomorphismKeyLayout { } #[derive(PartialEq, Eq, Clone)] -pub struct GGLWEAutomorphismKey { - pub(crate) key: GGLWESwitchingKey, +pub struct GLWEAutomorphismKey { + pub(crate) key: GGLWE, pub(crate) p: i64, } -impl GGLWEAutomorphismKey { +pub trait GetGaloisElement { + fn p(&self) -> i64; +} + +pub trait SetGaloisElement { + fn set_p(&mut self, p: i64); +} + +impl SetGaloisElement for GLWEAutomorphismKey { + fn set_p(&mut self, p: i64) { + self.p = p + } +} + +impl GetGaloisElement for GLWEAutomorphismKey { + fn p(&self) -> i64 { + self.p + } +} + +impl GLWEAutomorphismKey { pub fn p(&self) -> i64 { self.p } } -impl LWEInfos for GGLWEAutomorphismKey { +impl LWEInfos for GLWEAutomorphismKey { fn n(&self) -> Degree { self.key.n() } @@ -50,13 +70,13 @@ impl LWEInfos for GGLWEAutomorphismKey { } } -impl GLWEInfos for GGLWEAutomorphismKey { +impl GLWEInfos for GLWEAutomorphismKey { fn rank(&self) -> Rank { self.rank_out() } } -impl GGLWEInfos for GGLWEAutomorphismKey { +impl GGLWEInfos for GLWEAutomorphismKey { fn rank_in(&self) -> Rank { self.key.rank_in() } @@ -74,7 +94,7 @@ impl GGLWEInfos for GGLWEAutomorphismKey { } } -impl LWEInfos for GGLWEAutomorphismKeyLayout { +impl LWEInfos for GLWEAutomorphismKeyLayout { fn base2k(&self) -> Base2K { self.base2k } @@ -88,13 +108,13 @@ impl LWEInfos for GGLWEAutomorphismKeyLayout { } } -impl GLWEInfos for GGLWEAutomorphismKeyLayout { +impl GLWEInfos for GLWEAutomorphismKeyLayout { fn rank(&self) -> Rank { self.rank } } -impl GGLWEInfos for GGLWEAutomorphismKeyLayout { +impl GGLWEInfos for GLWEAutomorphismKeyLayout { fn rank_in(&self) -> Rank { self.rank } @@ -112,84 +132,102 @@ impl GGLWEInfos for GGLWEAutomorphismKeyLayout { } } -impl fmt::Debug for GGLWEAutomorphismKey { +impl fmt::Debug for GLWEAutomorphismKey { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{self}") } } -impl FillUniform for GGLWEAutomorphismKey { +impl FillUniform for GLWEAutomorphismKey { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.key.fill_uniform(log_bound, source); } } -impl fmt::Display for GGLWEAutomorphismKey { +impl fmt::Display for GLWEAutomorphismKey { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "(AutomorphismKey: p={}) {}", self.p, self.key) } } -impl GGLWEAutomorphismKey> { - pub fn alloc(infos: &A) -> Self +impl GLWEAutomorphismKey> { + pub fn alloc_from_infos(infos: &A) -> Self + where + A: GGLWEInfos, + { + Self::alloc( + infos.n(), + infos.base2k(), + infos.k(), + infos.rank(), + infos.dnum(), + infos.dsize(), + ) + } + + pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self { + GLWEAutomorphismKey { + key: GGLWE::alloc(n, base2k, k, rank, rank, dnum, dsize), + p: 0, + } + } + + pub fn bytes_of_from_infos(infos: &A) -> usize where A: GGLWEInfos, { assert_eq!( infos.rank_in(), infos.rank_out(), - "rank_in != rank_out is not supported for GGLWEAutomorphismKey" + "rank_in != rank_out is not supported for AutomorphismKey" ); - GGLWEAutomorphismKey { - key: GGLWESwitchingKey::alloc(infos), - p: 0, - } - } - - pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self { - GGLWEAutomorphismKey { - key: GGLWESwitchingKey::alloc_with(n, base2k, k, rank, rank, dnum, dsize), - p: 0, - } - } - - pub fn alloc_bytes(infos: &A) -> usize - where - A: GGLWEInfos, - { - assert_eq!( - infos.rank_in(), - infos.rank_out(), - "rank_in != rank_out is not supported for GGLWEAutomorphismKey" - ); - GGLWESwitchingKey::alloc_bytes(infos) + Self::bytes_of( + infos.n(), + infos.base2k(), + infos.k(), + infos.rank(), + infos.dnum(), + infos.dsize(), + ) } pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { - GGLWESwitchingKey::alloc_bytes_with(n, base2k, k, rank, rank, dnum, dsize) + GGLWE::bytes_of(n, base2k, k, rank, rank, dnum, dsize) } } -impl GGLWEAutomorphismKey { - pub fn at(&self, row: usize, col: usize) -> GLWECiphertext<&[u8]> { +impl GGLWEToMut for GLWEAutomorphismKey { + fn to_mut(&mut self) -> GGLWE<&mut [u8]> { + self.key.to_mut() + } +} + +impl GGLWEToRef for GLWEAutomorphismKey { + fn to_ref(&self) -> GGLWE<&[u8]> { + self.key.to_ref() + } +} + +impl GLWEAutomorphismKey { + pub fn at(&self, row: usize, col: usize) -> GLWE<&[u8]> { self.key.at(row, col) } } -impl GGLWEAutomorphismKey { - pub fn at_mut(&mut self, row: usize, col: usize) -> GLWECiphertext<&mut [u8]> { +impl GLWEAutomorphismKey { + pub fn at_mut(&mut self, row: usize, col: usize) -> GLWE<&mut [u8]> { self.key.at_mut(row, col) } } -impl ReaderFrom for GGLWEAutomorphismKey { +impl ReaderFrom for GLWEAutomorphismKey { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { self.p = reader.read_u64::()? as i64; self.key.read_from(reader) } } -impl WriterTo for GGLWEAutomorphismKey { +impl WriterTo for GLWEAutomorphismKey { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { writer.write_u64::(self.p as u64)?; self.key.write_to(writer) diff --git a/poulpy-core/src/layouts/glwe_ct.rs b/poulpy-core/src/layouts/glwe_ct.rs deleted file mode 100644 index 23b6ef9..0000000 --- a/poulpy-core/src/layouts/glwe_ct.rs +++ /dev/null @@ -1,300 +0,0 @@ -use poulpy_hal::{ - layouts::{ - Data, DataMut, DataRef, FillUniform, ReaderFrom, ToOwnedDeep, VecZnx, VecZnxToMut, VecZnxToRef, WriterTo, ZnxInfos, - }, - source::Source, -}; - -use crate::layouts::{Base2K, BuildError, Degree, LWEInfos, Rank, TorusPrecision}; -use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; -use std::fmt; - -pub trait GLWEInfos -where - Self: LWEInfos, -{ - fn rank(&self) -> Rank; - fn glwe_layout(&self) -> GLWECiphertextLayout { - GLWECiphertextLayout { - n: self.n(), - base2k: self.base2k(), - k: self.k(), - rank: self.rank(), - } - } -} - -pub trait GLWELayoutSet { - fn set_k(&mut self, k: TorusPrecision); - fn set_basek(&mut self, base2k: Base2K); -} - -#[derive(PartialEq, Eq, Copy, Clone, Debug)] -pub struct GLWECiphertextLayout { - pub n: Degree, - pub base2k: Base2K, - pub k: TorusPrecision, - pub rank: Rank, -} - -impl LWEInfos for GLWECiphertextLayout { - fn n(&self) -> Degree { - self.n - } - - fn base2k(&self) -> Base2K { - self.base2k - } - - fn k(&self) -> TorusPrecision { - self.k - } -} - -impl GLWEInfos for GLWECiphertextLayout { - fn rank(&self) -> Rank { - self.rank - } -} - -#[derive(PartialEq, Eq, Clone)] -pub struct GLWECiphertext { - pub(crate) data: VecZnx, - pub(crate) base2k: Base2K, - pub(crate) k: TorusPrecision, -} - -impl GLWELayoutSet for GLWECiphertext { - fn set_basek(&mut self, base2k: Base2K) { - self.base2k = base2k - } - - fn set_k(&mut self, k: TorusPrecision) { - self.k = k - } -} - -impl GLWECiphertext { - pub fn data(&self) -> &VecZnx { - &self.data - } -} - -impl GLWECiphertext { - pub fn data_mut(&mut self) -> &mut VecZnx { - &mut self.data - } -} - -pub struct GLWECiphertextBuilder { - data: Option>, - base2k: Option, - k: Option, -} - -impl GLWECiphertext { - #[inline] - pub fn builder() -> GLWECiphertextBuilder { - GLWECiphertextBuilder { - data: None, - base2k: None, - k: None, - } - } -} - -impl GLWECiphertextBuilder> { - #[inline] - pub fn layout(mut self, layout: &A) -> Self - where - A: GLWEInfos, - { - self.data = Some(VecZnx::alloc( - layout.n().into(), - (layout.rank() + 1).into(), - layout.size(), - )); - self.base2k = Some(layout.base2k()); - self.k = Some(layout.k()); - self - } -} - -impl GLWECiphertextBuilder { - #[inline] - pub fn data(mut self, data: VecZnx) -> Self { - self.data = Some(data); - self - } - #[inline] - pub fn base2k(mut self, base2k: Base2K) -> Self { - self.base2k = Some(base2k); - self - } - #[inline] - pub fn k(mut self, k: TorusPrecision) -> Self { - self.k = Some(k); - self - } - - pub fn build(self) -> Result, BuildError> { - let data: VecZnx = self.data.ok_or(BuildError::MissingData)?; - let base2k: Base2K = self.base2k.ok_or(BuildError::MissingBase2K)?; - let k: TorusPrecision = self.k.ok_or(BuildError::MissingK)?; - - if base2k == 0_u32 { - return Err(BuildError::ZeroBase2K); - } - - if k == 0_u32 { - return Err(BuildError::ZeroTorusPrecision); - } - - if data.n() == 0 { - return Err(BuildError::ZeroDegree); - } - - if data.cols() == 0 { - return Err(BuildError::ZeroCols); - } - - if data.size() == 0 { - return Err(BuildError::ZeroLimbs); - } - - Ok(GLWECiphertext { data, base2k, k }) - } -} - -impl LWEInfos for GLWECiphertext { - fn base2k(&self) -> Base2K { - self.base2k - } - - fn k(&self) -> TorusPrecision { - self.k - } - - fn n(&self) -> Degree { - Degree(self.data.n() as u32) - } - - fn size(&self) -> usize { - self.data.size() - } -} - -impl GLWEInfos for GLWECiphertext { - fn rank(&self) -> Rank { - Rank(self.data.cols() as u32 - 1) - } -} - -impl ToOwnedDeep for GLWECiphertext { - type Owned = GLWECiphertext>; - fn to_owned_deep(&self) -> Self::Owned { - GLWECiphertext { - data: self.data.to_owned_deep(), - k: self.k, - base2k: self.base2k, - } - } -} - -impl fmt::Debug for GLWECiphertext { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{self}") - } -} - -impl fmt::Display for GLWECiphertext { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "GLWECiphertext: base2k={} k={}: {}", - self.base2k().0, - self.k().0, - self.data - ) - } -} - -impl FillUniform for GLWECiphertext { - fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { - self.data.fill_uniform(log_bound, source); - } -} - -impl GLWECiphertext> { - pub fn alloc(infos: &A) -> Self - where - A: GLWEInfos, - { - Self::alloc_with(infos.n(), infos.base2k(), infos.k(), infos.rank()) - } - - pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self { - Self { - data: VecZnx::alloc(n.into(), (rank + 1).into(), k.0.div_ceil(base2k.0) as usize), - base2k, - k, - } - } - - pub fn alloc_bytes(infos: &A) -> usize - where - A: GLWEInfos, - { - Self::alloc_bytes_with(infos.n(), infos.base2k(), infos.k(), infos.rank()) - } - - pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize { - VecZnx::alloc_bytes(n.into(), (rank + 1).into(), k.0.div_ceil(base2k.0) as usize) - } -} - -pub trait GLWECiphertextToRef { - fn to_ref(&self) -> GLWECiphertext<&[u8]>; -} - -impl GLWECiphertextToRef for GLWECiphertext { - fn to_ref(&self) -> GLWECiphertext<&[u8]> { - GLWECiphertext::builder() - .k(self.k()) - .base2k(self.base2k()) - .data(self.data.to_ref()) - .build() - .unwrap() - } -} - -pub trait GLWECiphertextToMut { - fn to_mut(&mut self) -> GLWECiphertext<&mut [u8]>; -} - -impl GLWECiphertextToMut for GLWECiphertext { - fn to_mut(&mut self) -> GLWECiphertext<&mut [u8]> { - GLWECiphertext::builder() - .k(self.k()) - .base2k(self.base2k()) - .data(self.data.to_mut()) - .build() - .unwrap() - } -} - -impl ReaderFrom for GLWECiphertext { - fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { - self.k = TorusPrecision(reader.read_u32::()?); - self.base2k = Base2K(reader.read_u32::()?); - self.data.read_from(reader) - } -} - -impl WriterTo for GLWECiphertext { - fn write_to(&self, writer: &mut W) -> std::io::Result<()> { - writer.write_u32::(self.k.0)?; - writer.write_u32::(self.base2k.0)?; - self.data.write_to(writer) - } -} diff --git a/poulpy-core/src/layouts/glwe_pk.rs b/poulpy-core/src/layouts/glwe_pk.rs deleted file mode 100644 index fc4b0fa..0000000 --- a/poulpy-core/src/layouts/glwe_pk.rs +++ /dev/null @@ -1,209 +0,0 @@ -use poulpy_hal::layouts::{Data, DataMut, DataRef, ReaderFrom, VecZnx, WriterTo, ZnxInfos}; - -use crate::{ - dist::Distribution, - layouts::{Base2K, BuildError, Degree, GLWEInfos, LWEInfos, Rank, TorusPrecision}, -}; -use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; - -#[derive(PartialEq, Eq)] -pub struct GLWEPublicKey { - pub(crate) data: VecZnx, - pub(crate) base2k: Base2K, - pub(crate) k: TorusPrecision, - pub(crate) dist: Distribution, -} - -#[derive(PartialEq, Eq, Copy, Clone, Debug)] -pub struct GLWEPublicKeyLayout { - pub n: Degree, - pub base2k: Base2K, - pub k: TorusPrecision, - pub rank: Rank, -} - -impl LWEInfos for GLWEPublicKey { - fn base2k(&self) -> Base2K { - self.base2k - } - - fn k(&self) -> TorusPrecision { - self.k - } - - fn n(&self) -> Degree { - Degree(self.data.n() as u32) - } - - fn size(&self) -> usize { - self.data.size() - } -} - -impl GLWEInfos for GLWEPublicKey { - fn rank(&self) -> Rank { - Rank(self.data.cols() as u32 - 1) - } -} - -impl LWEInfos for GLWEPublicKeyLayout { - fn base2k(&self) -> Base2K { - self.base2k - } - - fn k(&self) -> TorusPrecision { - self.k - } - - fn n(&self) -> Degree { - self.n - } - - fn size(&self) -> usize { - self.k.0.div_ceil(self.base2k.0) as usize - } -} - -impl GLWEInfos for GLWEPublicKeyLayout { - fn rank(&self) -> Rank { - self.rank - } -} - -pub struct GLWEPublicKeyBuilder { - data: Option>, - base2k: Option, - k: Option, -} - -impl GLWEPublicKey { - #[inline] - pub fn builder() -> GLWEPublicKeyBuilder { - GLWEPublicKeyBuilder { - data: None, - base2k: None, - k: None, - } - } -} - -impl GLWEPublicKeyBuilder> { - #[inline] - pub fn layout(mut self, layout: &A) -> Self - where - A: GLWEInfos, - { - self.data = Some(VecZnx::alloc( - layout.n().into(), - (layout.rank() + 1).into(), - layout.size(), - )); - self.base2k = Some(layout.base2k()); - self.k = Some(layout.k()); - self - } -} - -impl GLWEPublicKeyBuilder { - #[inline] - pub fn data(mut self, data: VecZnx) -> Self { - self.data = Some(data); - self - } - #[inline] - pub fn base2k(mut self, base2k: Base2K) -> Self { - self.base2k = Some(base2k); - self - } - #[inline] - pub fn k(mut self, k: TorusPrecision) -> Self { - self.k = Some(k); - self - } - - pub fn build(self) -> Result, BuildError> { - let data: VecZnx = self.data.ok_or(BuildError::MissingData)?; - let base2k: Base2K = self.base2k.ok_or(BuildError::MissingBase2K)?; - let k: TorusPrecision = self.k.ok_or(BuildError::MissingK)?; - - if base2k == 0_u32 { - return Err(BuildError::ZeroBase2K); - } - - if k == 0_u32 { - return Err(BuildError::ZeroTorusPrecision); - } - - if data.n() == 0 { - return Err(BuildError::ZeroDegree); - } - - if data.cols() == 0 { - return Err(BuildError::ZeroCols); - } - - if data.size() == 0 { - return Err(BuildError::ZeroLimbs); - } - - Ok(GLWEPublicKey { - data, - base2k, - k, - dist: Distribution::NONE, - }) - } -} - -impl GLWEPublicKey> { - pub fn alloc(infos: &A) -> Self - where - A: GLWEInfos, - { - Self::alloc_with(infos.n(), infos.base2k(), infos.k(), infos.rank()) - } - - pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self { - Self { - data: VecZnx::alloc(n.into(), (rank + 1).into(), k.0.div_ceil(base2k.0) as usize), - base2k, - k, - dist: Distribution::NONE, - } - } - - pub fn alloc_bytes(infos: &A) -> usize - where - A: GLWEInfos, - { - Self::alloc_bytes_with(infos.n(), infos.base2k(), infos.k(), infos.rank()) - } - - pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize { - VecZnx::alloc_bytes(n.into(), (rank + 1).into(), k.0.div_ceil(base2k.0) as usize) - } -} - -impl ReaderFrom for GLWEPublicKey { - fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { - self.k = TorusPrecision(reader.read_u32::()?); - self.base2k = Base2K(reader.read_u32::()?); - match Distribution::read_from(reader) { - Ok(dist) => self.dist = dist, - Err(e) => return Err(e), - } - self.data.read_from(reader) - } -} - -impl WriterTo for GLWEPublicKey { - fn write_to(&self, writer: &mut W) -> std::io::Result<()> { - writer.write_u32::(self.k.0)?; - writer.write_u32::(self.base2k.0)?; - match self.dist.write_to(writer) { - Ok(()) => {} - Err(e) => return Err(e), - } - self.data.write_to(writer) - } -} diff --git a/poulpy-core/src/layouts/glwe_plaintext.rs b/poulpy-core/src/layouts/glwe_plaintext.rs new file mode 100644 index 0000000..3261d3d --- /dev/null +++ b/poulpy-core/src/layouts/glwe_plaintext.rs @@ -0,0 +1,160 @@ +use std::fmt; + +use poulpy_hal::layouts::{Data, DataMut, DataRef, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos}; + +use crate::layouts::{Base2K, Degree, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, Rank, SetGLWEInfos, TorusPrecision}; + +#[derive(PartialEq, Eq, Copy, Clone, Debug)] +pub struct GLWEPlaintextLayout { + pub n: Degree, + pub base2k: Base2K, + pub k: TorusPrecision, +} + +impl LWEInfos for GLWEPlaintextLayout { + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } + + fn n(&self) -> Degree { + self.n + } +} + +impl GLWEInfos for GLWEPlaintextLayout { + fn rank(&self) -> Rank { + Rank(0) + } +} + +pub struct GLWEPlaintext { + pub data: VecZnx, + pub base2k: Base2K, + pub k: TorusPrecision, +} + +impl SetGLWEInfos for GLWEPlaintext { + fn set_base2k(&mut self, base2k: Base2K) { + self.base2k = base2k + } + + fn set_k(&mut self, k: TorusPrecision) { + self.k = k + } +} + +impl LWEInfos for GLWEPlaintext { + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } + + fn size(&self) -> usize { + self.data.size() + } + + fn n(&self) -> Degree { + Degree(self.data.n() as u32) + } +} + +impl GLWEInfos for GLWEPlaintext { + fn rank(&self) -> Rank { + Rank(self.data.cols() as u32 - 1) + } +} + +impl fmt::Display for GLWEPlaintext { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "GLWEPlaintext: base2k={} k={}: {}", + self.base2k().0, + self.k().0, + self.data + ) + } +} + +impl GLWEPlaintext> { + pub fn alloc_from_infos(infos: &A) -> Self + where + A: GLWEInfos, + { + Self::alloc(infos.n(), infos.base2k(), infos.k()) + } + + pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision) -> Self { + GLWEPlaintext { + data: VecZnx::alloc(n.into(), 1, k.0.div_ceil(base2k.0) as usize), + base2k, + k, + } + } + + pub fn bytes_of_from_infos(infos: &A) -> usize + where + A: GLWEInfos, + { + Self::bytes_of(infos.n(), infos.base2k(), infos.k()) + } + + pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision) -> usize { + VecZnx::bytes_of(n.into(), 1, k.0.div_ceil(base2k.0) as usize) + } +} + +impl GLWEToRef for GLWEPlaintext { + fn to_ref(&self) -> GLWE<&[u8]> { + GLWE { + k: self.k, + base2k: self.base2k, + data: self.data.to_ref(), + } + } +} + +impl GLWEToMut for GLWEPlaintext { + fn to_mut(&mut self) -> GLWE<&mut [u8]> { + GLWE { + k: self.k, + base2k: self.base2k, + data: self.data.to_mut(), + } + } +} + +pub trait GLWEPlaintextToRef { + fn to_ref(&self) -> GLWEPlaintext<&[u8]>; +} + +impl GLWEPlaintextToRef for GLWEPlaintext { + fn to_ref(&self) -> GLWEPlaintext<&[u8]> { + GLWEPlaintext { + data: self.data.to_ref(), + base2k: self.base2k, + k: self.k, + } + } +} + +pub trait GLWEPlaintextToMut { + fn to_ref(&mut self) -> GLWEPlaintext<&mut [u8]>; +} + +impl GLWEPlaintextToMut for GLWEPlaintext { + fn to_ref(&mut self) -> GLWEPlaintext<&mut [u8]> { + GLWEPlaintext { + base2k: self.base2k, + k: self.k, + data: self.data.to_mut(), + } + } +} diff --git a/poulpy-core/src/layouts/glwe_pt.rs b/poulpy-core/src/layouts/glwe_pt.rs deleted file mode 100644 index b565055..0000000 --- a/poulpy-core/src/layouts/glwe_pt.rs +++ /dev/null @@ -1,202 +0,0 @@ -use std::fmt; - -use poulpy_hal::layouts::{Data, DataMut, DataRef, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos}; - -use crate::layouts::{ - Base2K, BuildError, Degree, GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEInfos, GLWELayoutSet, LWEInfos, - Rank, TorusPrecision, -}; - -#[derive(PartialEq, Eq, Copy, Clone, Debug)] -pub struct GLWEPlaintextLayout { - pub n: Degree, - pub base2k: Base2K, - pub k: TorusPrecision, -} - -impl LWEInfos for GLWEPlaintextLayout { - fn base2k(&self) -> Base2K { - self.base2k - } - - fn k(&self) -> TorusPrecision { - self.k - } - - fn n(&self) -> Degree { - self.n - } -} - -impl GLWEInfos for GLWEPlaintextLayout { - fn rank(&self) -> Rank { - Rank(0) - } -} - -pub struct GLWEPlaintext { - pub data: VecZnx, - pub base2k: Base2K, - pub k: TorusPrecision, -} - -impl GLWELayoutSet for GLWEPlaintext { - fn set_basek(&mut self, base2k: Base2K) { - self.base2k = base2k - } - - fn set_k(&mut self, k: TorusPrecision) { - self.k = k - } -} - -impl LWEInfos for GLWEPlaintext { - fn base2k(&self) -> Base2K { - self.base2k - } - - fn k(&self) -> TorusPrecision { - self.k - } - - fn size(&self) -> usize { - self.data.size() - } - - fn n(&self) -> Degree { - Degree(self.data.n() as u32) - } -} - -impl GLWEInfos for GLWEPlaintext { - fn rank(&self) -> Rank { - Rank(self.data.cols() as u32 - 1) - } -} - -pub struct GLWEPlaintextBuilder { - data: Option>, - base2k: Option, - k: Option, -} - -impl GLWEPlaintext { - #[inline] - pub fn builder() -> GLWEPlaintextBuilder { - GLWEPlaintextBuilder { - data: None, - base2k: None, - k: None, - } - } -} - -impl GLWEPlaintextBuilder { - #[inline] - pub fn data(mut self, data: VecZnx) -> Self { - self.data = Some(data); - self - } - #[inline] - pub fn base2k(mut self, base2k: Base2K) -> Self { - self.base2k = Some(base2k); - self - } - #[inline] - pub fn k(mut self, k: TorusPrecision) -> Self { - self.k = Some(k); - self - } - - pub fn build(self) -> Result, BuildError> { - let data: VecZnx = self.data.ok_or(BuildError::MissingData)?; - let base2k: Base2K = self.base2k.ok_or(BuildError::MissingBase2K)?; - let k: TorusPrecision = self.k.ok_or(BuildError::MissingK)?; - - if base2k.0 == 0 { - return Err(BuildError::ZeroBase2K); - } - - if k.0 == 0 { - return Err(BuildError::ZeroTorusPrecision); - } - - if data.n() == 0 { - return Err(BuildError::ZeroDegree); - } - - if data.cols() != 1 { - return Err(BuildError::ZeroCols); - } - - if data.size() == 0 { - return Err(BuildError::ZeroLimbs); - } - - Ok(GLWEPlaintext { data, base2k, k }) - } -} - -impl fmt::Display for GLWEPlaintext { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "GLWEPlaintext: base2k={} k={}: {}", - self.base2k().0, - self.k().0, - self.data - ) - } -} - -impl GLWEPlaintext> { - pub fn alloc(infos: &A) -> Self - where - A: GLWEInfos, - { - Self::alloc_with(infos.n(), infos.base2k(), infos.k(), Rank(0)) - } - - pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self { - debug_assert!(rank.0 == 0); - Self { - data: VecZnx::alloc(n.into(), (rank + 1).into(), k.0.div_ceil(base2k.0) as usize), - base2k, - k, - } - } - - pub fn alloc_bytes(infos: &A) -> usize - where - A: GLWEInfos, - { - Self::alloc_bytes_with(infos.n(), infos.base2k(), infos.k(), Rank(0)) - } - - pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize { - debug_assert!(rank.0 == 0); - VecZnx::alloc_bytes(n.into(), (rank + 1).into(), k.0.div_ceil(base2k.0) as usize) - } -} - -impl GLWECiphertextToRef for GLWEPlaintext { - fn to_ref(&self) -> GLWECiphertext<&[u8]> { - GLWECiphertext::builder() - .data(self.data.to_ref()) - .k(self.k()) - .base2k(self.base2k()) - .build() - .unwrap() - } -} - -impl GLWECiphertextToMut for GLWEPlaintext { - fn to_mut(&mut self) -> GLWECiphertext<&mut [u8]> { - GLWECiphertext::builder() - .k(self.k()) - .base2k(self.base2k()) - .data(self.data.to_mut()) - .build() - .unwrap() - } -} diff --git a/poulpy-core/src/layouts/glwe_public_key.rs b/poulpy-core/src/layouts/glwe_public_key.rs new file mode 100644 index 0000000..f9c8c41 --- /dev/null +++ b/poulpy-core/src/layouts/glwe_public_key.rs @@ -0,0 +1,140 @@ +use poulpy_hal::layouts::{Data, DataMut, DataRef, ReaderFrom, VecZnx, WriterTo}; + +use crate::{ + GetDistribution, GetDistributionMut, + dist::Distribution, + layouts::{Base2K, Degree, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, Rank, TorusPrecision}, +}; + +#[derive(PartialEq, Eq)] +pub struct GLWEPublicKey { + pub(crate) key: GLWE, + pub(crate) dist: Distribution, +} + +impl GetDistributionMut for GLWEPublicKey { + fn dist_mut(&mut self) -> &mut Distribution { + &mut self.dist + } +} + +impl GetDistribution for GLWEPublicKey { + fn dist(&self) -> &Distribution { + &self.dist + } +} + +#[derive(PartialEq, Eq, Copy, Clone, Debug)] +pub struct GLWEPublicKeyLayout { + pub n: Degree, + pub base2k: Base2K, + pub k: TorusPrecision, + pub rank: Rank, +} + +impl LWEInfos for GLWEPublicKey { + fn base2k(&self) -> Base2K { + self.key.base2k() + } + + fn k(&self) -> TorusPrecision { + self.key.k() + } + + fn n(&self) -> Degree { + self.key.n() + } + + fn size(&self) -> usize { + self.key.size() + } +} + +impl GLWEInfos for GLWEPublicKey { + fn rank(&self) -> Rank { + self.key.rank() + } +} + +impl LWEInfos for GLWEPublicKeyLayout { + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } + + fn n(&self) -> Degree { + self.n + } + + fn size(&self) -> usize { + self.k.0.div_ceil(self.base2k.0) as usize + } +} + +impl GLWEInfos for GLWEPublicKeyLayout { + fn rank(&self) -> Rank { + self.rank + } +} + +impl GLWEPublicKey> { + pub fn alloc_from_infos(infos: &A) -> Self + where + A: GLWEInfos, + { + Self::alloc(infos.n(), infos.base2k(), infos.k(), infos.rank()) + } + + pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self { + GLWEPublicKey { + key: GLWE::alloc(n, base2k, k, rank), + dist: Distribution::NONE, + } + } + + pub fn bytes_of_from_infos(infos: &A) -> usize + where + A: GLWEInfos, + { + Self::bytes_of(infos.n(), infos.base2k(), infos.k(), infos.rank()) + } + + pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize { + VecZnx::bytes_of(n.into(), (rank + 1).into(), k.0.div_ceil(base2k.0) as usize) + } +} + +impl ReaderFrom for GLWEPublicKey { + fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { + match Distribution::read_from(reader) { + Ok(dist) => self.dist = dist, + Err(e) => return Err(e), + } + self.key.read_from(reader) + } +} + +impl WriterTo for GLWEPublicKey { + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + match self.dist.write_to(writer) { + Ok(()) => {} + Err(e) => return Err(e), + } + self.key.write_to(writer) + } +} + +impl GLWEToRef for GLWEPublicKey { + fn to_ref(&self) -> GLWE<&[u8]> { + self.key.to_ref() + } +} + +impl GLWEToMut for GLWEPublicKey { + fn to_mut(&mut self) -> GLWE<&mut [u8]> { + self.key.to_mut() + } +} diff --git a/poulpy-core/src/layouts/glwe_sk.rs b/poulpy-core/src/layouts/glwe_secret.rs similarity index 73% rename from poulpy-core/src/layouts/glwe_sk.rs rename to poulpy-core/src/layouts/glwe_secret.rs index 8870d35..b99a78e 100644 --- a/poulpy-core/src/layouts/glwe_sk.rs +++ b/poulpy-core/src/layouts/glwe_secret.rs @@ -1,9 +1,10 @@ use poulpy_hal::{ - layouts::{Data, DataMut, DataRef, ReaderFrom, ScalarZnx, WriterTo, ZnxInfos, ZnxZero}, + layouts::{Data, DataMut, DataRef, ReaderFrom, ScalarZnx, ScalarZnxToMut, ScalarZnxToRef, WriterTo, ZnxInfos, ZnxZero}, source::Source, }; use crate::{ + GetDistribution, dist::Distribution, layouts::{Base2K, Degree, GLWEInfos, LWEInfos, Rank, TorusPrecision}, }; @@ -61,6 +62,12 @@ impl LWEInfos for GLWESecret { } } +impl GetDistribution for GLWESecret { + fn dist(&self) -> &Distribution { + &self.dist + } +} + impl GLWEInfos for GLWESecret { fn rank(&self) -> Rank { Rank(self.data.cols() as u32) @@ -68,29 +75,29 @@ impl GLWEInfos for GLWESecret { } impl GLWESecret> { - pub fn alloc(infos: &A) -> Self + pub fn alloc_from_infos(infos: &A) -> Self where A: GLWEInfos, { - Self::alloc_with(infos.n(), infos.rank()) + Self::alloc(infos.n(), infos.rank()) } - pub fn alloc_with(n: Degree, rank: Rank) -> Self { - Self { + pub fn alloc(n: Degree, rank: Rank) -> Self { + GLWESecret { data: ScalarZnx::alloc(n.into(), rank.into()), dist: Distribution::NONE, } } - pub fn alloc_bytes(infos: &A) -> usize + pub fn bytes_of_from_infos(infos: &A) -> usize where A: GLWEInfos, { - Self::alloc_bytes_with(infos.n(), infos.rank()) + Self::bytes_of(infos.n(), infos.rank()) } - pub fn alloc_bytes_with(n: Degree, rank: Rank) -> usize { - ScalarZnx::alloc_bytes(n.into(), rank.into()) + pub fn bytes_of(n: Degree, rank: Rank) -> usize { + ScalarZnx::bytes_of(n.into(), rank.into()) } } @@ -136,6 +143,32 @@ impl GLWESecret { } } +pub trait GLWESecretToMut { + fn to_mut(&mut self) -> GLWESecret<&mut [u8]>; +} + +impl GLWESecretToMut for GLWESecret { + fn to_mut(&mut self) -> GLWESecret<&mut [u8]> { + GLWESecret { + dist: self.dist, + data: self.data.to_mut(), + } + } +} + +pub trait GLWESecretToRef { + fn to_ref(&self) -> GLWESecret<&[u8]>; +} + +impl GLWESecretToRef for GLWESecret { + fn to_ref(&self) -> GLWESecret<&[u8]> { + GLWESecret { + data: self.data.to_ref(), + dist: self.dist, + } + } +} + impl ReaderFrom for GLWESecret { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { match Distribution::read_from(reader) { diff --git a/poulpy-core/src/layouts/glwe_switching_key.rs b/poulpy-core/src/layouts/glwe_switching_key.rs new file mode 100644 index 0000000..3f94c06 --- /dev/null +++ b/poulpy-core/src/layouts/glwe_switching_key.rs @@ -0,0 +1,255 @@ +use poulpy_hal::{ + layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, WriterTo}, + source::Source, +}; + +use crate::layouts::{ + Base2K, Degree, Dnum, Dsize, GGLWE, GGLWEInfos, GGLWEToMut, GGLWEToRef, GLWE, GLWEInfos, LWEInfos, Rank, TorusPrecision, +}; +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; + +use std::fmt; + +#[derive(PartialEq, Eq, Copy, Clone, Debug)] +pub struct GLWESwitchingKeyLayout { + pub n: Degree, + pub base2k: Base2K, + pub k: TorusPrecision, + pub rank_in: Rank, + pub rank_out: Rank, + pub dnum: Dnum, + pub dsize: Dsize, +} + +impl LWEInfos for GLWESwitchingKeyLayout { + fn n(&self) -> Degree { + self.n + } + + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } +} + +impl GLWEInfos for GLWESwitchingKeyLayout { + fn rank(&self) -> Rank { + self.rank_out() + } +} + +impl GGLWEInfos for GLWESwitchingKeyLayout { + fn rank_in(&self) -> Rank { + self.rank_in + } + + fn rank_out(&self) -> Rank { + self.rank_out + } + + fn dsize(&self) -> Dsize { + self.dsize + } + + fn dnum(&self) -> Dnum { + self.dnum + } +} + +#[derive(PartialEq, Eq, Clone)] +pub struct GLWESwitchingKey { + pub(crate) key: GGLWE, + pub(crate) input_degree: Degree, // Degree of sk_in + pub(crate) output_degree: Degree, // Degree of sk_out +} + +pub trait GLWESwitchingKeyDegrees { + fn input_degree(&self) -> &Degree; + fn output_degree(&self) -> &Degree; +} + +impl GLWESwitchingKeyDegrees for GLWESwitchingKey { + fn output_degree(&self) -> &Degree { + &self.output_degree + } + + fn input_degree(&self) -> &Degree { + &self.input_degree + } +} + +pub trait GLWESwitchingKeyDegreesMut { + fn input_degree(&mut self) -> &mut Degree; + fn output_degree(&mut self) -> &mut Degree; +} + +impl GLWESwitchingKeyDegreesMut for GLWESwitchingKey { + fn output_degree(&mut self) -> &mut Degree { + &mut self.output_degree + } + + fn input_degree(&mut self) -> &mut Degree { + &mut self.input_degree + } +} + +impl LWEInfos for GLWESwitchingKey { + fn n(&self) -> Degree { + self.key.n() + } + + fn base2k(&self) -> Base2K { + self.key.base2k() + } + + fn k(&self) -> TorusPrecision { + self.key.k() + } + + fn size(&self) -> usize { + self.key.size() + } +} + +impl GLWEInfos for GLWESwitchingKey { + fn rank(&self) -> Rank { + self.rank_out() + } +} + +impl GGLWEInfos for GLWESwitchingKey { + fn rank_in(&self) -> Rank { + self.key.rank_in() + } + + fn rank_out(&self) -> Rank { + self.key.rank_out() + } + + fn dsize(&self) -> Dsize { + self.key.dsize() + } + + fn dnum(&self) -> Dnum { + self.key.dnum() + } +} + +impl fmt::Debug for GLWESwitchingKey { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{self}") + } +} + +impl fmt::Display for GLWESwitchingKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "(GLWESwitchingKey: sk_in_n={} sk_out_n={}) {}", + self.input_degree, + self.output_degree, + self.key.data() + ) + } +} + +impl FillUniform for GLWESwitchingKey { + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { + self.key.fill_uniform(log_bound, source); + } +} + +impl GLWESwitchingKey> { + pub fn alloc_from_infos(infos: &A) -> Self + where + A: GGLWEInfos, + { + Self::alloc( + infos.n(), + infos.base2k(), + infos.k(), + infos.rank_in(), + infos.rank_out(), + infos.dnum(), + infos.dsize(), + ) + } + + pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank_in: Rank, rank_out: Rank, dnum: Dnum, dsize: Dsize) -> Self { + GLWESwitchingKey { + key: GGLWE::alloc(n, base2k, k, rank_in, rank_out, dnum, dsize), + input_degree: Degree(0), + output_degree: Degree(0), + } + } + + pub fn bytes_of_from_infos(infos: &A) -> usize + where + A: GGLWEInfos, + { + Self::bytes_of( + infos.n(), + infos.base2k(), + infos.k(), + infos.rank_in(), + infos.rank_out(), + infos.dnum(), + infos.dsize(), + ) + } + + pub fn bytes_of( + n: Degree, + base2k: Base2K, + k: TorusPrecision, + rank_in: Rank, + rank_out: Rank, + dnum: Dnum, + dsize: Dsize, + ) -> usize { + GGLWE::bytes_of(n, base2k, k, rank_in, rank_out, dnum, dsize) + } +} + +impl GGLWEToMut for GLWESwitchingKey { + fn to_mut(&mut self) -> GGLWE<&mut [u8]> { + self.key.to_mut() + } +} + +impl GGLWEToRef for GLWESwitchingKey { + fn to_ref(&self) -> GGLWE<&[u8]> { + self.key.to_ref() + } +} + +impl GLWESwitchingKey { + pub fn at(&self, row: usize, col: usize) -> GLWE<&[u8]> { + self.key.at(row, col) + } +} + +impl GLWESwitchingKey { + pub fn at_mut(&mut self, row: usize, col: usize) -> GLWE<&mut [u8]> { + self.key.at_mut(row, col) + } +} + +impl ReaderFrom for GLWESwitchingKey { + fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { + self.input_degree = Degree(reader.read_u32::()?); + self.output_degree = Degree(reader.read_u32::()?); + self.key.read_from(reader) + } +} + +impl WriterTo for GLWESwitchingKey { + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + writer.write_u32::(self.input_degree.into())?; + writer.write_u32::(self.output_degree.into())?; + self.key.write_to(writer) + } +} diff --git a/poulpy-core/src/layouts/gglwe_tsk.rs b/poulpy-core/src/layouts/glwe_tensor_key.rs similarity index 59% rename from poulpy-core/src/layouts/gglwe_tsk.rs rename to poulpy-core/src/layouts/glwe_tensor_key.rs index a949b7e..bc0100f 100644 --- a/poulpy-core/src/layouts/gglwe_tsk.rs +++ b/poulpy-core/src/layouts/glwe_tensor_key.rs @@ -3,13 +3,15 @@ use poulpy_hal::{ source::Source, }; -use crate::layouts::{Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWESwitchingKey, GLWEInfos, LWEInfos, Rank, TorusPrecision}; +use crate::layouts::{ + Base2K, Degree, Dnum, Dsize, GGLWE, GGLWEInfos, GGLWEToMut, GGLWEToRef, GLWEInfos, LWEInfos, Rank, TorusPrecision, +}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::fmt; #[derive(PartialEq, Eq, Copy, Clone, Debug)] -pub struct GGLWETensorKeyLayout { +pub struct GLWETensorKeyLayout { pub n: Degree, pub base2k: Base2K, pub k: TorusPrecision, @@ -19,11 +21,11 @@ pub struct GGLWETensorKeyLayout { } #[derive(PartialEq, Eq, Clone)] -pub struct GGLWETensorKey { - pub(crate) keys: Vec>, +pub struct GLWETensorKey { + pub(crate) keys: Vec>, } -impl LWEInfos for GGLWETensorKey { +impl LWEInfos for GLWETensorKey { fn n(&self) -> Degree { self.keys[0].n() } @@ -41,13 +43,13 @@ impl LWEInfos for GGLWETensorKey { } } -impl GLWEInfos for GGLWETensorKey { +impl GLWEInfos for GLWETensorKey { fn rank(&self) -> Rank { self.keys[0].rank_out() } } -impl GGLWEInfos for GGLWETensorKey { +impl GGLWEInfos for GLWETensorKey { fn rank_in(&self) -> Rank { self.rank_out() } @@ -65,7 +67,7 @@ impl GGLWEInfos for GGLWETensorKey { } } -impl LWEInfos for GGLWETensorKeyLayout { +impl LWEInfos for GLWETensorKeyLayout { fn n(&self) -> Degree { self.n } @@ -79,13 +81,13 @@ impl LWEInfos for GGLWETensorKeyLayout { } } -impl GLWEInfos for GGLWETensorKeyLayout { +impl GLWEInfos for GLWETensorKeyLayout { fn rank(&self) -> Rank { self.rank_out() } } -impl GGLWEInfos for GGLWETensorKeyLayout { +impl GGLWEInfos for GLWETensorKeyLayout { fn rank_in(&self) -> Rank { self.rank } @@ -103,21 +105,21 @@ impl GGLWEInfos for GGLWETensorKeyLayout { } } -impl fmt::Debug for GGLWETensorKey { +impl fmt::Debug for GLWETensorKey { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{self}") } } -impl FillUniform for GGLWETensorKey { +impl FillUniform for GLWETensorKey { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.keys .iter_mut() - .for_each(|key: &mut GGLWESwitchingKey| key.fill_uniform(log_bound, source)) + .for_each(|key: &mut GGLWE| key.fill_uniform(log_bound, source)) } } -impl fmt::Display for GGLWETensorKey { +impl fmt::Display for GLWETensorKey { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!(f, "(GLWETensorKey)",)?; for (i, key) in self.keys.iter().enumerate() { @@ -127,8 +129,8 @@ impl fmt::Display for GGLWETensorKey { } } -impl GGLWETensorKey> { - pub fn alloc(infos: &A) -> Self +impl GLWETensorKey> { + pub fn alloc_from_infos(infos: &A) -> Self where A: GGLWEInfos, { @@ -137,34 +139,26 @@ impl GGLWETensorKey> { infos.rank_out(), "rank_in != rank_out is not supported for GGLWETensorKey" ); - Self::alloc_with( + Self::alloc( infos.n(), infos.base2k(), infos.k(), - infos.rank_out(), + infos.rank(), infos.dnum(), infos.dsize(), ) } - pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self { - let mut keys: Vec>> = Vec::new(); + pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self { let pairs: u32 = (((rank.0 + 1) * rank.0) >> 1).max(1); - (0..pairs).for_each(|_| { - keys.push(GGLWESwitchingKey::alloc_with( - n, - base2k, - k, - Rank(1), - rank, - dnum, - dsize, - )); - }); - Self { keys } + GLWETensorKey { + keys: (0..pairs) + .map(|_| GGLWE::alloc(n, base2k, k, Rank(1), rank, dnum, dsize)) + .collect(), + } } - pub fn alloc_bytes(infos: &A) -> usize + pub fn bytes_of_from_infos(infos: &A) -> usize where A: GGLWEInfos, { @@ -173,29 +167,25 @@ impl GGLWETensorKey> { infos.rank_out(), "rank_in != rank_out is not supported for GGLWETensorKey" ); - let rank_out: usize = infos.rank_out().into(); - let pairs: usize = (((rank_out + 1) * rank_out) >> 1).max(1); - pairs - * GGLWESwitchingKey::alloc_bytes_with( - infos.n(), - infos.base2k(), - infos.k(), - Rank(1), - infos.rank_out(), - infos.dnum(), - infos.dsize(), - ) + Self::bytes_of( + infos.n(), + infos.base2k(), + infos.k(), + infos.rank(), + infos.dnum(), + infos.dsize(), + ) } - pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { + pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { let pairs: usize = (((rank.0 + 1) * rank.0) >> 1).max(1) as usize; - pairs * GGLWESwitchingKey::alloc_bytes_with(n, base2k, k, Rank(1), rank, dnum, dsize) + pairs * GGLWE::bytes_of(n, base2k, k, Rank(1), rank, dnum, dsize) } } -impl GGLWETensorKey { - // Returns a mutable reference to GLWESwitchingKey_{s}(s[i] * s[j]) - pub fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GGLWESwitchingKey { +impl GLWETensorKey { + // Returns a mutable reference to GGLWE_{s}(s[i] * s[j]) + pub fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GGLWE { if i > j { std::mem::swap(&mut i, &mut j); }; @@ -204,9 +194,9 @@ impl GGLWETensorKey { } } -impl GGLWETensorKey { - // Returns a reference to GLWESwitchingKey_{s}(s[i] * s[j]) - pub fn at(&self, mut i: usize, mut j: usize) -> &GGLWESwitchingKey { +impl GLWETensorKey { + // Returns a reference to GGLWE_{s}(s[i] * s[j]) + pub fn at(&self, mut i: usize, mut j: usize) -> &GGLWE { if i > j { std::mem::swap(&mut i, &mut j); }; @@ -215,7 +205,7 @@ impl GGLWETensorKey { } } -impl ReaderFrom for GGLWETensorKey { +impl ReaderFrom for GLWETensorKey { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { let len: usize = reader.read_u64::()? as usize; if self.keys.len() != len { @@ -231,7 +221,7 @@ impl ReaderFrom for GGLWETensorKey { } } -impl WriterTo for GGLWETensorKey { +impl WriterTo for GLWETensorKey { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { writer.write_u64::(self.keys.len() as u64)?; for key in &self.keys { @@ -240,3 +230,33 @@ impl WriterTo for GGLWETensorKey { Ok(()) } } + +pub trait GLWETensorKeyToRef { + fn to_ref(&self) -> GLWETensorKey<&[u8]>; +} + +impl GLWETensorKeyToRef for GLWETensorKey +where + GGLWE: GGLWEToRef, +{ + fn to_ref(&self) -> GLWETensorKey<&[u8]> { + GLWETensorKey { + keys: self.keys.iter().map(|c| c.to_ref()).collect(), + } + } +} + +pub trait GLWETensorKeyToMut { + fn to_mut(&mut self) -> GLWETensorKey<&mut [u8]>; +} + +impl GLWETensorKeyToMut for GLWETensorKey +where + GGLWE: GGLWEToMut, +{ + fn to_mut(&mut self) -> GLWETensorKey<&mut [u8]> { + GLWETensorKey { + keys: self.keys.iter_mut().map(|c| c.to_mut()).collect(), + } + } +} diff --git a/poulpy-core/src/layouts/glwe_to_lwe_ksk.rs b/poulpy-core/src/layouts/glwe_to_lwe_switching_key.rs similarity index 54% rename from poulpy-core/src/layouts/glwe_to_lwe_ksk.rs rename to poulpy-core/src/layouts/glwe_to_lwe_switching_key.rs index f227c9c..bc3ee4b 100644 --- a/poulpy-core/src/layouts/glwe_to_lwe_ksk.rs +++ b/poulpy-core/src/layouts/glwe_to_lwe_switching_key.rs @@ -3,7 +3,10 @@ use poulpy_hal::{ source::Source, }; -use crate::layouts::{Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWESwitchingKey, GLWEInfos, LWEInfos, Rank, TorusPrecision}; +use crate::layouts::{ + Base2K, Degree, Dnum, Dsize, GGLWE, GGLWEInfos, GGLWEToMut, GGLWEToRef, GLWEInfos, GLWESwitchingKey, GLWESwitchingKeyDegrees, + GLWESwitchingKeyDegreesMut, LWEInfos, Rank, TorusPrecision, +}; use std::fmt; @@ -54,11 +57,11 @@ impl GGLWEInfos for GLWEToLWEKeyLayout { } } -/// A special [GLWESwitchingKey] required to for the conversion from [GLWECiphertext] to [LWECiphertext]. +/// A special [GLWESwitchingKey] required to for the conversion from [GLWE] to [LWE]. #[derive(PartialEq, Eq, Clone)] -pub struct GLWEToLWEKey(pub(crate) GGLWESwitchingKey); +pub struct GLWEToLWESwitchingKey(pub(crate) GLWESwitchingKey); -impl LWEInfos for GLWEToLWEKey { +impl LWEInfos for GLWEToLWESwitchingKey { fn base2k(&self) -> Base2K { self.0.base2k() } @@ -76,12 +79,12 @@ impl LWEInfos for GLWEToLWEKey { } } -impl GLWEInfos for GLWEToLWEKey { +impl GLWEInfos for GLWEToLWESwitchingKey { fn rank(&self) -> Rank { self.rank_out() } } -impl GGLWEInfos for GLWEToLWEKey { +impl GGLWEInfos for GLWEToLWESwitchingKey { fn rank_in(&self) -> Rank { self.0.rank_in() } @@ -99,56 +102,62 @@ impl GGLWEInfos for GLWEToLWEKey { } } -impl fmt::Debug for GLWEToLWEKey { +impl fmt::Debug for GLWEToLWESwitchingKey { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{self}") } } -impl FillUniform for GLWEToLWEKey { +impl FillUniform for GLWEToLWESwitchingKey { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.0.fill_uniform(log_bound, source); } } -impl fmt::Display for GLWEToLWEKey { +impl fmt::Display for GLWEToLWESwitchingKey { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "(GLWEToLWESwitchingKey) {}", self.0) } } -impl ReaderFrom for GLWEToLWEKey { +impl ReaderFrom for GLWEToLWESwitchingKey { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { self.0.read_from(reader) } } -impl WriterTo for GLWEToLWEKey { +impl WriterTo for GLWEToLWESwitchingKey { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { self.0.write_to(writer) } } -impl GLWEToLWEKey> { - pub fn alloc(infos: &A) -> Self +impl GLWEToLWESwitchingKey> { + pub fn alloc_from_infos(infos: &A) -> Self where A: GGLWEInfos, { - debug_assert_eq!( + assert_eq!( infos.rank_out().0, 1, "rank_out > 1 is not supported for GLWEToLWESwitchingKey" ); - debug_assert_eq!( + assert_eq!( infos.dsize().0, 1, "dsize > 1 is not supported for GLWEToLWESwitchingKey" ); - Self(GGLWESwitchingKey::alloc(infos)) + Self::alloc( + infos.n(), + infos.base2k(), + infos.k(), + infos.rank_in(), + infos.dnum(), + ) } - pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> Self { - Self(GGLWESwitchingKey::alloc_with( + pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> Self { + GLWEToLWESwitchingKey(GLWESwitchingKey::alloc( n, base2k, k, @@ -159,24 +168,62 @@ impl GLWEToLWEKey> { )) } - pub fn alloc_bytes(infos: &A) -> usize + pub fn bytes_of_from_infos(infos: &A) -> usize where A: GGLWEInfos, { - debug_assert_eq!( + assert_eq!( infos.rank_out().0, 1, "rank_out > 1 is not supported for GLWEToLWESwitchingKey" ); - debug_assert_eq!( + assert_eq!( infos.dsize().0, 1, "dsize > 1 is not supported for GLWEToLWESwitchingKey" ); - GGLWESwitchingKey::alloc_bytes(infos) + Self::bytes_of( + infos.n(), + infos.base2k(), + infos.k(), + infos.rank_in(), + infos.dnum(), + ) } - pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> usize { - GGLWESwitchingKey::alloc_bytes_with(n, base2k, k, rank_in, Rank(1), dnum, Dsize(1)) + pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> usize { + GLWESwitchingKey::bytes_of(n, base2k, k, rank_in, Rank(1), dnum, Dsize(1)) + } +} + +impl GGLWEToRef for GLWEToLWESwitchingKey { + fn to_ref(&self) -> GGLWE<&[u8]> { + self.0.to_ref() + } +} + +impl GGLWEToMut for GLWEToLWESwitchingKey { + fn to_mut(&mut self) -> GGLWE<&mut [u8]> { + self.0.to_mut() + } +} + +impl GLWESwitchingKeyDegreesMut for GLWEToLWESwitchingKey { + fn input_degree(&mut self) -> &mut Degree { + &mut self.0.input_degree + } + + fn output_degree(&mut self) -> &mut Degree { + &mut self.0.output_degree + } +} + +impl GLWESwitchingKeyDegrees for GLWEToLWESwitchingKey { + fn input_degree(&self) -> &Degree { + &self.0.input_degree + } + + fn output_degree(&self) -> &Degree { + &self.0.output_degree } } diff --git a/poulpy-core/src/layouts/lwe.rs b/poulpy-core/src/layouts/lwe.rs new file mode 100644 index 0000000..6f8cdce --- /dev/null +++ b/poulpy-core/src/layouts/lwe.rs @@ -0,0 +1,202 @@ +use std::fmt; + +use poulpy_hal::{ + layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, WriterTo, Zn, ZnToMut, ZnToRef, ZnxInfos}, + source::Source, +}; + +use crate::layouts::{Base2K, Degree, TorusPrecision}; +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; + +pub trait LWEInfos { + fn n(&self) -> Degree; + fn log_n(&self) -> usize { + (u64::BITS - (self.n().as_usize() as u64 - 1).leading_zeros()) as usize + } + fn k(&self) -> TorusPrecision; + fn max_k(&self) -> TorusPrecision { + TorusPrecision(self.k().0 * self.size() as u32) + } + fn base2k(&self) -> Base2K; + fn size(&self) -> usize { + self.k().0.div_ceil(self.base2k().0) as usize + } + fn lwe_layout(&self) -> LWELayout { + LWELayout { + n: self.n(), + k: self.k(), + base2k: self.base2k(), + } + } +} + +pub trait SetLWEInfos { + fn set_k(&mut self, k: TorusPrecision); + fn set_base2k(&mut self, base2k: Base2K); +} + +#[derive(PartialEq, Eq, Copy, Clone, Debug)] +pub struct LWELayout { + pub n: Degree, + pub k: TorusPrecision, + pub base2k: Base2K, +} + +impl LWEInfos for LWELayout { + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } + + fn n(&self) -> Degree { + self.n + } +} +#[derive(PartialEq, Eq, Clone)] +pub struct LWE { + pub(crate) data: Zn, + pub(crate) k: TorusPrecision, + pub(crate) base2k: Base2K, +} + +impl LWEInfos for LWE { + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } + fn n(&self) -> Degree { + Degree(self.data.n() as u32 - 1) + } + + fn size(&self) -> usize { + self.data.size() + } +} + +impl SetLWEInfos for LWE { + fn set_base2k(&mut self, base2k: Base2K) { + self.base2k = base2k + } + + fn set_k(&mut self, k: TorusPrecision) { + self.k = k + } +} + +impl LWE { + pub fn data(&self) -> &Zn { + &self.data + } +} + +impl LWE { + pub fn data_mut(&mut self) -> &Zn { + &mut self.data + } +} + +impl fmt::Debug for LWE { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{self}") + } +} + +impl fmt::Display for LWE { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "LWE: base2k={} k={}: {}", + self.base2k().0, + self.k().0, + self.data + ) + } +} + +impl FillUniform for LWE +where + Zn: FillUniform, +{ + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { + self.data.fill_uniform(log_bound, source); + } +} + +impl LWE> { + pub fn alloc_from_infos(infos: &A) -> Self + where + A: LWEInfos, + { + Self::alloc(infos.n(), infos.base2k(), infos.k()) + } + + pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision) -> Self { + LWE { + data: Zn::alloc((n + 1).into(), 1, k.0.div_ceil(base2k.0) as usize), + k, + base2k, + } + } + + pub fn bytes_of_from_infos(infos: &A) -> usize + where + A: LWEInfos, + { + Self::bytes_of(infos.n(), infos.base2k(), infos.k()) + } + + pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision) -> usize { + Zn::bytes_of((n + 1).into(), 1, k.0.div_ceil(base2k.0) as usize) + } +} + +pub trait LWEToRef { + fn to_ref(&self) -> LWE<&[u8]>; +} + +impl LWEToRef for LWE { + fn to_ref(&self) -> LWE<&[u8]> { + LWE { + k: self.k, + base2k: self.base2k, + data: self.data.to_ref(), + } + } +} + +pub trait LWEToMut { + #[allow(dead_code)] + fn to_mut(&mut self) -> LWE<&mut [u8]>; +} + +impl LWEToMut for LWE { + fn to_mut(&mut self) -> LWE<&mut [u8]> { + LWE { + k: self.k, + base2k: self.base2k, + data: self.data.to_mut(), + } + } +} + +impl ReaderFrom for LWE { + fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { + self.k = TorusPrecision(reader.read_u32::()?); + self.base2k = Base2K(reader.read_u32::()?); + self.data.read_from(reader) + } +} + +impl WriterTo for LWE { + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + writer.write_u32::(self.k.into())?; + writer.write_u32::(self.base2k.into())?; + self.data.write_to(writer) + } +} diff --git a/poulpy-core/src/layouts/lwe_ct.rs b/poulpy-core/src/layouts/lwe_ct.rs deleted file mode 100644 index 1560ea4..0000000 --- a/poulpy-core/src/layouts/lwe_ct.rs +++ /dev/null @@ -1,263 +0,0 @@ -use std::fmt; - -use poulpy_hal::{ - layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, WriterTo, Zn, ZnToMut, ZnToRef, ZnxInfos}, - source::Source, -}; - -use crate::layouts::{Base2K, BuildError, Degree, TorusPrecision}; -use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; - -pub trait LWEInfos { - fn n(&self) -> Degree; - fn k(&self) -> TorusPrecision; - fn max_k(&self) -> TorusPrecision { - TorusPrecision(self.k().0 * self.size() as u32) - } - fn base2k(&self) -> Base2K; - fn size(&self) -> usize { - self.k().0.div_ceil(self.base2k().0) as usize - } - fn lwe_layout(&self) -> LWECiphertextLayout { - LWECiphertextLayout { - n: self.n(), - k: self.k(), - base2k: self.base2k(), - } - } -} - -#[derive(PartialEq, Eq, Copy, Clone, Debug)] -pub struct LWECiphertextLayout { - pub n: Degree, - pub k: TorusPrecision, - pub base2k: Base2K, -} - -impl LWEInfos for LWECiphertextLayout { - fn base2k(&self) -> Base2K { - self.base2k - } - - fn k(&self) -> TorusPrecision { - self.k - } - - fn n(&self) -> Degree { - self.n - } -} - -#[derive(PartialEq, Eq, Clone)] -pub struct LWECiphertext { - pub(crate) data: Zn, - pub(crate) k: TorusPrecision, - pub(crate) base2k: Base2K, -} - -impl LWEInfos for LWECiphertext { - fn base2k(&self) -> Base2K { - self.base2k - } - - fn k(&self) -> TorusPrecision { - self.k - } - fn n(&self) -> Degree { - Degree(self.data.n() as u32 - 1) - } - - fn size(&self) -> usize { - self.data.size() - } -} - -impl LWECiphertext { - pub fn data(&self) -> &Zn { - &self.data - } -} - -impl LWECiphertext { - pub fn data_mut(&mut self) -> &Zn { - &mut self.data - } -} - -impl fmt::Debug for LWECiphertext { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{self}") - } -} - -impl fmt::Display for LWECiphertext { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "LWECiphertext: base2k={} k={}: {}", - self.base2k().0, - self.k().0, - self.data - ) - } -} - -impl FillUniform for LWECiphertext -where - Zn: FillUniform, -{ - fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { - self.data.fill_uniform(log_bound, source); - } -} - -impl LWECiphertext> { - pub fn alloc(infos: &A) -> Self - where - A: LWEInfos, - { - Self::alloc_with(infos.n(), infos.base2k(), infos.k()) - } - - pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision) -> Self { - Self { - data: Zn::alloc((n + 1).into(), 1, k.0.div_ceil(base2k.0) as usize), - k, - base2k, - } - } - - pub fn alloc_bytes(infos: &A) -> usize - where - A: LWEInfos, - { - Self::alloc_bytes_with(infos.n(), infos.base2k(), infos.k()) - } - - pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision) -> usize { - Zn::alloc_bytes((n + 1).into(), 1, k.0.div_ceil(base2k.0) as usize) - } -} - -impl LWECiphertextBuilder> { - #[inline] - pub fn layout(mut self, layout: A) -> Self - where - A: LWEInfos, - { - self.data = Some(Zn::alloc((layout.n() + 1).into(), 1, layout.size())); - self.base2k = Some(layout.base2k()); - self.k = Some(layout.k()); - self - } -} - -pub struct LWECiphertextBuilder { - data: Option>, - base2k: Option, - k: Option, -} - -impl LWECiphertext { - #[inline] - pub fn builder() -> LWECiphertextBuilder { - LWECiphertextBuilder { - data: None, - base2k: None, - k: None, - } - } -} - -impl LWECiphertextBuilder { - #[inline] - pub fn data(mut self, data: Zn) -> Self { - self.data = Some(data); - self - } - #[inline] - pub fn base2k(mut self, base2k: Base2K) -> Self { - self.base2k = Some(base2k); - self - } - #[inline] - pub fn k(mut self, k: TorusPrecision) -> Self { - self.k = Some(k); - self - } - - pub fn build(self) -> Result, BuildError> { - let data: Zn = self.data.ok_or(BuildError::MissingData)?; - let base2k: Base2K = self.base2k.ok_or(BuildError::MissingBase2K)?; - let k: TorusPrecision = self.k.ok_or(BuildError::MissingK)?; - - if base2k.0 == 0 { - return Err(BuildError::ZeroBase2K); - } - - if k.0 == 0 { - return Err(BuildError::ZeroTorusPrecision); - } - - if data.n() == 0 { - return Err(BuildError::ZeroDegree); - } - - if data.cols() == 0 { - return Err(BuildError::ZeroCols); - } - - if data.size() == 0 { - return Err(BuildError::ZeroLimbs); - } - - Ok(LWECiphertext { data, base2k, k }) - } -} - -pub trait LWECiphertextToRef { - fn to_ref(&self) -> LWECiphertext<&[u8]>; -} - -impl LWECiphertextToRef for LWECiphertext { - fn to_ref(&self) -> LWECiphertext<&[u8]> { - LWECiphertext::builder() - .base2k(self.base2k()) - .k(self.k()) - .data(self.data.to_ref()) - .build() - .unwrap() - } -} - -pub trait LWECiphertextToMut { - #[allow(dead_code)] - fn to_mut(&mut self) -> LWECiphertext<&mut [u8]>; -} - -impl LWECiphertextToMut for LWECiphertext { - fn to_mut(&mut self) -> LWECiphertext<&mut [u8]> { - LWECiphertext::builder() - .base2k(self.base2k()) - .k(self.k()) - .data(self.data.to_mut()) - .build() - .unwrap() - } -} - -impl ReaderFrom for LWECiphertext { - fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { - self.k = TorusPrecision(reader.read_u32::()?); - self.base2k = Base2K(reader.read_u32::()?); - self.data.read_from(reader) - } -} - -impl WriterTo for LWECiphertext { - fn write_to(&self, writer: &mut W) -> std::io::Result<()> { - writer.write_u32::(self.k.into())?; - writer.write_u32::(self.base2k.into())?; - self.data.write_to(writer) - } -} diff --git a/poulpy-core/src/layouts/lwe_pt.rs b/poulpy-core/src/layouts/lwe_plaintext.rs similarity index 92% rename from poulpy-core/src/layouts/lwe_pt.rs rename to poulpy-core/src/layouts/lwe_plaintext.rs index e739722..966ceb5 100644 --- a/poulpy-core/src/layouts/lwe_pt.rs +++ b/poulpy-core/src/layouts/lwe_plaintext.rs @@ -53,15 +53,15 @@ impl LWEInfos for LWEPlaintext { } impl LWEPlaintext> { - pub fn alloc(infos: &A) -> Self + pub fn alloc_from_infos(infos: &A) -> Self where A: LWEInfos, { - Self::alloc_with(infos.base2k(), infos.k()) + Self::alloc(infos.base2k(), infos.k()) } - pub fn alloc_with(base2k: Base2K, k: TorusPrecision) -> Self { - Self { + pub fn alloc(base2k: Base2K, k: TorusPrecision) -> Self { + LWEPlaintext { data: Zn::alloc(1, 1, k.0.div_ceil(base2k.0) as usize), k, base2k, diff --git a/poulpy-core/src/layouts/lwe_sk.rs b/poulpy-core/src/layouts/lwe_secret.rs similarity index 70% rename from poulpy-core/src/layouts/lwe_sk.rs rename to poulpy-core/src/layouts/lwe_secret.rs index a5b7d4e..00a7849 100644 --- a/poulpy-core/src/layouts/lwe_sk.rs +++ b/poulpy-core/src/layouts/lwe_secret.rs @@ -1,9 +1,10 @@ use poulpy_hal::{ - layouts::{Data, DataMut, DataRef, ScalarZnx, ZnxInfos, ZnxView, ZnxZero}, + layouts::{Data, DataMut, DataRef, ScalarZnx, ScalarZnxToMut, ScalarZnxToRef, ZnxInfos, ZnxView, ZnxZero}, source::Source, }; use crate::{ + GetDistribution, dist::Distribution, layouts::{Base2K, Degree, LWEInfos, TorusPrecision}, }; @@ -15,13 +16,19 @@ pub struct LWESecret { impl LWESecret> { pub fn alloc(n: Degree) -> Self { - Self { + LWESecret { data: ScalarZnx::alloc(n.into(), 1), dist: Distribution::NONE, } } } +impl GetDistribution for LWESecret { + fn dist(&self) -> &Distribution { + &self.dist + } +} + impl LWESecret { pub fn raw(&self) -> &[i64] { self.data.at(0, 0) @@ -84,3 +91,29 @@ impl LWESecret { self.dist = Distribution::ZERO; } } + +pub trait LWESecretToRef { + fn to_ref(&self) -> LWESecret<&[u8]>; +} + +impl LWESecretToRef for LWESecret { + fn to_ref(&self) -> LWESecret<&[u8]> { + LWESecret { + dist: self.dist, + data: self.data.to_ref(), + } + } +} + +pub trait LWESecretToMut { + fn to_mut(&mut self) -> LWESecret<&mut [u8]>; +} + +impl LWESecretToMut for LWESecret { + fn to_mut(&mut self) -> LWESecret<&mut [u8]> { + LWESecret { + dist: self.dist, + data: self.data.to_mut(), + } + } +} diff --git a/poulpy-core/src/layouts/lwe_ksk.rs b/poulpy-core/src/layouts/lwe_switching_key.rs similarity index 67% rename from poulpy-core/src/layouts/lwe_ksk.rs rename to poulpy-core/src/layouts/lwe_switching_key.rs index 314322c..2ae6032 100644 --- a/poulpy-core/src/layouts/lwe_ksk.rs +++ b/poulpy-core/src/layouts/lwe_switching_key.rs @@ -5,7 +5,10 @@ use poulpy_hal::{ source::Source, }; -use crate::layouts::{Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWESwitchingKey, GLWEInfos, LWEInfos, Rank, TorusPrecision}; +use crate::layouts::{ + Base2K, Degree, Dnum, Dsize, GGLWE, GGLWEInfos, GGLWEToMut, GGLWEToRef, GLWEInfos, GLWESwitchingKey, GLWESwitchingKeyDegrees, + GLWESwitchingKeyDegreesMut, LWEInfos, Rank, TorusPrecision, +}; #[derive(PartialEq, Eq, Copy, Clone, Debug)] pub struct LWESwitchingKeyLayout { @@ -54,7 +57,7 @@ impl GGLWEInfos for LWESwitchingKeyLayout { } #[derive(PartialEq, Eq, Clone)] -pub struct LWESwitchingKey(pub(crate) GGLWESwitchingKey); +pub struct LWESwitchingKey(pub(crate) GLWESwitchingKey); impl LWEInfos for LWESwitchingKey { fn base2k(&self) -> Base2K { @@ -99,30 +102,30 @@ impl GGLWEInfos for LWESwitchingKey { } impl LWESwitchingKey> { - pub fn alloc(infos: &A) -> Self + pub fn alloc_from_infos(infos: &A) -> Self where A: GGLWEInfos, { - debug_assert_eq!( + assert_eq!( infos.dsize().0, 1, "dsize > 1 is not supported for LWESwitchingKey" ); - debug_assert_eq!( + assert_eq!( infos.rank_in().0, 1, "rank_in > 1 is not supported for LWESwitchingKey" ); - debug_assert_eq!( + assert_eq!( infos.rank_out().0, 1, "rank_out > 1 is not supported for LWESwitchingKey" ); - Self(GGLWESwitchingKey::alloc(infos)) + Self::alloc(infos.n(), infos.base2k(), infos.k(), infos.dnum()) } - pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> Self { - Self(GGLWESwitchingKey::alloc_with( + pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> Self { + LWESwitchingKey(GLWESwitchingKey::alloc( n, base2k, k, @@ -133,30 +136,30 @@ impl LWESwitchingKey> { )) } - pub fn alloc_bytes(infos: &A) -> usize + pub fn bytes_of_from_infos(infos: &A) -> usize where A: GGLWEInfos, { - debug_assert_eq!( + assert_eq!( infos.dsize().0, 1, "dsize > 1 is not supported for LWESwitchingKey" ); - debug_assert_eq!( + assert_eq!( infos.rank_in().0, 1, "rank_in > 1 is not supported for LWESwitchingKey" ); - debug_assert_eq!( + assert_eq!( infos.rank_out().0, 1, "rank_out > 1 is not supported for LWESwitchingKey" ); - GGLWESwitchingKey::alloc_bytes(infos) + Self::bytes_of(infos.n(), infos.base2k(), infos.k(), infos.dnum()) } - pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> usize { - GGLWESwitchingKey::alloc_bytes_with(n, base2k, k, Rank(1), Rank(1), dnum, Dsize(1)) + pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> usize { + GLWESwitchingKey::bytes_of(n, base2k, k, Rank(1), Rank(1), dnum, Dsize(1)) } } @@ -189,3 +192,35 @@ impl WriterTo for LWESwitchingKey { self.0.write_to(writer) } } + +impl GGLWEToRef for LWESwitchingKey { + fn to_ref(&self) -> GGLWE<&[u8]> { + self.0.to_ref() + } +} + +impl GGLWEToMut for LWESwitchingKey { + fn to_mut(&mut self) -> GGLWE<&mut [u8]> { + self.0.to_mut() + } +} + +impl GLWESwitchingKeyDegreesMut for LWESwitchingKey { + fn input_degree(&mut self) -> &mut Degree { + &mut self.0.input_degree + } + + fn output_degree(&mut self) -> &mut Degree { + &mut self.0.output_degree + } +} + +impl GLWESwitchingKeyDegrees for LWESwitchingKey { + fn input_degree(&self) -> &Degree { + &self.0.input_degree + } + + fn output_degree(&self) -> &Degree { + &self.0.output_degree + } +} diff --git a/poulpy-core/src/layouts/lwe_to_glwe_ksk.rs b/poulpy-core/src/layouts/lwe_to_glwe_switching_key.rs similarity index 63% rename from poulpy-core/src/layouts/lwe_to_glwe_ksk.rs rename to poulpy-core/src/layouts/lwe_to_glwe_switching_key.rs index b3ba74b..caa676d 100644 --- a/poulpy-core/src/layouts/lwe_to_glwe_ksk.rs +++ b/poulpy-core/src/layouts/lwe_to_glwe_switching_key.rs @@ -5,7 +5,10 @@ use poulpy_hal::{ source::Source, }; -use crate::layouts::{Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWESwitchingKey, GLWEInfos, LWEInfos, Rank, TorusPrecision}; +use crate::layouts::{ + Base2K, Degree, Dnum, Dsize, GGLWE, GGLWEInfos, GGLWEToMut, GGLWEToRef, GLWEInfos, GLWESwitchingKey, GLWESwitchingKeyDegrees, + GLWESwitchingKeyDegreesMut, LWEInfos, Rank, TorusPrecision, +}; #[derive(PartialEq, Eq, Copy, Clone, Debug)] pub struct LWEToGLWESwitchingKeyLayout { @@ -55,7 +58,7 @@ impl GGLWEInfos for LWEToGLWESwitchingKeyLayout { } #[derive(PartialEq, Eq, Clone)] -pub struct LWEToGLWESwitchingKey(pub(crate) GGLWESwitchingKey); +pub struct LWEToGLWESwitchingKey(pub(crate) GLWESwitchingKey); impl LWEInfos for LWEToGLWESwitchingKey { fn base2k(&self) -> Base2K { @@ -129,25 +132,32 @@ impl WriterTo for LWEToGLWESwitchingKey { } impl LWEToGLWESwitchingKey> { - pub fn alloc(infos: &A) -> Self + pub fn alloc_from_infos(infos: &A) -> Self where A: GGLWEInfos, { - debug_assert_eq!( + assert_eq!( infos.rank_in().0, 1, "rank_in > 1 is not supported for LWEToGLWESwitchingKey" ); - debug_assert_eq!( + assert_eq!( infos.dsize().0, 1, "dsize > 1 is not supported for LWEToGLWESwitchingKey" ); - Self(GGLWESwitchingKey::alloc(infos)) + + Self::alloc( + infos.n(), + infos.base2k(), + infos.k(), + infos.rank_out(), + infos.dnum(), + ) } - pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> Self { - Self(GGLWESwitchingKey::alloc_with( + pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> Self { + LWEToGLWESwitchingKey(GLWESwitchingKey::alloc( n, base2k, k, @@ -158,24 +168,62 @@ impl LWEToGLWESwitchingKey> { )) } - pub fn alloc_bytes(infos: &A) -> usize + pub fn bytes_of_from_infos(infos: &A) -> usize where A: GGLWEInfos, { - debug_assert_eq!( + assert_eq!( infos.rank_in().0, 1, "rank_in > 1 is not supported for LWEToGLWESwitchingKey" ); - debug_assert_eq!( + assert_eq!( infos.dsize().0, 1, "dsize > 1 is not supported for LWEToGLWESwitchingKey" ); - GGLWESwitchingKey::alloc_bytes(infos) + Self::bytes_of( + infos.n(), + infos.base2k(), + infos.k(), + infos.rank_out(), + infos.dnum(), + ) } - pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, dnum: Dnum, rank_out: Rank) -> usize { - GGLWESwitchingKey::alloc_bytes_with(n, base2k, k, Rank(1), rank_out, dnum, Dsize(1)) + pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> usize { + GLWESwitchingKey::bytes_of(n, base2k, k, Rank(1), rank_out, dnum, Dsize(1)) + } +} + +impl GGLWEToRef for LWEToGLWESwitchingKey { + fn to_ref(&self) -> GGLWE<&[u8]> { + self.0.to_ref() + } +} + +impl GGLWEToMut for LWEToGLWESwitchingKey { + fn to_mut(&mut self) -> GGLWE<&mut [u8]> { + self.0.to_mut() + } +} + +impl GLWESwitchingKeyDegreesMut for LWEToGLWESwitchingKey { + fn input_degree(&mut self) -> &mut Degree { + &mut self.0.input_degree + } + + fn output_degree(&mut self) -> &mut Degree { + &mut self.0.output_degree + } +} + +impl GLWESwitchingKeyDegrees for LWEToGLWESwitchingKey { + fn input_degree(&self) -> &Degree { + &self.0.input_degree + } + + fn output_degree(&self) -> &Degree { + &self.0.output_degree } } diff --git a/poulpy-core/src/layouts/mod.rs b/poulpy-core/src/layouts/mod.rs index 2b13751..1168ac8 100644 --- a/poulpy-core/src/layouts/mod.rs +++ b/poulpy-core/src/layouts/mod.rs @@ -1,52 +1,50 @@ -mod gglwe_atk; -mod gglwe_ct; -mod gglwe_ksk; -mod gglwe_tsk; -mod ggsw_ct; -mod glwe_ct; -mod glwe_pk; -mod glwe_pt; -mod glwe_sk; -mod glwe_to_lwe_ksk; -mod lwe_ct; -mod lwe_ksk; -mod lwe_pt; -mod lwe_sk; -mod lwe_to_glwe_ksk; +mod gglwe; +mod ggsw; +mod glwe; +mod glwe_automorphism_key; +mod glwe_plaintext; +mod glwe_public_key; +mod glwe_secret; +mod glwe_switching_key; +mod glwe_tensor_key; +mod glwe_to_lwe_switching_key; +mod lwe; +mod lwe_plaintext; +mod lwe_secret; +mod lwe_switching_key; +mod lwe_to_glwe_switching_key; pub mod compressed; pub mod prepared; -pub use gglwe_atk::*; -pub use gglwe_ct::*; -pub use gglwe_ksk::*; -pub use gglwe_tsk::*; -pub use ggsw_ct::*; -pub use glwe_ct::*; -pub use glwe_pk::*; -pub use glwe_pt::*; -pub use glwe_sk::*; -pub use glwe_to_lwe_ksk::*; -pub use lwe_ct::*; -pub use lwe_ksk::*; -pub use lwe_pt::*; -pub use lwe_sk::*; -pub use lwe_to_glwe_ksk::*; +pub use compressed::*; +pub use gglwe::*; +pub use ggsw::*; +pub use glwe::*; +pub use glwe_automorphism_key::*; +pub use glwe_plaintext::*; +pub use glwe_public_key::*; +pub use glwe_secret::*; +pub use glwe_switching_key::*; +pub use glwe_tensor_key::*; +pub use glwe_to_lwe_switching_key::*; +pub use lwe::*; +pub use lwe_plaintext::*; +pub use lwe_secret::*; +pub use lwe_switching_key::*; +pub use lwe_to_glwe_switching_key::*; +pub use prepared::*; -#[derive(Debug)] -pub enum BuildError { - MissingData, - MissingBase2K, - MissingK, - MissingDigits, - ZeroDegree, - NonPowerOfTwoDegree, - ZeroBase2K, - ZeroTorusPrecision, - ZeroCols, - ZeroLimbs, - ZeroRank, - ZeroDigits, +use poulpy_hal::layouts::{Backend, Module}; + +pub trait GetDegree { + fn ring_degree(&self) -> Degree; +} + +impl GetDegree for Module { + fn ring_degree(&self) -> Degree { + Self::n(self).into() + } } /// Newtype over `u32` with arithmetic and comparisons against same type and `u32`. diff --git a/poulpy-core/src/layouts/prepared/gglwe.rs b/poulpy-core/src/layouts/prepared/gglwe.rs new file mode 100644 index 0000000..792adbf --- /dev/null +++ b/poulpy-core/src/layouts/prepared/gglwe.rs @@ -0,0 +1,281 @@ +use poulpy_hal::{ + api::{VmpPMatAlloc, VmpPMatBytesOf, VmpPrepare, VmpPrepareTmpBytes}, + layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VmpPMat, VmpPMatToMut, VmpPMatToRef, ZnxInfos}, +}; + +use crate::layouts::{ + Base2K, Degree, Dnum, Dsize, GGLWE, GGLWEInfos, GGLWEToRef, GLWEInfos, GetDegree, LWEInfos, Rank, TorusPrecision, +}; + +#[derive(PartialEq, Eq)] +pub struct GGLWEPrepared { + pub(crate) data: VmpPMat, + pub(crate) k: TorusPrecision, + pub(crate) base2k: Base2K, + pub(crate) dsize: Dsize, +} + +impl LWEInfos for GGLWEPrepared { + fn n(&self) -> Degree { + Degree(self.data.n() as u32) + } + + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } + + fn size(&self) -> usize { + self.data.size() + } +} + +impl GLWEInfos for GGLWEPrepared { + fn rank(&self) -> Rank { + self.rank_out() + } +} + +impl GGLWEInfos for GGLWEPrepared { + fn rank_in(&self) -> Rank { + Rank(self.data.cols_in() as u32) + } + + fn rank_out(&self) -> Rank { + Rank(self.data.cols_out() as u32 - 1) + } + + fn dsize(&self) -> Dsize { + self.dsize + } + + fn dnum(&self) -> Dnum { + Dnum(self.data.rows() as u32) + } +} + +pub trait GGLWEPreparedFactory +where + Self: GetDegree + VmpPMatAlloc + VmpPMatBytesOf + VmpPrepare + VmpPrepareTmpBytes, +{ + fn alloc_gglwe_prepared( + &self, + base2k: Base2K, + k: TorusPrecision, + rank_in: Rank, + rank_out: Rank, + dnum: Dnum, + dsize: Dsize, + ) -> GGLWEPrepared, BE> { + let size: usize = k.0.div_ceil(base2k.0) as usize; + debug_assert!( + size as u32 > dsize.0, + "invalid gglwe: ceil(k/base2k): {size} <= dsize: {}", + dsize.0 + ); + + assert!( + dnum.0 * dsize.0 <= size as u32, + "invalid gglwe: dnum: {} * dsize:{} > ceil(k/base2k): {size}", + dnum.0, + dsize.0, + ); + + GGLWEPrepared { + data: self.vmp_pmat_alloc(dnum.into(), rank_in.into(), (rank_out + 1).into(), size), + k, + base2k, + dsize, + } + } + + fn alloc_gglwe_prepared_from_infos(&self, infos: &A) -> GGLWEPrepared, BE> + where + A: GGLWEInfos, + { + assert_eq!(self.ring_degree(), infos.n()); + self.alloc_gglwe_prepared( + infos.base2k(), + infos.k(), + infos.rank_in(), + infos.rank_out(), + infos.dnum(), + infos.dsize(), + ) + } + + fn bytes_of_gglwe_prepared( + &self, + base2k: Base2K, + k: TorusPrecision, + rank_in: Rank, + rank_out: Rank, + dnum: Dnum, + dsize: Dsize, + ) -> usize { + let size: usize = k.0.div_ceil(base2k.0) as usize; + debug_assert!( + size as u32 > dsize.0, + "invalid gglwe: ceil(k/base2k): {size} <= dsize: {}", + dsize.0 + ); + + assert!( + dnum.0 * dsize.0 <= size as u32, + "invalid gglwe: dnum: {} * dsize:{} > ceil(k/base2k): {size}", + dnum.0, + dsize.0, + ); + + self.bytes_of_vmp_pmat(dnum.into(), rank_in.into(), (rank_out + 1).into(), size) + } + + fn bytes_of_gglwe_prepared_from_infos(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + assert_eq!(self.ring_degree(), infos.n()); + self.bytes_of_gglwe_prepared( + infos.base2k(), + infos.k(), + infos.rank_in(), + infos.rank_out(), + infos.dnum(), + infos.dsize(), + ) + } + + fn prepare_gglwe_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + self.vmp_prepare_tmp_bytes( + infos.dnum().into(), + infos.rank_in().into(), + (infos.rank() + 1).into(), + infos.size(), + ) + } + + fn prepare_gglwe(&self, res: &mut R, other: &O, scratch: &mut Scratch) + where + R: GGLWEPreparedToMut, + O: GGLWEToRef, + { + let mut res: GGLWEPrepared<&mut [u8], BE> = res.to_mut(); + let other: GGLWE<&[u8]> = other.to_ref(); + + assert_eq!(res.n(), self.ring_degree()); + assert_eq!(other.n(), self.ring_degree()); + assert_eq!(res.base2k, other.base2k); + assert_eq!(res.k, other.k); + assert_eq!(res.dsize, other.dsize); + + self.vmp_prepare(&mut res.data, &other.data, scratch); + } +} + +impl GGLWEPreparedFactory for Module where + Module: GetDegree + VmpPMatAlloc + VmpPMatBytesOf + VmpPrepare + VmpPrepareTmpBytes +{ +} + +impl GGLWEPrepared, B> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: GGLWEInfos, + M: GGLWEPreparedFactory, + { + module.alloc_gglwe_prepared_from_infos(infos) + } + + pub fn alloc( + module: &M, + base2k: Base2K, + k: TorusPrecision, + rank_in: Rank, + rank_out: Rank, + dnum: Dnum, + dsize: Dsize, + ) -> Self + where + M: GGLWEPreparedFactory, + { + module.alloc_gglwe_prepared(base2k, k, rank_in, rank_out, dnum, dsize) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: GGLWEPreparedFactory, + { + module.bytes_of_gglwe_prepared_from_infos(infos) + } + + pub fn bytes_of( + module: &M, + base2k: Base2K, + k: TorusPrecision, + rank_in: Rank, + rank_out: Rank, + dnum: Dnum, + dsize: Dsize, + ) -> usize + where + M: GGLWEPreparedFactory, + { + module.bytes_of_gglwe_prepared(base2k, k, rank_in, rank_out, dnum, dsize) + } +} + +impl GGLWEPrepared { + pub fn prepare(&mut self, module: &M, other: &O, scratch: &mut Scratch) + where + O: GGLWEToRef, + M: GGLWEPreparedFactory, + { + module.prepare_gglwe(self, other, scratch); + } +} + +impl GGLWEPrepared, B> { + pub fn prepare_tmp_bytes(&self, module: &M) -> usize + where + M: GGLWEPreparedFactory, + { + module.prepare_gglwe_tmp_bytes(self) + } +} + +pub trait GGLWEPreparedToMut { + fn to_mut(&mut self) -> GGLWEPrepared<&mut [u8], B>; +} + +impl GGLWEPreparedToMut for GGLWEPrepared { + fn to_mut(&mut self) -> GGLWEPrepared<&mut [u8], B> { + GGLWEPrepared { + k: self.k, + base2k: self.base2k, + dsize: self.dsize, + data: self.data.to_mut(), + } + } +} + +pub trait GGLWEPreparedToRef { + fn to_ref(&self) -> GGLWEPrepared<&[u8], B>; +} + +impl GGLWEPreparedToRef for GGLWEPrepared { + fn to_ref(&self) -> GGLWEPrepared<&[u8], B> { + GGLWEPrepared { + k: self.k, + base2k: self.base2k, + dsize: self.dsize, + data: self.data.to_ref(), + } + } +} diff --git a/poulpy-core/src/layouts/prepared/gglwe_atk.rs b/poulpy-core/src/layouts/prepared/gglwe_atk.rs deleted file mode 100644 index 594aa0a..0000000 --- a/poulpy-core/src/layouts/prepared/gglwe_atk.rs +++ /dev/null @@ -1,141 +0,0 @@ -use poulpy_hal::{ - api::{VmpPMatAlloc, VmpPMatAllocBytes, VmpPrepare}, - layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}, -}; - -use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWEAutomorphismKey, GGLWEInfos, GLWEInfos, LWEInfos, Rank, TorusPrecision, - prepared::{GGLWESwitchingKeyPrepared, Prepare, PrepareAlloc, PrepareScratchSpace}, -}; - -#[derive(PartialEq, Eq)] -pub struct GGLWEAutomorphismKeyPrepared { - pub(crate) key: GGLWESwitchingKeyPrepared, - pub(crate) p: i64, -} - -impl GGLWEAutomorphismKeyPrepared { - pub fn p(&self) -> i64 { - self.p - } -} - -impl LWEInfos for GGLWEAutomorphismKeyPrepared { - fn n(&self) -> Degree { - self.key.n() - } - - fn base2k(&self) -> Base2K { - self.key.base2k() - } - - fn k(&self) -> TorusPrecision { - self.key.k() - } - - fn size(&self) -> usize { - self.key.size() - } -} - -impl GLWEInfos for GGLWEAutomorphismKeyPrepared { - fn rank(&self) -> Rank { - self.rank_out() - } -} - -impl GGLWEInfos for GGLWEAutomorphismKeyPrepared { - fn rank_in(&self) -> Rank { - self.key.rank_in() - } - - fn rank_out(&self) -> Rank { - self.key.rank_out() - } - - fn dsize(&self) -> Dsize { - self.key.dsize() - } - - fn dnum(&self) -> Dnum { - self.key.dnum() - } -} - -impl GGLWEAutomorphismKeyPrepared, B> { - pub fn alloc(module: &Module, infos: &A) -> Self - where - A: GGLWEInfos, - Module: VmpPMatAlloc, - { - assert_eq!( - infos.rank_in(), - infos.rank_out(), - "rank_in != rank_out is not supported for GGLWEAutomorphismKeyPrepared" - ); - GGLWEAutomorphismKeyPrepared::, B> { - key: GGLWESwitchingKeyPrepared::alloc(module, infos), - p: 0, - } - } - - pub fn alloc_with(module: &Module, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self - where - Module: VmpPMatAlloc, - { - GGLWEAutomorphismKeyPrepared { - key: GGLWESwitchingKeyPrepared::alloc_with(module, base2k, k, rank, rank, dnum, dsize), - p: 0, - } - } - - pub fn alloc_bytes(module: &Module, infos: &A) -> usize - where - A: GGLWEInfos, - Module: VmpPMatAllocBytes, - { - assert_eq!( - infos.rank_in(), - infos.rank_out(), - "rank_in != rank_out is not supported for GGLWEAutomorphismKeyPrepared" - ); - GGLWESwitchingKeyPrepared::alloc_bytes(module, infos) - } - - pub fn alloc_bytes_with(module: &Module, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize - where - Module: VmpPMatAllocBytes, - { - GGLWESwitchingKeyPrepared::alloc_bytes_with(module, base2k, k, rank, rank, dnum, dsize) - } -} - -impl PrepareScratchSpace for GGLWEAutomorphismKeyPrepared, B> -where - GGLWESwitchingKeyPrepared, B>: PrepareScratchSpace, -{ - fn prepare_scratch_space(module: &Module, infos: &A) -> usize { - GGLWESwitchingKeyPrepared::prepare_scratch_space(module, infos) - } -} - -impl Prepare> for GGLWEAutomorphismKeyPrepared -where - Module: VmpPrepare, -{ - fn prepare(&mut self, module: &Module, other: &GGLWEAutomorphismKey, scratch: &mut Scratch) { - self.key.prepare(module, &other.key, scratch); - self.p = other.p; - } -} - -impl PrepareAlloc, B>> for GGLWEAutomorphismKey -where - Module: VmpPMatAlloc + VmpPrepare, -{ - fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> GGLWEAutomorphismKeyPrepared, B> { - let mut atk_prepared: GGLWEAutomorphismKeyPrepared, B> = GGLWEAutomorphismKeyPrepared::alloc(module, self); - atk_prepared.prepare(module, self, scratch); - atk_prepared - } -} diff --git a/poulpy-core/src/layouts/prepared/gglwe_ct.rs b/poulpy-core/src/layouts/prepared/gglwe_ct.rs deleted file mode 100644 index 4f22e6e..0000000 --- a/poulpy-core/src/layouts/prepared/gglwe_ct.rs +++ /dev/null @@ -1,298 +0,0 @@ -use poulpy_hal::{ - api::{VmpPMatAlloc, VmpPMatAllocBytes, VmpPrepare, VmpPrepareTmpBytes}, - layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VmpPMat, ZnxInfos}, - oep::VmpPMatAllocBytesImpl, -}; - -use crate::layouts::{ - Base2K, BuildError, Degree, Dnum, Dsize, GGLWECiphertext, GGLWEInfos, GLWEInfos, LWEInfos, Rank, TorusPrecision, - prepared::{Prepare, PrepareAlloc, PrepareScratchSpace}, -}; - -#[derive(PartialEq, Eq)] -pub struct GGLWECiphertextPrepared { - pub(crate) data: VmpPMat, - pub(crate) k: TorusPrecision, - pub(crate) base2k: Base2K, - pub(crate) dsize: Dsize, -} - -impl LWEInfos for GGLWECiphertextPrepared { - fn n(&self) -> Degree { - Degree(self.data.n() as u32) - } - - fn base2k(&self) -> Base2K { - self.base2k - } - - fn k(&self) -> TorusPrecision { - self.k - } - - fn size(&self) -> usize { - self.data.size() - } -} - -impl GLWEInfos for GGLWECiphertextPrepared { - fn rank(&self) -> Rank { - self.rank_out() - } -} - -impl GGLWEInfos for GGLWECiphertextPrepared { - fn rank_in(&self) -> Rank { - Rank(self.data.cols_in() as u32) - } - - fn rank_out(&self) -> Rank { - Rank(self.data.cols_out() as u32 - 1) - } - - fn dsize(&self) -> Dsize { - self.dsize - } - - fn dnum(&self) -> Dnum { - Dnum(self.data.rows() as u32) - } -} - -pub struct GGLWECiphertextPreparedBuilder { - data: Option>, - base2k: Option, - k: Option, - dsize: Option, -} - -impl GGLWECiphertextPrepared { - #[inline] - pub fn builder() -> GGLWECiphertextPreparedBuilder { - GGLWECiphertextPreparedBuilder { - data: None, - base2k: None, - k: None, - dsize: None, - } - } -} - -impl GGLWECiphertextPreparedBuilder, B> { - #[inline] - pub fn layout(mut self, infos: &A) -> Self - where - A: GGLWEInfos, - B: VmpPMatAllocBytesImpl, - { - self.data = Some(VmpPMat::alloc( - infos.n().into(), - infos.dnum().into(), - infos.rank_in().into(), - (infos.rank_out() + 1).into(), - infos.size(), - )); - self.base2k = Some(infos.base2k()); - self.k = Some(infos.k()); - self.dsize = Some(infos.dsize()); - self - } -} - -impl GGLWECiphertextPreparedBuilder { - #[inline] - pub fn data(mut self, data: VmpPMat) -> Self { - self.data = Some(data); - self - } - #[inline] - pub fn base2k(mut self, base2k: Base2K) -> Self { - self.base2k = Some(base2k); - self - } - #[inline] - pub fn k(mut self, k: TorusPrecision) -> Self { - self.k = Some(k); - self - } - - #[inline] - pub fn dsize(mut self, dsize: Dsize) -> Self { - self.dsize = Some(dsize); - self - } - - pub fn build(self) -> Result, BuildError> { - let data: VmpPMat = self.data.ok_or(BuildError::MissingData)?; - let base2k: Base2K = self.base2k.ok_or(BuildError::MissingBase2K)?; - let k: TorusPrecision = self.k.ok_or(BuildError::MissingK)?; - let dsize: Dsize = self.dsize.ok_or(BuildError::MissingDigits)?; - - if base2k == 0_u32 { - return Err(BuildError::ZeroBase2K); - } - - if dsize == 0_u32 { - return Err(BuildError::ZeroBase2K); - } - - if k == 0_u32 { - return Err(BuildError::ZeroTorusPrecision); - } - - if data.n() == 0 { - return Err(BuildError::ZeroDegree); - } - - if data.cols() == 0 { - return Err(BuildError::ZeroCols); - } - - if data.size() == 0 { - return Err(BuildError::ZeroLimbs); - } - - Ok(GGLWECiphertextPrepared { - data, - base2k, - k, - dsize, - }) - } -} - -impl GGLWECiphertextPrepared, B> { - pub fn alloc(module: &Module, infos: &A) -> Self - where - A: GGLWEInfos, - Module: VmpPMatAlloc, - { - debug_assert_eq!(module.n(), infos.n().0 as usize, "module.n() != infos.n()"); - Self::alloc_with( - module, - infos.base2k(), - infos.k(), - infos.rank_in(), - infos.rank_out(), - infos.dnum(), - infos.dsize(), - ) - } - - pub fn alloc_with( - module: &Module, - base2k: Base2K, - k: TorusPrecision, - rank_in: Rank, - rank_out: Rank, - dnum: Dnum, - dsize: Dsize, - ) -> Self - where - Module: VmpPMatAlloc, - { - let size: usize = k.0.div_ceil(base2k.0) as usize; - debug_assert!( - size as u32 > dsize.0, - "invalid gglwe: ceil(k/base2k): {size} <= dsize: {}", - dsize.0 - ); - - assert!( - dnum.0 * dsize.0 <= size as u32, - "invalid gglwe: dnum: {} * dsize:{} > ceil(k/base2k): {size}", - dnum.0, - dsize.0, - ); - - Self { - data: module.vmp_pmat_alloc(dnum.into(), rank_in.into(), (rank_out + 1).into(), size), - k, - base2k, - dsize, - } - } - - pub fn alloc_bytes(module: &Module, infos: &A) -> usize - where - A: GGLWEInfos, - Module: VmpPMatAllocBytes, - { - debug_assert_eq!(module.n(), infos.n().0 as usize, "module.n() != infos.n()"); - Self::alloc_bytes_with( - module, - infos.base2k(), - infos.k(), - infos.rank_in(), - infos.rank_out(), - infos.dnum(), - infos.dsize(), - ) - } - - pub fn alloc_bytes_with( - module: &Module, - base2k: Base2K, - k: TorusPrecision, - rank_in: Rank, - rank_out: Rank, - dnum: Dnum, - dsize: Dsize, - ) -> usize - where - Module: VmpPMatAllocBytes, - { - let size: usize = k.0.div_ceil(base2k.0) as usize; - debug_assert!( - size as u32 > dsize.0, - "invalid gglwe: ceil(k/base2k): {size} <= dsize: {}", - dsize.0 - ); - - assert!( - dnum.0 * dsize.0 <= size as u32, - "invalid gglwe: dnum: {} * dsize:{} > ceil(k/base2k): {size}", - dnum.0, - dsize.0, - ); - - module.vmp_pmat_alloc_bytes(dnum.into(), rank_in.into(), (rank_out + 1).into(), size) - } -} - -impl PrepareScratchSpace for GGLWECiphertextPrepared, B> -where - Module: VmpPrepareTmpBytes, -{ - fn prepare_scratch_space(module: &Module, infos: &A) -> usize { - module.vmp_prepare_tmp_bytes( - infos.dnum().into(), - infos.rank_in().into(), - (infos.rank() + 1).into(), - infos.size(), - ) - } -} - -impl Prepare> for GGLWECiphertextPrepared -where - Module: VmpPrepare, -{ - fn prepare(&mut self, module: &Module, other: &GGLWECiphertext, scratch: &mut Scratch) { - module.vmp_prepare(&mut self.data, &other.data, scratch); - self.k = other.k; - self.base2k = other.base2k; - self.dsize = other.dsize; - } -} - -impl PrepareAlloc, B>> for GGLWECiphertext -where - Module: VmpPMatAlloc + VmpPrepare, -{ - fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> GGLWECiphertextPrepared, B> { - let mut atk_prepared: GGLWECiphertextPrepared, B> = GGLWECiphertextPrepared::alloc(module, self); - atk_prepared.prepare(module, self, scratch); - atk_prepared - } -} diff --git a/poulpy-core/src/layouts/prepared/gglwe_ksk.rs b/poulpy-core/src/layouts/prepared/gglwe_ksk.rs deleted file mode 100644 index c9110c1..0000000 --- a/poulpy-core/src/layouts/prepared/gglwe_ksk.rs +++ /dev/null @@ -1,147 +0,0 @@ -use poulpy_hal::{ - api::{VmpPMatAlloc, VmpPMatAllocBytes, VmpPrepare}, - layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}, -}; - -use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWESwitchingKey, GLWEInfos, LWEInfos, Rank, TorusPrecision, - prepared::{GGLWECiphertextPrepared, Prepare, PrepareAlloc, PrepareScratchSpace}, -}; - -#[derive(PartialEq, Eq)] -pub struct GGLWESwitchingKeyPrepared { - pub(crate) key: GGLWECiphertextPrepared, - pub(crate) sk_in_n: usize, // Degree of sk_in - pub(crate) sk_out_n: usize, // Degree of sk_out -} - -impl LWEInfos for GGLWESwitchingKeyPrepared { - fn n(&self) -> Degree { - self.key.n() - } - - fn base2k(&self) -> Base2K { - self.key.base2k() - } - - fn k(&self) -> TorusPrecision { - self.key.k() - } - - fn size(&self) -> usize { - self.key.size() - } -} - -impl GLWEInfos for GGLWESwitchingKeyPrepared { - fn rank(&self) -> Rank { - self.rank_out() - } -} - -impl GGLWEInfos for GGLWESwitchingKeyPrepared { - fn rank_in(&self) -> Rank { - self.key.rank_in() - } - - fn rank_out(&self) -> Rank { - self.key.rank_out() - } - - fn dsize(&self) -> Dsize { - self.key.dsize() - } - - fn dnum(&self) -> Dnum { - self.key.dnum() - } -} - -impl GGLWESwitchingKeyPrepared, B> { - pub fn alloc(module: &Module, infos: &A) -> Self - where - A: GGLWEInfos, - Module: VmpPMatAlloc, - { - debug_assert_eq!(module.n() as u32, infos.n(), "module.n() != infos.n()"); - GGLWESwitchingKeyPrepared::, B> { - key: GGLWECiphertextPrepared::alloc(module, infos), - sk_in_n: 0, - sk_out_n: 0, - } - } - - pub fn alloc_with( - module: &Module, - base2k: Base2K, - k: TorusPrecision, - rank_in: Rank, - rank_out: Rank, - dnum: Dnum, - dsize: Dsize, - ) -> Self - where - Module: VmpPMatAlloc, - { - GGLWESwitchingKeyPrepared::, B> { - key: GGLWECiphertextPrepared::alloc_with(module, base2k, k, rank_in, rank_out, dnum, dsize), - sk_in_n: 0, - sk_out_n: 0, - } - } - - pub fn alloc_bytes(module: &Module, infos: &A) -> usize - where - A: GGLWEInfos, - Module: VmpPMatAllocBytes, - { - debug_assert_eq!(module.n() as u32, infos.n(), "module.n() != infos.n()"); - GGLWECiphertextPrepared::alloc_bytes(module, infos) - } - - pub fn alloc_bytes_with( - module: &Module, - base2k: Base2K, - k: TorusPrecision, - rank_in: Rank, - rank_out: Rank, - dnum: Dnum, - dsize: Dsize, - ) -> usize - where - Module: VmpPMatAllocBytes, - { - GGLWECiphertextPrepared::alloc_bytes_with(module, base2k, k, rank_in, rank_out, dnum, dsize) - } -} - -impl PrepareScratchSpace for GGLWESwitchingKeyPrepared, B> -where - GGLWECiphertextPrepared, B>: PrepareScratchSpace, -{ - fn prepare_scratch_space(module: &Module, infos: &A) -> usize { - GGLWECiphertextPrepared::prepare_scratch_space(module, infos) - } -} - -impl Prepare> for GGLWESwitchingKeyPrepared -where - Module: VmpPrepare, -{ - fn prepare(&mut self, module: &Module, other: &GGLWESwitchingKey, scratch: &mut Scratch) { - self.key.prepare(module, &other.key, scratch); - self.sk_in_n = other.sk_in_n; - self.sk_out_n = other.sk_out_n; - } -} - -impl PrepareAlloc, B>> for GGLWESwitchingKey -where - Module: VmpPMatAlloc + VmpPrepare, -{ - fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> GGLWESwitchingKeyPrepared, B> { - let mut atk_prepared: GGLWESwitchingKeyPrepared, B> = GGLWESwitchingKeyPrepared::alloc(module, self); - atk_prepared.prepare(module, self, scratch); - atk_prepared - } -} diff --git a/poulpy-core/src/layouts/prepared/gglwe_tsk.rs b/poulpy-core/src/layouts/prepared/gglwe_tsk.rs deleted file mode 100644 index 4343e30..0000000 --- a/poulpy-core/src/layouts/prepared/gglwe_tsk.rs +++ /dev/null @@ -1,190 +0,0 @@ -use poulpy_hal::{ - api::{VmpPMatAlloc, VmpPMatAllocBytes, VmpPrepare}, - layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}, -}; - -use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWETensorKey, GLWEInfos, LWEInfos, Rank, TorusPrecision, - prepared::{GGLWESwitchingKeyPrepared, Prepare, PrepareAlloc, PrepareScratchSpace}, -}; - -#[derive(PartialEq, Eq)] -pub struct GGLWETensorKeyPrepared { - pub(crate) keys: Vec>, -} - -impl LWEInfos for GGLWETensorKeyPrepared { - fn n(&self) -> Degree { - self.keys[0].n() - } - - fn base2k(&self) -> Base2K { - self.keys[0].base2k() - } - - fn k(&self) -> TorusPrecision { - self.keys[0].k() - } - - fn size(&self) -> usize { - self.keys[0].size() - } -} - -impl GLWEInfos for GGLWETensorKeyPrepared { - fn rank(&self) -> Rank { - self.rank_out() - } -} - -impl GGLWEInfos for GGLWETensorKeyPrepared { - fn rank_in(&self) -> Rank { - self.rank_out() - } - - fn rank_out(&self) -> Rank { - self.keys[0].rank_out() - } - - fn dsize(&self) -> Dsize { - self.keys[0].dsize() - } - - fn dnum(&self) -> Dnum { - self.keys[0].dnum() - } -} - -impl GGLWETensorKeyPrepared, B> { - pub fn alloc(module: &Module, infos: &A) -> Self - where - A: GGLWEInfos, - Module: VmpPMatAlloc, - { - assert_eq!( - infos.rank_in(), - infos.rank_out(), - "rank_in != rank_out is not supported for GGLWETensorKeyPrepared" - ); - Self::alloc_with( - module, - infos.base2k(), - infos.k(), - infos.dnum(), - infos.dsize(), - infos.rank_out(), - ) - } - - pub fn alloc_with(module: &Module, base2k: Base2K, k: TorusPrecision, dnum: Dnum, dsize: Dsize, rank: Rank) -> Self - where - Module: VmpPMatAlloc, - { - let mut keys: Vec, B>> = Vec::new(); - let pairs: u32 = (((rank.0 + 1) * rank.0) >> 1).max(1); - (0..pairs).for_each(|_| { - keys.push(GGLWESwitchingKeyPrepared::alloc_with( - module, - base2k, - k, - Rank(1), - rank, - dnum, - dsize, - )); - }); - Self { keys } - } - - pub fn alloc_bytes(module: &Module, infos: &A) -> usize - where - A: GGLWEInfos, - Module: VmpPMatAllocBytes, - { - assert_eq!( - infos.rank_in(), - infos.rank_out(), - "rank_in != rank_out is not supported for GGLWETensorKey" - ); - let rank_out: usize = infos.rank_out().into(); - let pairs: usize = (((rank_out + 1) * rank_out) >> 1).max(1); - pairs - * GGLWESwitchingKeyPrepared::alloc_bytes_with( - module, - infos.base2k(), - infos.k(), - Rank(1), - infos.rank_out(), - infos.dnum(), - infos.dsize(), - ) - } - - pub fn alloc_bytes_with(module: &Module, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize - where - Module: VmpPMatAllocBytes, - { - let pairs: usize = (((rank.0 + 1) * rank.0) >> 1).max(1) as usize; - pairs * GGLWESwitchingKeyPrepared::alloc_bytes_with(module, base2k, k, Rank(1), rank, dnum, dsize) - } -} - -impl GGLWETensorKeyPrepared { - // Returns a mutable reference to GLWESwitchingKey_{s}(s[i] * s[j]) - pub fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GGLWESwitchingKeyPrepared { - if i > j { - std::mem::swap(&mut i, &mut j); - }; - let rank: usize = self.rank_out().into(); - &mut self.keys[i * rank + j - (i * (i + 1) / 2)] - } -} - -impl GGLWETensorKeyPrepared { - // Returns a reference to GLWESwitchingKey_{s}(s[i] * s[j]) - pub fn at(&self, mut i: usize, mut j: usize) -> &GGLWESwitchingKeyPrepared { - if i > j { - std::mem::swap(&mut i, &mut j); - }; - let rank: usize = self.rank_out().into(); - &self.keys[i * rank + j - (i * (i + 1) / 2)] - } -} - -impl PrepareScratchSpace for GGLWETensorKeyPrepared, B> -where - GGLWESwitchingKeyPrepared, B>: PrepareScratchSpace, -{ - fn prepare_scratch_space(module: &Module, infos: &A) -> usize { - GGLWESwitchingKeyPrepared::prepare_scratch_space(module, infos) - } -} - -impl Prepare> for GGLWETensorKeyPrepared -where - Module: VmpPrepare, -{ - fn prepare(&mut self, module: &Module, other: &GGLWETensorKey, scratch: &mut Scratch) { - #[cfg(debug_assertions)] - { - assert_eq!(self.keys.len(), other.keys.len()); - } - self.keys - .iter_mut() - .zip(other.keys.iter()) - .for_each(|(a, b)| { - a.prepare(module, b, scratch); - }); - } -} - -impl PrepareAlloc, B>> for GGLWETensorKey -where - Module: VmpPMatAlloc + VmpPrepare, -{ - fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> GGLWETensorKeyPrepared, B> { - let mut tsk_prepared: GGLWETensorKeyPrepared, B> = GGLWETensorKeyPrepared::alloc(module, self); - tsk_prepared.prepare(module, self, scratch); - tsk_prepared - } -} diff --git a/poulpy-core/src/layouts/prepared/ggsw.rs b/poulpy-core/src/layouts/prepared/ggsw.rs new file mode 100644 index 0000000..3115980 --- /dev/null +++ b/poulpy-core/src/layouts/prepared/ggsw.rs @@ -0,0 +1,256 @@ +use poulpy_hal::{ + api::{VmpPMatAlloc, VmpPMatBytesOf, VmpPrepare, VmpPrepareTmpBytes}, + layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VmpPMat, VmpPMatToMut, VmpPMatToRef, ZnxInfos}, +}; + +use crate::layouts::{ + Base2K, Degree, Dnum, Dsize, GGSW, GGSWInfos, GGSWToRef, GLWEInfos, GetDegree, LWEInfos, Rank, TorusPrecision, +}; + +#[derive(PartialEq, Eq)] +pub struct GGSWPrepared { + pub(crate) data: VmpPMat, + pub(crate) k: TorusPrecision, + pub(crate) base2k: Base2K, + pub(crate) dsize: Dsize, +} + +impl LWEInfos for GGSWPrepared { + fn n(&self) -> Degree { + Degree(self.data.n() as u32) + } + + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } + + fn size(&self) -> usize { + self.data.size() + } +} + +impl GLWEInfos for GGSWPrepared { + fn rank(&self) -> Rank { + Rank(self.data.cols_out() as u32 - 1) + } +} + +impl GGSWInfos for GGSWPrepared { + fn dsize(&self) -> Dsize { + self.dsize + } + + fn dnum(&self) -> Dnum { + Dnum(self.data.rows() as u32) + } +} + +pub trait GGSWPreparedFactory +where + Self: GetDegree + VmpPMatAlloc + VmpPMatBytesOf + VmpPrepareTmpBytes + VmpPrepare, +{ + fn alloc_ggsw_prepared( + &self, + base2k: Base2K, + k: TorusPrecision, + dnum: Dnum, + dsize: Dsize, + rank: Rank, + ) -> GGSWPrepared, B> { + let size: usize = k.0.div_ceil(base2k.0) as usize; + debug_assert!( + size as u32 > dsize.0, + "invalid ggsw: ceil(k/base2k): {size} <= dsize: {}", + dsize.0 + ); + + assert!( + dnum.0 * dsize.0 <= size as u32, + "invalid ggsw: dnum: {} * dsize:{} > ceil(k/base2k): {size}", + dnum.0, + dsize.0, + ); + + GGSWPrepared { + data: self.vmp_pmat_alloc( + dnum.into(), + (rank + 1).into(), + (rank + 1).into(), + k.0.div_ceil(base2k.0) as usize, + ), + k, + base2k, + dsize, + } + } + + fn alloc_ggsw_prepared_from_infos(&self, infos: &A) -> GGSWPrepared, B> + where + A: GGSWInfos, + { + assert_eq!(self.ring_degree(), infos.n()); + self.alloc_ggsw_prepared( + infos.base2k(), + infos.k(), + infos.dnum(), + infos.dsize(), + infos.rank(), + ) + } + + fn bytes_of_ggsw_prepared(&self, base2k: Base2K, k: TorusPrecision, dnum: Dnum, dsize: Dsize, rank: Rank) -> usize { + let size: usize = k.0.div_ceil(base2k.0) as usize; + debug_assert!( + size as u32 > dsize.0, + "invalid ggsw: ceil(k/base2k): {size} <= dsize: {}", + dsize.0 + ); + + assert!( + dnum.0 * dsize.0 <= size as u32, + "invalid ggsw: dnum: {} * dsize:{} > ceil(k/base2k): {size}", + dnum.0, + dsize.0, + ); + + self.bytes_of_vmp_pmat(dnum.into(), (rank + 1).into(), (rank + 1).into(), size) + } + + fn bytes_of_ggsw_prepared_from_infos(&self, infos: &A) -> usize + where + A: GGSWInfos, + { + assert_eq!(self.ring_degree(), infos.n()); + self.bytes_of_ggsw_prepared( + infos.base2k(), + infos.k(), + infos.dnum(), + infos.dsize(), + infos.rank(), + ) + } + + fn ggsw_prepare_tmp_bytes(&self, infos: &A) -> usize + where + A: GGSWInfos, + { + assert_eq!(self.ring_degree(), infos.n()); + self.vmp_prepare_tmp_bytes( + infos.dnum().into(), + (infos.rank() + 1).into(), + (infos.rank() + 1).into(), + infos.size(), + ) + } + fn ggsw_prepare(&self, res: &mut R, other: &O, scratch: &mut Scratch) + where + R: GGSWPreparedToMut, + O: GGSWToRef, + { + let mut res: GGSWPrepared<&mut [u8], B> = res.to_mut(); + let other: GGSW<&[u8]> = other.to_ref(); + assert_eq!(res.n(), self.ring_degree()); + assert_eq!(other.n(), self.ring_degree()); + assert_eq!(res.k, other.k); + assert_eq!(res.base2k, other.base2k); + assert_eq!(res.dsize, other.dsize); + self.vmp_prepare(&mut res.data, &other.data, scratch); + } +} + +impl GGSWPreparedFactory for Module where + Self: GetDegree + VmpPMatAlloc + VmpPMatBytesOf + VmpPrepareTmpBytes + VmpPrepare +{ +} + +impl GGSWPrepared, B> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: GGSWInfos, + M: GGSWPreparedFactory, + { + module.alloc_ggsw_prepared_from_infos(infos) + } + + pub fn alloc(module: &M, base2k: Base2K, k: TorusPrecision, dnum: Dnum, dsize: Dsize, rank: Rank) -> Self + where + M: GGSWPreparedFactory, + { + module.alloc_ggsw_prepared(base2k, k, dnum, dsize, rank) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: GGSWInfos, + M: GGSWPreparedFactory, + { + module.bytes_of_ggsw_prepared_from_infos(infos) + } + + pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision, dnum: Dnum, dsize: Dsize, rank: Rank) -> usize + where + M: GGSWPreparedFactory, + { + module.bytes_of_ggsw_prepared(base2k, k, dnum, dsize, rank) + } +} + +impl GGSWPrepared { + pub fn data(&self) -> &VmpPMat { + &self.data + } +} + +impl GGSWPrepared, B> { + pub fn prepare_tmp_bytes(&self, module: &M, infos: &A) -> usize + where + A: GGSWInfos, + M: GGSWPreparedFactory, + { + module.ggsw_prepare_tmp_bytes(infos) + } +} + +impl GGSWPrepared { + pub fn prepare(&mut self, module: &M, other: &O, scratch: &mut Scratch) + where + O: GGSWToRef, + M: GGSWPreparedFactory, + { + module.ggsw_prepare(self, other, scratch); + } +} + +pub trait GGSWPreparedToMut { + fn to_mut(&mut self) -> GGSWPrepared<&mut [u8], B>; +} + +impl GGSWPreparedToMut for GGSWPrepared { + fn to_mut(&mut self) -> GGSWPrepared<&mut [u8], B> { + GGSWPrepared { + base2k: self.base2k, + k: self.k, + dsize: self.dsize, + data: self.data.to_mut(), + } + } +} + +pub trait GGSWPreparedToRef { + fn to_ref(&self) -> GGSWPrepared<&[u8], B>; +} + +impl GGSWPreparedToRef for GGSWPrepared { + fn to_ref(&self) -> GGSWPrepared<&[u8], B> { + GGSWPrepared { + base2k: self.base2k, + k: self.k, + dsize: self.dsize, + data: self.data.to_ref(), + } + } +} diff --git a/poulpy-core/src/layouts/prepared/ggsw_ct.rs b/poulpy-core/src/layouts/prepared/ggsw_ct.rs deleted file mode 100644 index eb79a5a..0000000 --- a/poulpy-core/src/layouts/prepared/ggsw_ct.rs +++ /dev/null @@ -1,312 +0,0 @@ -use poulpy_hal::{ - api::{VmpPMatAlloc, VmpPMatAllocBytes, VmpPrepare, VmpPrepareTmpBytes}, - layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VmpPMat, VmpPMatToRef, ZnxInfos}, - oep::VmpPMatAllocBytesImpl, -}; - -use crate::layouts::{ - Base2K, BuildError, Degree, Dnum, Dsize, GGSWCiphertext, GGSWInfos, GLWEInfos, LWEInfos, Rank, TorusPrecision, - prepared::{Prepare, PrepareAlloc, PrepareScratchSpace}, -}; - -#[derive(PartialEq, Eq)] -pub struct GGSWCiphertextPrepared { - pub(crate) data: VmpPMat, - pub(crate) k: TorusPrecision, - pub(crate) base2k: Base2K, - pub(crate) dsize: Dsize, -} - -impl LWEInfos for GGSWCiphertextPrepared { - fn n(&self) -> Degree { - Degree(self.data.n() as u32) - } - - fn base2k(&self) -> Base2K { - self.base2k - } - - fn k(&self) -> TorusPrecision { - self.k - } - - fn size(&self) -> usize { - self.data.size() - } -} - -impl GLWEInfos for GGSWCiphertextPrepared { - fn rank(&self) -> Rank { - Rank(self.data.cols_out() as u32 - 1) - } -} - -impl GGSWInfos for GGSWCiphertextPrepared { - fn dsize(&self) -> Dsize { - self.dsize - } - - fn dnum(&self) -> Dnum { - Dnum(self.data.rows() as u32) - } -} - -pub struct GGSWCiphertextPreparedBuilder { - data: Option>, - base2k: Option, - k: Option, - dsize: Option, -} - -impl GGSWCiphertextPrepared { - #[inline] - pub fn builder() -> GGSWCiphertextPreparedBuilder { - GGSWCiphertextPreparedBuilder { - data: None, - base2k: None, - k: None, - dsize: None, - } - } -} - -impl GGSWCiphertextPreparedBuilder, B> { - #[inline] - pub fn layout(mut self, infos: &A) -> Self - where - A: GGSWInfos, - B: VmpPMatAllocBytesImpl, - { - debug_assert!( - infos.size() as u32 > infos.dsize().0, - "invalid ggsw: ceil(k/base2k): {} <= dsize: {}", - infos.size(), - infos.dsize() - ); - - assert!( - infos.dnum().0 * infos.dsize().0 <= infos.size() as u32, - "invalid ggsw: dnum: {} * dsize:{} > ceil(k/base2k): {}", - infos.dnum(), - infos.dsize(), - infos.size(), - ); - - self.data = Some(VmpPMat::alloc( - infos.n().into(), - infos.dnum().into(), - (infos.rank() + 1).into(), - (infos.rank() + 1).into(), - infos.size(), - )); - self.base2k = Some(infos.base2k()); - self.k = Some(infos.k()); - self.dsize = Some(infos.dsize()); - self - } -} - -impl GGSWCiphertextPreparedBuilder { - #[inline] - pub fn data(mut self, data: VmpPMat) -> Self { - self.data = Some(data); - self - } - #[inline] - pub fn base2k(mut self, base2k: Base2K) -> Self { - self.base2k = Some(base2k); - self - } - #[inline] - pub fn k(mut self, k: TorusPrecision) -> Self { - self.k = Some(k); - self - } - - #[inline] - pub fn dsize(mut self, dsize: Dsize) -> Self { - self.dsize = Some(dsize); - self - } - - pub fn build(self) -> Result, BuildError> { - let data: VmpPMat = self.data.ok_or(BuildError::MissingData)?; - let base2k: Base2K = self.base2k.ok_or(BuildError::MissingBase2K)?; - let k: TorusPrecision = self.k.ok_or(BuildError::MissingK)?; - let dsize: Dsize = self.dsize.ok_or(BuildError::MissingDigits)?; - - if base2k == 0_u32 { - return Err(BuildError::ZeroBase2K); - } - - if dsize == 0_u32 { - return Err(BuildError::ZeroBase2K); - } - - if k == 0_u32 { - return Err(BuildError::ZeroTorusPrecision); - } - - if data.n() == 0 { - return Err(BuildError::ZeroDegree); - } - - if data.cols() == 0 { - return Err(BuildError::ZeroCols); - } - - if data.size() == 0 { - return Err(BuildError::ZeroLimbs); - } - - Ok(GGSWCiphertextPrepared { - data, - base2k, - k, - dsize, - }) - } -} - -impl GGSWCiphertextPrepared, B> { - pub fn alloc(module: &Module, infos: &A) -> Self - where - A: GGSWInfos, - Module: VmpPMatAlloc, - { - Self::alloc_with( - module, - infos.base2k(), - infos.k(), - infos.dnum(), - infos.dsize(), - infos.rank(), - ) - } - - pub fn alloc_with(module: &Module, base2k: Base2K, k: TorusPrecision, dnum: Dnum, dsize: Dsize, rank: Rank) -> Self - where - Module: VmpPMatAlloc, - { - let size: usize = k.0.div_ceil(base2k.0) as usize; - debug_assert!( - size as u32 > dsize.0, - "invalid ggsw: ceil(k/base2k): {size} <= dsize: {}", - dsize.0 - ); - - assert!( - dnum.0 * dsize.0 <= size as u32, - "invalid ggsw: dnum: {} * dsize:{} > ceil(k/base2k): {size}", - dnum.0, - dsize.0, - ); - - Self { - data: module.vmp_pmat_alloc( - dnum.into(), - (rank + 1).into(), - (rank + 1).into(), - k.0.div_ceil(base2k.0) as usize, - ), - k, - base2k, - dsize, - } - } - - pub fn alloc_bytes(module: &Module, infos: &A) -> usize - where - A: GGSWInfos, - Module: VmpPMatAllocBytes, - { - Self::alloc_bytes_with( - module, - infos.base2k(), - infos.k(), - infos.dnum(), - infos.dsize(), - infos.rank(), - ) - } - - pub fn alloc_bytes_with(module: &Module, base2k: Base2K, k: TorusPrecision, dnum: Dnum, dsize: Dsize, rank: Rank) -> usize - where - Module: VmpPMatAllocBytes, - { - let size: usize = k.0.div_ceil(base2k.0) as usize; - debug_assert!( - size as u32 > dsize.0, - "invalid ggsw: ceil(k/base2k): {size} <= dsize: {}", - dsize.0 - ); - - assert!( - dnum.0 * dsize.0 <= size as u32, - "invalid ggsw: dnum: {} * dsize:{} > ceil(k/base2k): {size}", - dnum.0, - dsize.0, - ); - - module.vmp_pmat_alloc_bytes(dnum.into(), (rank + 1).into(), (rank + 1).into(), size) - } -} - -impl GGSWCiphertextPrepared { - pub fn data(&self) -> &VmpPMat { - &self.data - } -} - -impl PrepareScratchSpace for GGSWCiphertextPrepared, B> -where - Module: VmpPrepareTmpBytes, -{ - fn prepare_scratch_space(module: &Module, infos: &A) -> usize { - module.vmp_prepare_tmp_bytes( - infos.dnum().into(), - (infos.rank() + 1).into(), - (infos.rank() + 1).into(), - infos.size(), - ) - } -} - -impl Prepare> for GGSWCiphertextPrepared -where - Module: VmpPrepare, -{ - fn prepare(&mut self, module: &Module, other: &GGSWCiphertext, scratch: &mut Scratch) { - module.vmp_prepare(&mut self.data, &other.data, scratch); - self.k = other.k; - self.base2k = other.base2k; - self.dsize = other.dsize; - } -} - -impl PrepareAlloc, B>> for GGSWCiphertext -where - Module: VmpPMatAlloc + VmpPrepare, -{ - fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> GGSWCiphertextPrepared, B> { - let mut ggsw_prepared: GGSWCiphertextPrepared, B> = GGSWCiphertextPrepared::alloc(module, self); - ggsw_prepared.prepare(module, self, scratch); - ggsw_prepared - } -} - -pub trait GGSWCiphertextPreparedToRef { - fn to_ref(&self) -> GGSWCiphertextPrepared<&[u8], B>; -} - -impl GGSWCiphertextPreparedToRef for GGSWCiphertextPrepared { - fn to_ref(&self) -> GGSWCiphertextPrepared<&[u8], B> { - GGSWCiphertextPrepared::builder() - .base2k(self.base2k()) - .dsize(self.dsize()) - .k(self.k()) - .data(self.data.to_ref()) - .build() - .unwrap() - } -} diff --git a/poulpy-core/src/layouts/prepared/glwe.rs b/poulpy-core/src/layouts/prepared/glwe.rs new file mode 100644 index 0000000..cf7e564 --- /dev/null +++ b/poulpy-core/src/layouts/prepared/glwe.rs @@ -0,0 +1,161 @@ +use poulpy_hal::{ + api::{VecZnxDftAlloc, VecZnxDftApply, VecZnxDftBytesOf}, + layouts::{Backend, Data, DataMut, DataRef, Module, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ZnxInfos}, +}; + +use crate::layouts::{Base2K, Degree, GLWE, GLWEInfos, GLWEToRef, GetDegree, LWEInfos, Rank, TorusPrecision}; + +#[derive(PartialEq, Eq)] +pub struct GLWEPrepared { + pub(crate) data: VecZnxDft, + pub(crate) base2k: Base2K, + pub(crate) k: TorusPrecision, +} + +impl LWEInfos for GLWEPrepared { + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } + + fn size(&self) -> usize { + self.data.size() + } + + fn n(&self) -> Degree { + Degree(self.data.n() as u32) + } +} + +impl GLWEInfos for GLWEPrepared { + fn rank(&self) -> Rank { + Rank(self.data.cols() as u32 - 1) + } +} + +pub trait GLWEPreparedFactory +where + Self: GetDegree + VecZnxDftAlloc + VecZnxDftBytesOf + VecZnxDftApply, +{ + fn alloc_glwe_prepared(&self, base2k: Base2K, k: TorusPrecision, rank: Rank) -> GLWEPrepared, B> { + GLWEPrepared { + data: self.vec_znx_dft_alloc((rank + 1).into(), k.0.div_ceil(base2k.0) as usize), + base2k, + k, + } + } + + fn alloc_glwe_prepared_from_infos(&self, infos: &A) -> GLWEPrepared, B> + where + A: GLWEInfos, + { + self.alloc_glwe_prepared(infos.base2k(), infos.k(), infos.rank()) + } + + fn bytes_of_glwe_prepared(&self, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize { + self.bytes_of_vec_znx_dft((rank + 1).into(), k.0.div_ceil(base2k.0) as usize) + } + + fn bytes_of_glwe_prepared_from_infos(&self, infos: &A) -> usize + where + A: GLWEInfos, + { + self.bytes_of_glwe_prepared(infos.base2k(), infos.k(), infos.rank()) + } + + fn prepare_glwe(&self, res: &mut R, other: &O) + where + R: GLWEPreparedToMut, + O: GLWEToRef, + { + { + let mut res: GLWEPrepared<&mut [u8], B> = res.to_mut(); + let other: GLWE<&[u8]> = other.to_ref(); + + assert_eq!(res.n(), self.ring_degree()); + assert_eq!(other.n(), self.ring_degree()); + assert_eq!(res.size(), other.size()); + assert_eq!(res.k(), other.k()); + assert_eq!(res.base2k(), other.base2k()); + + for i in 0..(res.rank() + 1).into() { + self.vec_znx_dft_apply(1, 0, &mut res.data, i, &other.data, i); + } + } + } +} + +impl GLWEPreparedFactory for Module where Self: VecZnxDftAlloc + VecZnxDftBytesOf + VecZnxDftApply {} + +impl GLWEPrepared, B> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: GLWEInfos, + M: GLWEPreparedFactory, + { + module.alloc_glwe_prepared_from_infos(infos) + } + + pub fn alloc(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self + where + M: GLWEPreparedFactory, + { + module.alloc_glwe_prepared(base2k, k, rank) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: GLWEInfos, + M: GLWEPreparedFactory, + { + module.bytes_of_glwe_prepared_from_infos(infos) + } + + pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize + where + M: GLWEPreparedFactory, + { + module.bytes_of_glwe_prepared(base2k, k, rank) + } +} + +impl GLWEPrepared { + pub fn prepare(&mut self, module: &M, other: &O) + where + O: GLWEToRef, + M: GLWEPreparedFactory, + { + module.prepare_glwe(self, other); + } +} + +pub trait GLWEPreparedToMut { + fn to_mut(&mut self) -> GLWEPrepared<&mut [u8], B>; +} + +impl GLWEPreparedToMut for GLWEPrepared { + fn to_mut(&mut self) -> GLWEPrepared<&mut [u8], B> { + GLWEPrepared { + k: self.k, + base2k: self.base2k, + data: self.data.to_mut(), + } + } +} + +pub trait GLWEPreparedToRef { + fn to_ref(&self) -> GLWEPrepared<&[u8], B>; +} + +impl GLWEPreparedToRef for GLWEPrepared { + fn to_ref(&self) -> GLWEPrepared<&[u8], B> { + GLWEPrepared { + data: self.data.to_ref(), + k: self.k, + base2k: self.base2k, + } + } +} diff --git a/poulpy-core/src/layouts/prepared/glwe_automorphism_key.rs b/poulpy-core/src/layouts/prepared/glwe_automorphism_key.rs new file mode 100644 index 0000000..adf54eb --- /dev/null +++ b/poulpy-core/src/layouts/prepared/glwe_automorphism_key.rs @@ -0,0 +1,213 @@ +use poulpy_hal::layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}; + +use crate::layouts::{ + Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWEPrepared, GGLWEPreparedFactory, GGLWEPreparedToMut, GGLWEPreparedToRef, + GGLWEToRef, GLWEInfos, GetGaloisElement, LWEInfos, Rank, SetGaloisElement, TorusPrecision, +}; + +#[derive(PartialEq, Eq)] +pub struct GLWEAutomorphismKeyPrepared { + pub(crate) key: GGLWEPrepared, + pub(crate) p: i64, +} + +impl LWEInfos for GLWEAutomorphismKeyPrepared { + fn n(&self) -> Degree { + self.key.n() + } + + fn base2k(&self) -> Base2K { + self.key.base2k() + } + + fn k(&self) -> TorusPrecision { + self.key.k() + } + + fn size(&self) -> usize { + self.key.size() + } +} + +impl GetGaloisElement for GLWEAutomorphismKeyPrepared { + fn p(&self) -> i64 { + self.p + } +} + +impl SetGaloisElement for GLWEAutomorphismKeyPrepared { + fn set_p(&mut self, p: i64) { + self.p = p + } +} + +impl GLWEInfos for GLWEAutomorphismKeyPrepared { + fn rank(&self) -> Rank { + self.rank_out() + } +} + +impl GGLWEInfos for GLWEAutomorphismKeyPrepared { + fn rank_in(&self) -> Rank { + self.key.rank_in() + } + + fn rank_out(&self) -> Rank { + self.key.rank_out() + } + + fn dsize(&self) -> Dsize { + self.key.dsize() + } + + fn dnum(&self) -> Dnum { + self.key.dnum() + } +} + +pub trait GLWEAutomorphismKeyPreparedFactory +where + Self: GGLWEPreparedFactory, +{ + fn alloc_glwe_automorphism_key_prepared( + &self, + base2k: Base2K, + k: TorusPrecision, + rank: Rank, + dnum: Dnum, + dsize: Dsize, + ) -> GLWEAutomorphismKeyPrepared, B> { + GLWEAutomorphismKeyPrepared::, B> { + key: self.alloc_gglwe_prepared(base2k, k, rank, rank, dnum, dsize), + p: 0, + } + } + + fn alloc_glwe_automorphism_key_prepared_from_infos(&self, infos: &A) -> GLWEAutomorphismKeyPrepared, B> + where + A: GGLWEInfos, + { + assert_eq!( + infos.rank_in(), + infos.rank_out(), + "rank_in != rank_out is not supported for AutomorphismKeyPrepared" + ); + self.alloc_glwe_automorphism_key_prepared( + infos.base2k(), + infos.k(), + infos.rank(), + infos.dnum(), + infos.dsize(), + ) + } + + fn bytes_of_glwe_automorphism_key_prepared( + &self, + base2k: Base2K, + k: TorusPrecision, + rank: Rank, + dnum: Dnum, + dsize: Dsize, + ) -> usize { + self.bytes_of_gglwe_prepared(base2k, k, rank, rank, dnum, dsize) + } + + fn bytes_of_glwe_automorphism_key_prepared_from_infos(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + assert_eq!( + infos.rank_in(), + infos.rank_out(), + "rank_in != rank_out is not supported for AutomorphismKeyPrepared" + ); + self.bytes_of_glwe_automorphism_key_prepared( + infos.base2k(), + infos.k(), + infos.rank(), + infos.dnum(), + infos.dsize(), + ) + } + + fn prepare_glwe_automorphism_key_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + self.prepare_gglwe_tmp_bytes(infos) + } + + fn prepare_glwe_automorphism_key(&self, res: &mut R, other: &O, scratch: &mut Scratch) + where + R: GGLWEPreparedToMut + SetGaloisElement, + O: GGLWEToRef + GetGaloisElement, + { + self.prepare_gglwe(res, other, scratch); + res.set_p(other.p()); + } +} + +impl GLWEAutomorphismKeyPreparedFactory for Module where Module: GGLWEPreparedFactory {} + +impl GLWEAutomorphismKeyPrepared, B> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: GGLWEInfos, + M: GLWEAutomorphismKeyPreparedFactory, + { + module.alloc_glwe_automorphism_key_prepared_from_infos(infos) + } + + pub fn alloc(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self + where + M: GLWEAutomorphismKeyPreparedFactory, + { + module.alloc_glwe_automorphism_key_prepared(base2k, k, rank, dnum, dsize) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: GLWEAutomorphismKeyPreparedFactory, + { + module.bytes_of_glwe_automorphism_key_prepared_from_infos(infos) + } + + pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize + where + M: GLWEAutomorphismKeyPreparedFactory, + { + module.bytes_of_glwe_automorphism_key_prepared(base2k, k, rank, dnum, dsize) + } +} + +impl GLWEAutomorphismKeyPrepared, B> { + pub fn prepare_tmp_bytes(&self, module: &M) -> usize + where + M: GLWEAutomorphismKeyPreparedFactory, + { + module.prepare_glwe_automorphism_key_tmp_bytes(self) + } +} + +impl GLWEAutomorphismKeyPrepared { + pub fn prepare(&mut self, module: &M, other: &O, scratch: &mut Scratch) + where + O: GGLWEToRef + GetGaloisElement, + M: GLWEAutomorphismKeyPreparedFactory, + { + module.prepare_glwe_automorphism_key(self, other, scratch); + } +} + +impl GGLWEPreparedToMut for GLWEAutomorphismKeyPrepared { + fn to_mut(&mut self) -> GGLWEPrepared<&mut [u8], B> { + self.key.to_mut() + } +} + +impl GGLWEPreparedToRef for GLWEAutomorphismKeyPrepared { + fn to_ref(&self) -> GGLWEPrepared<&[u8], BE> { + self.key.to_ref() + } +} diff --git a/poulpy-core/src/layouts/prepared/glwe_pk.rs b/poulpy-core/src/layouts/prepared/glwe_pk.rs deleted file mode 100644 index 6834f58..0000000 --- a/poulpy-core/src/layouts/prepared/glwe_pk.rs +++ /dev/null @@ -1,207 +0,0 @@ -use poulpy_hal::{ - api::{VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply}, - layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VecZnxDft, ZnxInfos}, - oep::VecZnxDftAllocBytesImpl, -}; - -use crate::{ - dist::Distribution, - layouts::{ - Base2K, BuildError, Degree, GLWEInfos, GLWEPublicKey, LWEInfos, Rank, TorusPrecision, - prepared::{Prepare, PrepareAlloc, PrepareScratchSpace}, - }, -}; - -#[derive(PartialEq, Eq)] -pub struct GLWEPublicKeyPrepared { - pub(crate) data: VecZnxDft, - pub(crate) base2k: Base2K, - pub(crate) k: TorusPrecision, - pub(crate) dist: Distribution, -} - -impl LWEInfos for GLWEPublicKeyPrepared { - fn base2k(&self) -> Base2K { - self.base2k - } - - fn k(&self) -> TorusPrecision { - self.k - } - - fn size(&self) -> usize { - self.data.size() - } - - fn n(&self) -> Degree { - Degree(self.data.n() as u32) - } -} - -impl GLWEInfos for GLWEPublicKeyPrepared { - fn rank(&self) -> Rank { - Rank(self.data.cols() as u32 - 1) - } -} - -pub struct GLWEPublicKeyPreparedBuilder { - data: Option>, - base2k: Option, - k: Option, -} - -impl GLWEPublicKeyPrepared { - #[inline] - pub fn builder() -> GLWEPublicKeyPreparedBuilder { - GLWEPublicKeyPreparedBuilder { - data: None, - base2k: None, - k: None, - } - } -} - -impl GLWEPublicKeyPreparedBuilder, B> { - #[inline] - pub fn layout(mut self, layout: &A) -> Self - where - A: GLWEInfos, - B: VecZnxDftAllocBytesImpl, - { - self.data = Some(VecZnxDft::alloc( - layout.n().into(), - (layout.rank() + 1).into(), - layout.size(), - )); - self.base2k = Some(layout.base2k()); - self.k = Some(layout.k()); - self - } -} - -impl GLWEPublicKeyPreparedBuilder { - #[inline] - pub fn data(mut self, data: VecZnxDft) -> Self { - self.data = Some(data); - self - } - #[inline] - pub fn base2k(mut self, base2k: Base2K) -> Self { - self.base2k = Some(base2k); - self - } - #[inline] - pub fn k(mut self, k: TorusPrecision) -> Self { - self.k = Some(k); - self - } - - pub fn build(self) -> Result, BuildError> { - let data: VecZnxDft = self.data.ok_or(BuildError::MissingData)?; - let base2k: Base2K = self.base2k.ok_or(BuildError::MissingBase2K)?; - let k: TorusPrecision = self.k.ok_or(BuildError::MissingK)?; - - if base2k == 0_u32 { - return Err(BuildError::ZeroBase2K); - } - - if k == 0_u32 { - return Err(BuildError::ZeroTorusPrecision); - } - - if data.n() == 0 { - return Err(BuildError::ZeroDegree); - } - - if data.cols() == 0 { - return Err(BuildError::ZeroCols); - } - - if data.size() == 0 { - return Err(BuildError::ZeroLimbs); - } - - Ok(GLWEPublicKeyPrepared { - data, - base2k, - k, - dist: Distribution::NONE, - }) - } -} - -impl GLWEPublicKeyPrepared, B> { - pub fn alloc(module: &Module, infos: &A) -> Self - where - A: GLWEInfos, - Module: VecZnxDftAlloc, - { - debug_assert_eq!(module.n(), infos.n().0 as usize, "module.n() != infos.n()"); - Self::alloc_with(module, infos.base2k(), infos.k(), infos.rank()) - } - - pub fn alloc_with(module: &Module, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self - where - Module: VecZnxDftAlloc, - { - Self { - data: module.vec_znx_dft_alloc((rank + 1).into(), k.0.div_ceil(base2k.0) as usize), - base2k, - k, - dist: Distribution::NONE, - } - } - - pub fn alloc_bytes(module: &Module, infos: &A) -> usize - where - A: GLWEInfos, - Module: VecZnxDftAllocBytes, - { - debug_assert_eq!(module.n(), infos.n().0 as usize, "module.n() != infos.n()"); - Self::alloc_bytes_with(module, infos.base2k(), infos.k(), infos.rank()) - } - - pub fn alloc_bytes_with(module: &Module, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize - where - Module: VecZnxDftAllocBytes, - { - module.vec_znx_dft_alloc_bytes((rank + 1).into(), k.0.div_ceil(base2k.0) as usize) - } -} - -impl PrepareAlloc, B>> for GLWEPublicKey -where - Module: VecZnxDftAlloc + VecZnxDftApply, -{ - fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> GLWEPublicKeyPrepared, B> { - let mut pk_prepared: GLWEPublicKeyPrepared, B> = GLWEPublicKeyPrepared::alloc(module, self); - pk_prepared.prepare(module, self, scratch); - pk_prepared - } -} - -impl PrepareScratchSpace for GLWEPublicKeyPrepared, B> { - fn prepare_scratch_space(_module: &Module, _infos: &A) -> usize { - 0 - } -} - -impl Prepare> for GLWEPublicKeyPrepared -where - Module: VecZnxDftApply, -{ - fn prepare(&mut self, module: &Module, other: &GLWEPublicKey, _scratch: &mut Scratch) { - #[cfg(debug_assertions)] - { - assert_eq!(self.n(), other.n()); - assert_eq!(self.size(), other.size()); - } - - (0..(self.rank() + 1).into()).for_each(|i| { - module.vec_znx_dft_apply(1, 0, &mut self.data, i, &other.data, i); - }); - self.k = other.k(); - self.base2k = other.base2k(); - self.dist = other.dist; - } -} diff --git a/poulpy-core/src/layouts/prepared/glwe_public_key.rs b/poulpy-core/src/layouts/prepared/glwe_public_key.rs new file mode 100644 index 0000000..fab30bb --- /dev/null +++ b/poulpy-core/src/layouts/prepared/glwe_public_key.rs @@ -0,0 +1,157 @@ +use poulpy_hal::{ + api::{VecZnxDftAlloc, VecZnxDftApply, VecZnxDftBytesOf}, + layouts::{Backend, Data, DataMut, DataRef, Module}, +}; + +use crate::{ + GetDistribution, GetDistributionMut, + dist::Distribution, + layouts::{ + Base2K, Degree, GLWEInfos, GLWEPrepared, GLWEPreparedFactory, GLWEPreparedToMut, GLWEPreparedToRef, GLWEToRef, GetDegree, + LWEInfos, Rank, TorusPrecision, + }, +}; + +#[derive(PartialEq, Eq)] +pub struct GLWEPublicKeyPrepared { + pub(crate) key: GLWEPrepared, + pub(crate) dist: Distribution, +} + +impl GetDistribution for GLWEPublicKeyPrepared { + fn dist(&self) -> &Distribution { + &self.dist + } +} + +impl GetDistributionMut for GLWEPublicKeyPrepared { + fn dist_mut(&mut self) -> &mut Distribution { + &mut self.dist + } +} + +impl LWEInfos for GLWEPublicKeyPrepared { + fn base2k(&self) -> Base2K { + self.key.base2k() + } + + fn k(&self) -> TorusPrecision { + self.key.k() + } + + fn size(&self) -> usize { + self.key.size() + } + + fn n(&self) -> Degree { + self.key.n() + } +} + +impl GLWEInfos for GLWEPublicKeyPrepared { + fn rank(&self) -> Rank { + self.key.rank() + } +} + +pub trait GLWEPublicKeyPreparedFactory +where + Self: GetDegree + GLWEPreparedFactory, +{ + fn alloc_glwe_public_key_prepared(&self, base2k: Base2K, k: TorusPrecision, rank: Rank) -> GLWEPublicKeyPrepared, B> { + GLWEPublicKeyPrepared { + key: self.alloc_glwe_prepared(base2k, k, rank), + dist: Distribution::NONE, + } + } + + fn alloc_glwe_public_key_prepared_from_infos(&self, infos: &A) -> GLWEPublicKeyPrepared, B> + where + A: GLWEInfos, + { + self.alloc_glwe_public_key_prepared(infos.base2k(), infos.k(), infos.rank()) + } + + fn bytes_of_glwe_public_key_prepared(&self, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize { + self.bytes_of_glwe_prepared(base2k, k, rank) + } + + fn bytes_of_glwe_public_key_prepared_from_infos(&self, infos: &A) -> usize + where + A: GLWEInfos, + { + self.bytes_of_glwe_public_key_prepared(infos.base2k(), infos.k(), infos.rank()) + } + + fn prepare_glwe_public_key(&self, res: &mut R, other: &O) + where + R: GLWEPreparedToMut + GetDistributionMut, + O: GLWEToRef + GetDistribution, + { + self.prepare_glwe(res, other); + *res.dist_mut() = *other.dist(); + } +} + +impl GLWEPublicKeyPreparedFactory for Module where Self: VecZnxDftAlloc + VecZnxDftBytesOf + VecZnxDftApply +{} + +impl GLWEPublicKeyPrepared, B> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: GLWEInfos, + M: GLWEPublicKeyPreparedFactory, + { + module.alloc_glwe_public_key_prepared_from_infos(infos) + } + + pub fn alloc(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self + where + M: GLWEPublicKeyPreparedFactory, + { + module.alloc_glwe_public_key_prepared(base2k, k, rank) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: GLWEInfos, + M: GLWEPublicKeyPreparedFactory, + { + module.bytes_of_glwe_public_key_prepared_from_infos(infos) + } + + pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize + where + M: GLWEPublicKeyPreparedFactory, + { + module.bytes_of_glwe_public_key_prepared(base2k, k, rank) + } +} + +impl GLWEPublicKeyPrepared { + pub fn prepare(&mut self, module: &M, other: &O) + where + O: GLWEToRef + GetDistribution, + M: GLWEPublicKeyPreparedFactory, + { + module.prepare_glwe_public_key(self, other); + } +} + +impl GLWEPreparedToMut for GLWEPublicKeyPrepared +where + GLWEPrepared: GLWEPreparedToMut, +{ + fn to_mut(&mut self) -> GLWEPrepared<&mut [u8], B> { + self.key.to_mut() + } +} + +impl GLWEPreparedToRef for GLWEPublicKeyPrepared +where + GLWEPrepared: GLWEPreparedToRef, +{ + fn to_ref(&self) -> GLWEPrepared<&[u8], B> { + self.key.to_ref() + } +} diff --git a/poulpy-core/src/layouts/prepared/glwe_secret.rs b/poulpy-core/src/layouts/prepared/glwe_secret.rs new file mode 100644 index 0000000..8f4917d --- /dev/null +++ b/poulpy-core/src/layouts/prepared/glwe_secret.rs @@ -0,0 +1,183 @@ +use poulpy_hal::{ + api::{SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare}, + layouts::{Backend, Data, DataMut, DataRef, Module, SvpPPol, SvpPPolToMut, SvpPPolToRef, ZnxInfos}, +}; + +use crate::{ + GetDistribution, GetDistributionMut, + dist::Distribution, + layouts::{Base2K, Degree, GLWEInfos, GLWESecret, GLWESecretToRef, GetDegree, LWEInfos, Rank, TorusPrecision}, +}; + +pub struct GLWESecretPrepared { + pub(crate) data: SvpPPol, + pub(crate) dist: Distribution, +} + +impl GetDistribution for GLWESecretPrepared { + fn dist(&self) -> &Distribution { + &self.dist + } +} + +impl GetDistributionMut for GLWESecretPrepared { + fn dist_mut(&mut self) -> &mut Distribution { + &mut self.dist + } +} + +impl LWEInfos for GLWESecretPrepared { + fn base2k(&self) -> Base2K { + Base2K(0) + } + + fn k(&self) -> TorusPrecision { + TorusPrecision(0) + } + + fn n(&self) -> Degree { + Degree(self.data.n() as u32) + } + + fn size(&self) -> usize { + self.data.size() + } +} +impl GLWEInfos for GLWESecretPrepared { + fn rank(&self) -> Rank { + Rank(self.data.cols() as u32) + } +} + +pub trait GLWESecretPreparedFactory +where + Self: GetDegree + SvpPPolBytesOf + SvpPPolAlloc + SvpPrepare, +{ + fn alloc_glwe_secret_prepared(&self, rank: Rank) -> GLWESecretPrepared, B> { + GLWESecretPrepared { + data: self.svp_ppol_alloc(rank.into()), + dist: Distribution::NONE, + } + } + fn alloc_glwe_secret_prepared_from_infos(&self, infos: &A) -> GLWESecretPrepared, B> + where + A: GLWEInfos, + { + assert_eq!(self.ring_degree(), infos.n()); + self.alloc_glwe_secret_prepared(infos.rank()) + } + + fn bytes_of_glwe_secret_prepared(&self, rank: Rank) -> usize { + self.bytes_of_svp_ppol(rank.into()) + } + fn bytes_of_glwe_secret_prepared_from_infos(&self, infos: &A) -> usize + where + A: GLWEInfos, + { + assert_eq!(self.ring_degree(), infos.n()); + self.bytes_of_glwe_secret_prepared(infos.rank()) + } + + fn prepare_glwe_secret(&self, res: &mut R, other: &O) + where + R: GLWESecretPreparedToMut + GetDistributionMut, + O: GLWESecretToRef + GetDistribution, + { + { + let mut res: GLWESecretPrepared<&mut [u8], _> = res.to_mut(); + let other: GLWESecret<&[u8]> = other.to_ref(); + + for i in 0..res.rank().into() { + self.svp_prepare(&mut res.data, i, &other.data, i); + } + } + + *res.dist_mut() = *other.dist(); + } +} + +impl GLWESecretPreparedFactory for Module where + Self: GetDegree + SvpPPolBytesOf + SvpPPolAlloc + SvpPrepare +{ +} + +impl GLWESecretPrepared, B> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: GLWEInfos, + M: GLWESecretPreparedFactory, + { + module.alloc_glwe_secret_prepared_from_infos(infos) + } + + pub fn alloc(module: &M, rank: Rank) -> Self + where + M: GLWESecretPreparedFactory, + { + module.alloc_glwe_secret_prepared(rank) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: GLWEInfos, + M: GLWESecretPreparedFactory, + { + module.bytes_of_glwe_secret_prepared_from_infos(infos) + } + + pub fn bytes_of(module: &M, rank: Rank) -> usize + where + M: GLWESecretPreparedFactory, + { + module.bytes_of_glwe_secret_prepared(rank) + } +} + +impl GLWESecretPrepared { + pub fn n(&self) -> Degree { + Degree(self.data.n() as u32) + } + + pub fn rank(&self) -> Rank { + Rank(self.data.cols() as u32) + } +} + +impl GLWESecretPrepared { + pub fn prepare(&mut self, module: &M, other: &O) + where + M: GLWESecretPreparedFactory, + O: GLWESecretToRef + GetDistribution, + { + module.prepare_glwe_secret(self, other); + } +} + +pub trait GLWESecretPreparedToRef { + fn to_ref(&self) -> GLWESecretPrepared<&[u8], B>; +} + +impl GLWESecretPreparedToRef for GLWESecretPrepared { + fn to_ref(&self) -> GLWESecretPrepared<&[u8], B> { + GLWESecretPrepared { + data: self.data.to_ref(), + dist: self.dist, + } + } +} + +pub trait GLWESecretPreparedToMut +where + Self: GLWESecretPreparedToRef, +{ + fn to_mut(&mut self) -> GLWESecretPrepared<&mut [u8], B>; +} + +impl GLWESecretPreparedToMut for GLWESecretPrepared { + fn to_mut(&mut self) -> GLWESecretPrepared<&mut [u8], B> { + GLWESecretPrepared { + dist: self.dist, + data: self.data.to_mut(), + } + } +} diff --git a/poulpy-core/src/layouts/prepared/glwe_sk.rs b/poulpy-core/src/layouts/prepared/glwe_sk.rs deleted file mode 100644 index d3f638b..0000000 --- a/poulpy-core/src/layouts/prepared/glwe_sk.rs +++ /dev/null @@ -1,115 +0,0 @@ -use poulpy_hal::{ - api::{SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare}, - layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, SvpPPol, ZnxInfos}, -}; - -use crate::{ - dist::Distribution, - layouts::{ - Base2K, Degree, GLWEInfos, GLWESecret, LWEInfos, Rank, TorusPrecision, - prepared::{Prepare, PrepareAlloc, PrepareScratchSpace}, - }, -}; - -pub struct GLWESecretPrepared { - pub(crate) data: SvpPPol, - pub(crate) dist: Distribution, -} - -impl LWEInfos for GLWESecretPrepared { - fn base2k(&self) -> Base2K { - Base2K(0) - } - - fn k(&self) -> TorusPrecision { - TorusPrecision(0) - } - - fn n(&self) -> Degree { - Degree(self.data.n() as u32) - } - - fn size(&self) -> usize { - self.data.size() - } -} -impl GLWEInfos for GLWESecretPrepared { - fn rank(&self) -> Rank { - Rank(self.data.cols() as u32) - } -} -impl GLWESecretPrepared, B> { - pub fn alloc(module: &Module, infos: &A) -> Self - where - A: GLWEInfos, - Module: SvpPPolAlloc, - { - assert_eq!(module.n() as u32, infos.n()); - Self::alloc_with(module, infos.rank()) - } - - pub fn alloc_with(module: &Module, rank: Rank) -> Self - where - Module: SvpPPolAlloc, - { - Self { - data: module.svp_ppol_alloc(rank.into()), - dist: Distribution::NONE, - } - } - - pub fn alloc_bytes(module: &Module, infos: &A) -> usize - where - A: GLWEInfos, - Module: SvpPPolAllocBytes, - { - assert_eq!(module.n() as u32, infos.n()); - Self::alloc_bytes_with(module, infos.rank()) - } - - pub fn alloc_bytes_with(module: &Module, rank: Rank) -> usize - where - Module: SvpPPolAllocBytes, - { - module.svp_ppol_alloc_bytes(rank.into()) - } -} - -impl GLWESecretPrepared { - pub fn n(&self) -> Degree { - Degree(self.data.n() as u32) - } - - pub fn rank(&self) -> Rank { - Rank(self.data.cols() as u32) - } -} - -impl PrepareScratchSpace for GLWESecretPrepared, B> { - fn prepare_scratch_space(_module: &Module, _infos: &A) -> usize { - 0 - } -} - -impl PrepareAlloc, B>> for GLWESecret -where - Module: SvpPrepare + SvpPPolAlloc, -{ - fn prepare_alloc(&self, module: &Module, _scratch: &mut Scratch) -> GLWESecretPrepared, B> { - let mut sk_dft: GLWESecretPrepared, B> = GLWESecretPrepared::alloc(module, self); - sk_dft.prepare(module, self, _scratch); - sk_dft - } -} - -impl Prepare> for GLWESecretPrepared -where - Module: SvpPrepare, -{ - fn prepare(&mut self, module: &Module, other: &GLWESecret, _scratch: &mut Scratch) { - (0..self.rank().into()).for_each(|i| { - module.svp_prepare(&mut self.data, i, &other.data, i); - }); - self.dist = other.dist - } -} diff --git a/poulpy-core/src/layouts/prepared/glwe_switching_key.rs b/poulpy-core/src/layouts/prepared/glwe_switching_key.rs new file mode 100644 index 0000000..d73d17d --- /dev/null +++ b/poulpy-core/src/layouts/prepared/glwe_switching_key.rs @@ -0,0 +1,241 @@ +use poulpy_hal::layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}; + +use crate::layouts::{ + Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWEToRef, GLWEInfos, GLWESwitchingKeyDegrees, GLWESwitchingKeyDegreesMut, + LWEInfos, Rank, TorusPrecision, + prepared::{GGLWEPrepared, GGLWEPreparedFactory, GGLWEPreparedToMut, GGLWEPreparedToRef}, +}; + +#[derive(PartialEq, Eq)] +pub struct GLWESwitchingKeyPrepared { + pub(crate) key: GGLWEPrepared, + pub(crate) input_degree: Degree, // Degree of sk_in + pub(crate) output_degree: Degree, // Degree of sk_out +} + +impl GLWESwitchingKeyDegrees for GLWESwitchingKeyPrepared { + fn output_degree(&self) -> &Degree { + &self.output_degree + } + + fn input_degree(&self) -> &Degree { + &self.input_degree + } +} + +impl GLWESwitchingKeyDegreesMut for GLWESwitchingKeyPrepared { + fn output_degree(&mut self) -> &mut Degree { + &mut self.output_degree + } + + fn input_degree(&mut self) -> &mut Degree { + &mut self.input_degree + } +} + +impl LWEInfos for GLWESwitchingKeyPrepared { + fn n(&self) -> Degree { + self.key.n() + } + + fn base2k(&self) -> Base2K { + self.key.base2k() + } + + fn k(&self) -> TorusPrecision { + self.key.k() + } + + fn size(&self) -> usize { + self.key.size() + } +} + +impl GLWEInfos for GLWESwitchingKeyPrepared { + fn rank(&self) -> Rank { + self.rank_out() + } +} + +impl GGLWEInfos for GLWESwitchingKeyPrepared { + fn rank_in(&self) -> Rank { + self.key.rank_in() + } + + fn rank_out(&self) -> Rank { + self.key.rank_out() + } + + fn dsize(&self) -> Dsize { + self.key.dsize() + } + + fn dnum(&self) -> Dnum { + self.key.dnum() + } +} + +pub trait GLWESwitchingKeyPreparedFactory +where + Self: GGLWEPreparedFactory, +{ + fn alloc_glwe_switching_key_prepared( + &self, + base2k: Base2K, + k: TorusPrecision, + rank_in: Rank, + rank_out: Rank, + dnum: Dnum, + dsize: Dsize, + ) -> GLWESwitchingKeyPrepared, B> { + GLWESwitchingKeyPrepared::, B> { + key: self.alloc_gglwe_prepared(base2k, k, rank_in, rank_out, dnum, dsize), + input_degree: Degree(0), + output_degree: Degree(0), + } + } + + fn alloc_glwe_switching_key_prepared_from_infos(&self, infos: &A) -> GLWESwitchingKeyPrepared, B> + where + A: GGLWEInfos, + { + self.alloc_glwe_switching_key_prepared( + infos.base2k(), + infos.k(), + infos.rank_in(), + infos.rank_out(), + infos.dnum(), + infos.dsize(), + ) + } + + fn bytes_of_glwe_switching_key_prepared( + &self, + base2k: Base2K, + k: TorusPrecision, + rank_in: Rank, + rank_out: Rank, + dnum: Dnum, + dsize: Dsize, + ) -> usize { + self.bytes_of_gglwe_prepared(base2k, k, rank_in, rank_out, dnum, dsize) + } + + fn bytes_of_glwe_switching_key_prepared_from_infos(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + self.bytes_of_glwe_switching_key_prepared( + infos.base2k(), + infos.k(), + infos.rank_in(), + infos.rank_out(), + infos.dnum(), + infos.dsize(), + ) + } + + fn prepare_glwe_switching_key_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + self.prepare_gglwe_tmp_bytes(infos) + } + + fn prepare_glwe_switching(&self, res: &mut R, other: &O, scratch: &mut Scratch) + where + R: GGLWEPreparedToMut + GLWESwitchingKeyDegreesMut, + O: GGLWEToRef + GLWESwitchingKeyDegrees, + { + self.prepare_gglwe(res, other, scratch); + *res.input_degree() = *other.input_degree(); + *res.output_degree() = *other.output_degree(); + } +} + +impl GLWESwitchingKeyPreparedFactory for Module where Self: GGLWEPreparedFactory {} + +impl GLWESwitchingKeyPrepared, B> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: GGLWEInfos, + M: GLWESwitchingKeyPreparedFactory, + { + module.alloc_glwe_switching_key_prepared_from_infos(infos) + } + + pub fn alloc( + module: &M, + base2k: Base2K, + k: TorusPrecision, + rank_in: Rank, + rank_out: Rank, + dnum: Dnum, + dsize: Dsize, + ) -> Self + where + M: GLWESwitchingKeyPreparedFactory, + { + module.alloc_glwe_switching_key_prepared(base2k, k, rank_in, rank_out, dnum, dsize) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: GLWESwitchingKeyPreparedFactory, + { + module.bytes_of_glwe_switching_key_prepared_from_infos(infos) + } + + pub fn bytes_of( + module: &M, + base2k: Base2K, + k: TorusPrecision, + rank_in: Rank, + rank_out: Rank, + dnum: Dnum, + dsize: Dsize, + ) -> usize + where + M: GLWESwitchingKeyPreparedFactory, + { + module.bytes_of_glwe_switching_key_prepared(base2k, k, rank_in, rank_out, dnum, dsize) + } +} + +impl GLWESwitchingKeyPrepared { + pub fn prepare(&mut self, module: &M, other: &O, scratch: &mut Scratch) + where + O: GGLWEToRef + GLWESwitchingKeyDegrees, + M: GLWESwitchingKeyPreparedFactory, + { + module.prepare_glwe_switching(self, other, scratch); + } +} + +impl GLWESwitchingKeyPrepared, B> { + pub fn prepare_tmp_bytes(&self, module: &M) -> usize + where + M: GLWESwitchingKeyPreparedFactory, + { + module.prepare_glwe_switching_key_tmp_bytes(self) + } +} + +impl GGLWEPreparedToRef for GLWESwitchingKeyPrepared +where + GGLWEPrepared: GGLWEPreparedToRef, +{ + fn to_ref(&self) -> GGLWEPrepared<&[u8], BE> { + self.key.to_ref() + } +} + +impl GGLWEPreparedToMut for GLWESwitchingKeyPrepared +where + GGLWEPrepared: GGLWEPreparedToMut, +{ + fn to_mut(&mut self) -> GGLWEPrepared<&mut [u8], BE> { + self.key.to_mut() + } +} diff --git a/poulpy-core/src/layouts/prepared/glwe_tensor_key.rs b/poulpy-core/src/layouts/prepared/glwe_tensor_key.rs new file mode 100644 index 0000000..bd63c75 --- /dev/null +++ b/poulpy-core/src/layouts/prepared/glwe_tensor_key.rs @@ -0,0 +1,238 @@ +use poulpy_hal::layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}; + +use crate::layouts::{ + Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWEPrepared, GGLWEPreparedFactory, GGLWEPreparedToMut, GGLWEPreparedToRef, + GLWEInfos, GLWETensorKey, GLWETensorKeyToRef, LWEInfos, Rank, TorusPrecision, +}; + +#[derive(PartialEq, Eq)] +pub struct GLWETensorKeyPrepared { + pub(crate) keys: Vec>, +} + +impl LWEInfos for GLWETensorKeyPrepared { + fn n(&self) -> Degree { + self.keys[0].n() + } + + fn base2k(&self) -> Base2K { + self.keys[0].base2k() + } + + fn k(&self) -> TorusPrecision { + self.keys[0].k() + } + + fn size(&self) -> usize { + self.keys[0].size() + } +} + +impl GLWEInfos for GLWETensorKeyPrepared { + fn rank(&self) -> Rank { + self.rank_out() + } +} + +impl GGLWEInfos for GLWETensorKeyPrepared { + fn rank_in(&self) -> Rank { + self.rank_out() + } + + fn rank_out(&self) -> Rank { + self.keys[0].rank_out() + } + + fn dsize(&self) -> Dsize { + self.keys[0].dsize() + } + + fn dnum(&self) -> Dnum { + self.keys[0].dnum() + } +} + +pub trait GLWETensorKeyPreparedFactory +where + Self: GGLWEPreparedFactory, +{ + fn alloc_tensor_key_prepared( + &self, + base2k: Base2K, + k: TorusPrecision, + dnum: Dnum, + dsize: Dsize, + rank: Rank, + ) -> GLWETensorKeyPrepared, B> { + let pairs: u32 = (((rank.as_u32() + 1) * rank.as_u32()) >> 1).max(1); + GLWETensorKeyPrepared { + keys: (0..pairs) + .map(|_| self.alloc_gglwe_prepared(base2k, k, Rank(1), rank, dnum, dsize)) + .collect(), + } + } + + fn alloc_tensor_key_prepared_from_infos(&self, infos: &A) -> GLWETensorKeyPrepared, B> + where + A: GGLWEInfos, + { + assert_eq!( + infos.rank_in(), + infos.rank_out(), + "rank_in != rank_out is not supported for TensorKeyPrepared" + ); + self.alloc_tensor_key_prepared( + infos.base2k(), + infos.k(), + infos.dnum(), + infos.dsize(), + infos.rank_out(), + ) + } + + fn bytes_of_tensor_key_prepared(&self, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { + let pairs: usize = (((rank.0 + 1) * rank.0) >> 1).max(1) as usize; + pairs * self.bytes_of_gglwe_prepared(base2k, k, Rank(1), rank, dnum, dsize) + } + + fn bytes_of_tensor_key_prepared_from_infos(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + self.bytes_of_tensor_key_prepared( + infos.base2k(), + infos.k(), + infos.rank(), + infos.dnum(), + infos.dsize(), + ) + } + + fn prepare_tensor_key_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + self.prepare_gglwe_tmp_bytes(infos) + } + + fn prepare_tensor_key(&self, res: &mut R, other: &O, scratch: &mut Scratch) + where + R: GLWETensorKeyPreparedToMut, + O: GLWETensorKeyToRef, + { + let mut res: GLWETensorKeyPrepared<&mut [u8], B> = res.to_mut(); + let other: GLWETensorKey<&[u8]> = other.to_ref(); + + assert_eq!(res.keys.len(), other.keys.len()); + + for (a, b) in res.keys.iter_mut().zip(other.keys.iter()) { + self.prepare_gglwe(a, b, scratch); + } + } +} + +impl GLWETensorKeyPreparedFactory for Module where Module: GGLWEPreparedFactory {} + +impl GLWETensorKeyPrepared, B> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: GGLWEInfos, + M: GLWETensorKeyPreparedFactory, + { + module.alloc_tensor_key_prepared_from_infos(infos) + } + + pub fn alloc_with(module: &M, base2k: Base2K, k: TorusPrecision, dnum: Dnum, dsize: Dsize, rank: Rank) -> Self + where + M: GLWETensorKeyPreparedFactory, + { + module.alloc_tensor_key_prepared(base2k, k, dnum, dsize, rank) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: GLWETensorKeyPreparedFactory, + { + module.bytes_of_tensor_key_prepared_from_infos(infos) + } + + pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize + where + M: GLWETensorKeyPreparedFactory, + { + module.bytes_of_tensor_key_prepared(base2k, k, rank, dnum, dsize) + } +} + +impl GLWETensorKeyPrepared { + // Returns a mutable reference to GGLWE_{s}(s[i] * s[j]) + pub fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GGLWEPrepared { + if i > j { + std::mem::swap(&mut i, &mut j); + }; + let rank: usize = self.rank_out().into(); + &mut self.keys[i * rank + j - (i * (i + 1) / 2)] + } +} + +impl GLWETensorKeyPrepared { + // Returns a reference to GGLWE_{s}(s[i] * s[j]) + pub fn at(&self, mut i: usize, mut j: usize) -> &GGLWEPrepared { + if i > j { + std::mem::swap(&mut i, &mut j); + }; + let rank: usize = self.rank_out().into(); + &self.keys[i * rank + j - (i * (i + 1) / 2)] + } +} + +impl GLWETensorKeyPrepared, B> { + pub fn prepare_tmp_bytes(&self, module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: GLWETensorKeyPreparedFactory, + { + module.prepare_tensor_key_tmp_bytes(infos) + } +} + +impl GLWETensorKeyPrepared { + pub fn prepare(&mut self, module: &M, other: &O, scratch: &mut Scratch) + where + O: GLWETensorKeyToRef, + M: GLWETensorKeyPreparedFactory, + { + module.prepare_tensor_key(self, other, scratch); + } +} + +pub trait GLWETensorKeyPreparedToMut { + fn to_mut(&mut self) -> GLWETensorKeyPrepared<&mut [u8], B>; +} + +impl GLWETensorKeyPreparedToMut for GLWETensorKeyPrepared +where + GGLWEPrepared: GGLWEPreparedToMut, +{ + fn to_mut(&mut self) -> GLWETensorKeyPrepared<&mut [u8], B> { + GLWETensorKeyPrepared { + keys: self.keys.iter_mut().map(|c| c.to_mut()).collect(), + } + } +} + +pub trait GLWETensorKeyPreparedToRef { + fn to_ref(&self) -> GLWETensorKeyPrepared<&[u8], B>; +} + +impl GLWETensorKeyPreparedToRef for GLWETensorKeyPrepared +where + GGLWEPrepared: GGLWEPreparedToRef, +{ + fn to_ref(&self) -> GLWETensorKeyPrepared<&[u8], B> { + GLWETensorKeyPrepared { + keys: self.keys.iter().map(|c| c.to_ref()).collect(), + } + } +} diff --git a/poulpy-core/src/layouts/prepared/glwe_to_lwe_ksk.rs b/poulpy-core/src/layouts/prepared/glwe_to_lwe_ksk.rs deleted file mode 100644 index f241c6d..0000000 --- a/poulpy-core/src/layouts/prepared/glwe_to_lwe_ksk.rs +++ /dev/null @@ -1,143 +0,0 @@ -use poulpy_hal::{ - api::{VmpPMatAlloc, VmpPMatAllocBytes, VmpPrepare}, - layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}, -}; - -use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWEInfos, GLWEInfos, GLWEToLWEKey, LWEInfos, Rank, TorusPrecision, - prepared::{GGLWESwitchingKeyPrepared, Prepare, PrepareAlloc, PrepareScratchSpace}, -}; - -#[derive(PartialEq, Eq)] -pub struct GLWEToLWESwitchingKeyPrepared(pub(crate) GGLWESwitchingKeyPrepared); - -impl LWEInfos for GLWEToLWESwitchingKeyPrepared { - fn base2k(&self) -> Base2K { - self.0.base2k() - } - - fn k(&self) -> TorusPrecision { - self.0.k() - } - - fn n(&self) -> Degree { - self.0.n() - } - - fn size(&self) -> usize { - self.0.size() - } -} - -impl GLWEInfos for GLWEToLWESwitchingKeyPrepared { - fn rank(&self) -> Rank { - self.rank_out() - } -} - -impl GGLWEInfos for GLWEToLWESwitchingKeyPrepared { - fn rank_in(&self) -> Rank { - self.0.rank_in() - } - - fn dsize(&self) -> Dsize { - self.0.dsize() - } - - fn rank_out(&self) -> Rank { - self.0.rank_out() - } - - fn dnum(&self) -> Dnum { - self.0.dnum() - } -} - -impl GLWEToLWESwitchingKeyPrepared, B> { - pub fn alloc(module: &Module, infos: &A) -> Self - where - A: GGLWEInfos, - Module: VmpPMatAlloc, - { - debug_assert_eq!( - infos.rank_out().0, - 1, - "rank_out > 1 is not supported for GLWEToLWESwitchingKeyPrepared" - ); - debug_assert_eq!( - infos.dsize().0, - 1, - "dsize > 1 is not supported for GLWEToLWESwitchingKeyPrepared" - ); - Self(GGLWESwitchingKeyPrepared::alloc(module, infos)) - } - - pub fn alloc_with(module: &Module, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> Self - where - Module: VmpPMatAlloc, - { - Self(GGLWESwitchingKeyPrepared::alloc_with( - module, - base2k, - k, - rank_in, - Rank(1), - dnum, - Dsize(1), - )) - } - - pub fn alloc_bytes(module: &Module, infos: &A) -> usize - where - A: GGLWEInfos, - Module: VmpPMatAllocBytes, - { - debug_assert_eq!( - infos.rank_out().0, - 1, - "rank_out > 1 is not supported for GLWEToLWESwitchingKeyPrepared" - ); - debug_assert_eq!( - infos.dsize().0, - 1, - "dsize > 1 is not supported for GLWEToLWESwitchingKeyPrepared" - ); - GGLWESwitchingKeyPrepared::alloc_bytes(module, infos) - } - - pub fn alloc_bytes_with(module: &Module, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> usize - where - Module: VmpPMatAllocBytes, - { - GGLWESwitchingKeyPrepared::alloc_bytes_with(module, base2k, k, rank_in, Rank(1), dnum, Dsize(1)) - } -} - -impl PrepareScratchSpace for GLWEToLWESwitchingKeyPrepared, B> -where - GGLWESwitchingKeyPrepared, B>: PrepareScratchSpace, -{ - fn prepare_scratch_space(module: &Module, infos: &A) -> usize { - GGLWESwitchingKeyPrepared::prepare_scratch_space(module, infos) - } -} - -impl PrepareAlloc, B>> for GLWEToLWEKey -where - Module: VmpPrepare + VmpPMatAlloc, -{ - fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> GLWEToLWESwitchingKeyPrepared, B> { - let mut ksk_prepared: GLWEToLWESwitchingKeyPrepared, B> = GLWEToLWESwitchingKeyPrepared::alloc(module, self); - ksk_prepared.prepare(module, self, scratch); - ksk_prepared - } -} - -impl Prepare> for GLWEToLWESwitchingKeyPrepared -where - Module: VmpPrepare, -{ - fn prepare(&mut self, module: &Module, other: &GLWEToLWEKey, scratch: &mut Scratch) { - self.0.prepare(module, &other.0, scratch); - } -} diff --git a/poulpy-core/src/layouts/prepared/glwe_to_lwe_switching_key.rs b/poulpy-core/src/layouts/prepared/glwe_to_lwe_switching_key.rs new file mode 100644 index 0000000..6edac5e --- /dev/null +++ b/poulpy-core/src/layouts/prepared/glwe_to_lwe_switching_key.rs @@ -0,0 +1,211 @@ +use poulpy_hal::layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}; + +use crate::layouts::{ + Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWEPrepared, GGLWEPreparedToMut, GGLWEPreparedToRef, GGLWEToRef, GLWEInfos, + GLWESwitchingKeyDegrees, GLWESwitchingKeyDegreesMut, LWEInfos, Rank, TorusPrecision, + prepared::{GLWESwitchingKeyPrepared, GLWESwitchingKeyPreparedFactory}, +}; + +#[derive(PartialEq, Eq)] +pub struct GLWEToLWESwitchingKeyPrepared(pub(crate) GLWESwitchingKeyPrepared); + +impl LWEInfos for GLWEToLWESwitchingKeyPrepared { + fn base2k(&self) -> Base2K { + self.0.base2k() + } + + fn k(&self) -> TorusPrecision { + self.0.k() + } + + fn n(&self) -> Degree { + self.0.n() + } + + fn size(&self) -> usize { + self.0.size() + } +} + +impl GLWEInfos for GLWEToLWESwitchingKeyPrepared { + fn rank(&self) -> Rank { + self.rank_out() + } +} + +impl GGLWEInfos for GLWEToLWESwitchingKeyPrepared { + fn rank_in(&self) -> Rank { + self.0.rank_in() + } + + fn dsize(&self) -> Dsize { + self.0.dsize() + } + + fn rank_out(&self) -> Rank { + self.0.rank_out() + } + + fn dnum(&self) -> Dnum { + self.0.dnum() + } +} + +pub trait GLWEToLWESwitchingKeyPreparedFactory +where + Self: GLWESwitchingKeyPreparedFactory, +{ + fn alloc_glwe_to_lwe_switching_key_prepared( + &self, + base2k: Base2K, + k: TorusPrecision, + rank_in: Rank, + dnum: Dnum, + ) -> GLWEToLWESwitchingKeyPrepared, B> { + GLWEToLWESwitchingKeyPrepared(self.alloc_glwe_switching_key_prepared(base2k, k, rank_in, Rank(1), dnum, Dsize(1))) + } + fn alloc_glwe_to_lwe_switching_key_prepared_from_infos(&self, infos: &A) -> GLWEToLWESwitchingKeyPrepared, B> + where + A: GGLWEInfos, + { + debug_assert_eq!( + infos.rank_out().0, + 1, + "rank_out > 1 is not supported for GLWEToLWESwitchingKeyPrepared" + ); + debug_assert_eq!( + infos.dsize().0, + 1, + "dsize > 1 is not supported for GLWEToLWESwitchingKeyPrepared" + ); + self.alloc_glwe_to_lwe_switching_key_prepared(infos.base2k(), infos.k(), infos.rank_in(), infos.dnum()) + } + + fn bytes_of_glwe_to_lwe_switching_key_prepared(&self, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> usize { + self.bytes_of_glwe_switching_key_prepared(base2k, k, rank_in, Rank(1), dnum, Dsize(1)) + } + + fn bytes_of_glwe_to_lwe_switching_key_prepared_from_infos(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + debug_assert_eq!( + infos.rank_out().0, + 1, + "rank_out > 1 is not supported for GLWEToLWESwitchingKeyPrepared" + ); + debug_assert_eq!( + infos.dsize().0, + 1, + "dsize > 1 is not supported for GLWEToLWESwitchingKeyPrepared" + ); + self.bytes_of_glwe_to_lwe_switching_key_prepared(infos.base2k(), infos.k(), infos.rank_in(), infos.dnum()) + } + + fn prepare_glwe_to_lwe_switching_key_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + self.prepare_glwe_switching_key_tmp_bytes(infos) + } + + fn prepare_glwe_to_lwe_switching_key(&self, res: &mut R, other: &O, scratch: &mut Scratch) + where + R: GGLWEPreparedToMut + GLWESwitchingKeyDegreesMut, + O: GGLWEToRef + GLWESwitchingKeyDegrees, + { + self.prepare_glwe_switching(res, other, scratch); + } +} + +impl GLWEToLWESwitchingKeyPreparedFactory for Module where Self: GLWESwitchingKeyPreparedFactory {} + +impl GLWEToLWESwitchingKeyPrepared, B> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: GGLWEInfos, + M: GLWEToLWESwitchingKeyPreparedFactory, + { + module.alloc_glwe_to_lwe_switching_key_prepared_from_infos(infos) + } + + pub fn alloc(module: &M, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> Self + where + M: GLWEToLWESwitchingKeyPreparedFactory, + { + module.alloc_glwe_to_lwe_switching_key_prepared(base2k, k, rank_in, dnum) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: GLWEToLWESwitchingKeyPreparedFactory, + { + module.bytes_of_glwe_to_lwe_switching_key_prepared_from_infos(infos) + } + + pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> usize + where + M: GLWEToLWESwitchingKeyPreparedFactory, + { + module.bytes_of_glwe_to_lwe_switching_key_prepared(base2k, k, rank_in, dnum) + } +} + +impl GLWEToLWESwitchingKeyPrepared, B> { + pub fn prepare_tmp_bytes(&self, module: &M, infos: &A) + where + A: GGLWEInfos, + M: GLWEToLWESwitchingKeyPreparedFactory, + { + module.prepare_glwe_to_lwe_switching_key_tmp_bytes(infos); + } +} + +impl GLWEToLWESwitchingKeyPrepared { + pub fn prepare(&mut self, module: &M, other: &O, scratch: &mut Scratch) + where + O: GGLWEToRef + GLWESwitchingKeyDegrees, + M: GLWEToLWESwitchingKeyPreparedFactory, + { + module.prepare_glwe_to_lwe_switching_key(self, other, scratch); + } +} + +impl GGLWEPreparedToRef for GLWEToLWESwitchingKeyPrepared +where + GLWESwitchingKeyPrepared: GGLWEPreparedToRef, +{ + fn to_ref(&self) -> GGLWEPrepared<&[u8], B> { + self.0.to_ref() + } +} + +impl GGLWEPreparedToMut for GLWEToLWESwitchingKeyPrepared +where + GLWESwitchingKeyPrepared: GGLWEPreparedToRef, +{ + fn to_mut(&mut self) -> GGLWEPrepared<&mut [u8], B> { + self.0.to_mut() + } +} + +impl GLWESwitchingKeyDegreesMut for GLWEToLWESwitchingKeyPrepared { + fn input_degree(&mut self) -> &mut Degree { + &mut self.0.input_degree + } + + fn output_degree(&mut self) -> &mut Degree { + &mut self.0.output_degree + } +} + +impl GLWESwitchingKeyDegrees for GLWEToLWESwitchingKeyPrepared { + fn input_degree(&self) -> &Degree { + &self.0.input_degree + } + + fn output_degree(&self) -> &Degree { + &self.0.output_degree + } +} diff --git a/poulpy-core/src/layouts/prepared/lwe_ksk.rs b/poulpy-core/src/layouts/prepared/lwe_ksk.rs deleted file mode 100644 index 5f0cf14..0000000 --- a/poulpy-core/src/layouts/prepared/lwe_ksk.rs +++ /dev/null @@ -1,152 +0,0 @@ -use poulpy_hal::{ - api::{VmpPMatAlloc, VmpPMatAllocBytes, VmpPrepare}, - layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}, -}; - -use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWEInfos, GLWEInfos, LWEInfos, LWESwitchingKey, Rank, TorusPrecision, - prepared::{GGLWESwitchingKeyPrepared, Prepare, PrepareAlloc, PrepareScratchSpace}, -}; - -#[derive(PartialEq, Eq)] -pub struct LWESwitchingKeyPrepared(pub(crate) GGLWESwitchingKeyPrepared); - -impl LWEInfos for LWESwitchingKeyPrepared { - fn base2k(&self) -> Base2K { - self.0.base2k() - } - - fn k(&self) -> TorusPrecision { - self.0.k() - } - - fn n(&self) -> Degree { - self.0.n() - } - - fn size(&self) -> usize { - self.0.size() - } -} -impl GLWEInfos for LWESwitchingKeyPrepared { - fn rank(&self) -> Rank { - self.rank_out() - } -} - -impl GGLWEInfos for LWESwitchingKeyPrepared { - fn dsize(&self) -> Dsize { - self.0.dsize() - } - - fn rank_in(&self) -> Rank { - self.0.rank_in() - } - - fn rank_out(&self) -> Rank { - self.0.rank_out() - } - - fn dnum(&self) -> Dnum { - self.0.dnum() - } -} - -impl LWESwitchingKeyPrepared, B> { - pub fn alloc(module: &Module, infos: &A) -> Self - where - A: GGLWEInfos, - Module: VmpPMatAlloc, - { - debug_assert_eq!( - infos.dsize().0, - 1, - "dsize > 1 is not supported for LWESwitchingKey" - ); - debug_assert_eq!( - infos.rank_in().0, - 1, - "rank_in > 1 is not supported for LWESwitchingKey" - ); - debug_assert_eq!( - infos.rank_out().0, - 1, - "rank_out > 1 is not supported for LWESwitchingKey" - ); - Self(GGLWESwitchingKeyPrepared::alloc(module, infos)) - } - - pub fn alloc_with(module: &Module, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> Self - where - Module: VmpPMatAlloc, - { - Self(GGLWESwitchingKeyPrepared::alloc_with( - module, - base2k, - k, - Rank(1), - Rank(1), - dnum, - Dsize(1), - )) - } - - pub fn alloc_bytes(module: &Module, infos: &A) -> usize - where - A: GGLWEInfos, - Module: VmpPMatAllocBytes, - { - debug_assert_eq!( - infos.dsize().0, - 1, - "dsize > 1 is not supported for LWESwitchingKey" - ); - debug_assert_eq!( - infos.rank_in().0, - 1, - "rank_in > 1 is not supported for LWESwitchingKey" - ); - debug_assert_eq!( - infos.rank_out().0, - 1, - "rank_out > 1 is not supported for LWESwitchingKey" - ); - GGLWESwitchingKeyPrepared::alloc_bytes(module, infos) - } - - pub fn alloc_bytes_with(module: &Module, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> usize - where - Module: VmpPMatAllocBytes, - { - GGLWESwitchingKeyPrepared::alloc_bytes_with(module, base2k, k, Rank(1), Rank(1), dnum, Dsize(1)) - } -} - -impl PrepareScratchSpace for LWESwitchingKeyPrepared, B> -where - GGLWESwitchingKeyPrepared, B>: PrepareScratchSpace, -{ - fn prepare_scratch_space(module: &Module, infos: &A) -> usize { - GGLWESwitchingKeyPrepared::prepare_scratch_space(module, infos) - } -} - -impl PrepareAlloc, B>> for LWESwitchingKey -where - Module: VmpPrepare + VmpPMatAlloc, -{ - fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> LWESwitchingKeyPrepared, B> { - let mut ksk_prepared: LWESwitchingKeyPrepared, B> = LWESwitchingKeyPrepared::alloc(module, self); - ksk_prepared.prepare(module, self, scratch); - ksk_prepared - } -} - -impl Prepare> for LWESwitchingKeyPrepared -where - Module: VmpPrepare, -{ - fn prepare(&mut self, module: &Module, other: &LWESwitchingKey, scratch: &mut Scratch) { - self.0.prepare(module, &other.0, scratch); - } -} diff --git a/poulpy-core/src/layouts/prepared/lwe_switching_key.rs b/poulpy-core/src/layouts/prepared/lwe_switching_key.rs new file mode 100644 index 0000000..327d001 --- /dev/null +++ b/poulpy-core/src/layouts/prepared/lwe_switching_key.rs @@ -0,0 +1,209 @@ +use poulpy_hal::layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}; + +use crate::layouts::{ + Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWEPrepared, GGLWEPreparedToMut, GGLWEPreparedToRef, GGLWEToRef, GLWEInfos, + GLWESwitchingKeyDegrees, GLWESwitchingKeyDegreesMut, LWEInfos, Rank, TorusPrecision, + prepared::{GLWESwitchingKeyPrepared, GLWESwitchingKeyPreparedFactory}, +}; + +#[derive(PartialEq, Eq)] +pub struct LWESwitchingKeyPrepared(pub(crate) GLWESwitchingKeyPrepared); + +impl LWEInfos for LWESwitchingKeyPrepared { + fn base2k(&self) -> Base2K { + self.0.base2k() + } + + fn k(&self) -> TorusPrecision { + self.0.k() + } + + fn n(&self) -> Degree { + self.0.n() + } + + fn size(&self) -> usize { + self.0.size() + } +} +impl GLWEInfos for LWESwitchingKeyPrepared { + fn rank(&self) -> Rank { + self.rank_out() + } +} + +impl GGLWEInfos for LWESwitchingKeyPrepared { + fn dsize(&self) -> Dsize { + self.0.dsize() + } + + fn rank_in(&self) -> Rank { + self.0.rank_in() + } + + fn rank_out(&self) -> Rank { + self.0.rank_out() + } + + fn dnum(&self) -> Dnum { + self.0.dnum() + } +} + +pub trait LWESwitchingKeyPreparedFactory +where + Self: GLWESwitchingKeyPreparedFactory, +{ + fn alloc_lwe_switching_key_prepared( + &self, + base2k: Base2K, + k: TorusPrecision, + dnum: Dnum, + ) -> LWESwitchingKeyPrepared, B> { + LWESwitchingKeyPrepared(self.alloc_glwe_switching_key_prepared(base2k, k, Rank(1), Rank(1), dnum, Dsize(1))) + } + + fn alloc_lwe_switching_key_prepared_from_infos(&self, infos: &A) -> LWESwitchingKeyPrepared, B> + where + A: GGLWEInfos, + { + debug_assert_eq!( + infos.dsize().0, + 1, + "dsize > 1 is not supported for LWESwitchingKey" + ); + debug_assert_eq!( + infos.rank_in().0, + 1, + "rank_in > 1 is not supported for LWESwitchingKey" + ); + debug_assert_eq!( + infos.rank_out().0, + 1, + "rank_out > 1 is not supported for LWESwitchingKey" + ); + self.alloc_lwe_switching_key_prepared(infos.base2k(), infos.k(), infos.dnum()) + } + + fn bytes_of_lwe_switching_key_prepared(&self, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> usize { + self.bytes_of_glwe_switching_key_prepared(base2k, k, Rank(1), Rank(1), dnum, Dsize(1)) + } + + fn bytes_of_lwe_switching_key_prepared_from_infos(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + debug_assert_eq!( + infos.dsize().0, + 1, + "dsize > 1 is not supported for LWESwitchingKey" + ); + debug_assert_eq!( + infos.rank_in().0, + 1, + "rank_in > 1 is not supported for LWESwitchingKey" + ); + debug_assert_eq!( + infos.rank_out().0, + 1, + "rank_out > 1 is not supported for LWESwitchingKey" + ); + self.bytes_of_lwe_switching_key_prepared(infos.base2k(), infos.k(), infos.dnum()) + } + + fn prepare_lwe_switching_key_tmp_bytes(&self, infos: &A) + where + A: GGLWEInfos, + { + self.prepare_glwe_switching_key_tmp_bytes(infos); + } + fn prepare_lwe_switching_key(&self, res: &mut R, other: &O, scratch: &mut Scratch) + where + R: GGLWEPreparedToMut + GLWESwitchingKeyDegreesMut, + O: GGLWEToRef + GLWESwitchingKeyDegrees, + { + self.prepare_glwe_switching(res, other, scratch); + } +} + +impl LWESwitchingKeyPreparedFactory for Module where Self: GLWESwitchingKeyPreparedFactory {} + +impl LWESwitchingKeyPrepared, B> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: GGLWEInfos, + M: LWESwitchingKeyPreparedFactory, + { + module.alloc_lwe_switching_key_prepared_from_infos(infos) + } + + pub fn alloc(module: &M, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> Self + where + M: LWESwitchingKeyPreparedFactory, + { + module.alloc_lwe_switching_key_prepared(base2k, k, dnum) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: LWESwitchingKeyPreparedFactory, + { + module.bytes_of_lwe_switching_key_prepared_from_infos(infos) + } + + pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> usize + where + M: LWESwitchingKeyPreparedFactory, + { + module.bytes_of_lwe_switching_key_prepared(base2k, k, dnum) + } +} + +impl LWESwitchingKeyPrepared, B> { + pub fn prepare_tmp_bytes(&self, module: &M, infos: &A) + where + A: GGLWEInfos, + M: LWESwitchingKeyPreparedFactory, + { + module.prepare_lwe_switching_key_tmp_bytes(infos); + } +} + +impl LWESwitchingKeyPrepared { + pub fn prepare(&mut self, module: &M, other: &O, scratch: &mut Scratch) + where + O: GGLWEToRef + GLWESwitchingKeyDegrees, + M: LWESwitchingKeyPreparedFactory, + { + module.prepare_lwe_switching_key(self, other, scratch); + } +} + +impl GGLWEPreparedToRef for LWESwitchingKeyPrepared +where + GGLWEPrepared: GGLWEPreparedToRef, +{ + fn to_ref(&self) -> GGLWEPrepared<&[u8], B> { + self.0.to_ref() + } +} + +impl GGLWEPreparedToMut for LWESwitchingKeyPrepared +where + GGLWEPrepared: GGLWEPreparedToMut, +{ + fn to_mut(&mut self) -> GGLWEPrepared<&mut [u8], B> { + self.0.to_mut() + } +} + +impl GLWESwitchingKeyDegreesMut for LWESwitchingKeyPrepared { + fn input_degree(&mut self) -> &mut Degree { + &mut self.0.input_degree + } + + fn output_degree(&mut self) -> &mut Degree { + &mut self.0.output_degree + } +} diff --git a/poulpy-core/src/layouts/prepared/lwe_to_glwe_ksk.rs b/poulpy-core/src/layouts/prepared/lwe_to_glwe_ksk.rs deleted file mode 100644 index 7c2023a..0000000 --- a/poulpy-core/src/layouts/prepared/lwe_to_glwe_ksk.rs +++ /dev/null @@ -1,144 +0,0 @@ -use poulpy_hal::{ - api::{VmpPMatAlloc, VmpPMatAllocBytes, VmpPrepare}, - layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}, -}; - -use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWEInfos, GLWEInfos, LWEInfos, LWEToGLWESwitchingKey, Rank, TorusPrecision, - prepared::{GGLWESwitchingKeyPrepared, Prepare, PrepareAlloc, PrepareScratchSpace}, -}; - -/// A special [GLWESwitchingKey] required to for the conversion from [LWECiphertext] to [GLWECiphertext]. -#[derive(PartialEq, Eq)] -pub struct LWEToGLWESwitchingKeyPrepared(pub(crate) GGLWESwitchingKeyPrepared); - -impl LWEInfos for LWEToGLWESwitchingKeyPrepared { - fn base2k(&self) -> Base2K { - self.0.base2k() - } - - fn k(&self) -> TorusPrecision { - self.0.k() - } - - fn n(&self) -> Degree { - self.0.n() - } - - fn size(&self) -> usize { - self.0.size() - } -} - -impl GLWEInfos for LWEToGLWESwitchingKeyPrepared { - fn rank(&self) -> Rank { - self.rank_out() - } -} - -impl GGLWEInfos for LWEToGLWESwitchingKeyPrepared { - fn dsize(&self) -> Dsize { - self.0.dsize() - } - - fn rank_in(&self) -> Rank { - self.0.rank_in() - } - - fn rank_out(&self) -> Rank { - self.0.rank_out() - } - - fn dnum(&self) -> Dnum { - self.0.dnum() - } -} - -impl LWEToGLWESwitchingKeyPrepared, B> { - pub fn alloc(module: &Module, infos: &A) -> Self - where - A: GGLWEInfos, - Module: VmpPMatAlloc, - { - debug_assert_eq!( - infos.rank_in().0, - 1, - "rank_in > 1 is not supported for LWEToGLWESwitchingKey" - ); - debug_assert_eq!( - infos.dsize().0, - 1, - "dsize > 1 is not supported for LWEToGLWESwitchingKey" - ); - Self(GGLWESwitchingKeyPrepared::alloc(module, infos)) - } - - pub fn alloc_with(module: &Module, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> Self - where - Module: VmpPMatAlloc, - { - Self(GGLWESwitchingKeyPrepared::alloc_with( - module, - base2k, - k, - Rank(1), - rank_out, - dnum, - Dsize(1), - )) - } - - pub fn alloc_bytes(module: &Module, infos: &A) -> usize - where - A: GGLWEInfos, - Module: VmpPMatAllocBytes, - { - debug_assert_eq!( - infos.rank_in().0, - 1, - "rank_in > 1 is not supported for LWEToGLWESwitchingKey" - ); - debug_assert_eq!( - infos.dsize().0, - 1, - "dsize > 1 is not supported for LWEToGLWESwitchingKey" - ); - GGLWESwitchingKeyPrepared::alloc_bytes(module, infos) - } - - pub fn alloc_bytes_with(module: &Module, base2k: Base2K, k: TorusPrecision, dnum: Dnum, rank_out: Rank) -> usize - where - Module: VmpPMatAllocBytes, - { - GGLWESwitchingKeyPrepared::alloc_bytes_with(module, base2k, k, Rank(1), rank_out, dnum, Dsize(1)) - } -} - -impl PrepareScratchSpace for LWEToGLWESwitchingKeyPrepared, B> -where - GGLWESwitchingKeyPrepared, B>: PrepareScratchSpace, -{ - fn prepare_scratch_space(module: &Module, infos: &A) -> usize { - GGLWESwitchingKeyPrepared::prepare_scratch_space(module, infos) - } -} - -impl PrepareAlloc, B>> for LWEToGLWESwitchingKey -where - Module: VmpPrepare + VmpPMatAlloc, -{ - fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> LWEToGLWESwitchingKeyPrepared, B> { - let mut ksk_prepared: LWEToGLWESwitchingKeyPrepared, B> = LWEToGLWESwitchingKeyPrepared::alloc(module, self); - ksk_prepared.prepare(module, self, scratch); - ksk_prepared - } -} - -impl Prepare> for LWEToGLWESwitchingKeyPrepared -where - Module: VmpPrepare, -{ - fn prepare(&mut self, module: &Module, other: &LWEToGLWESwitchingKey, scratch: &mut Scratch) { - self.0.prepare(module, &other.0, scratch); - } -} diff --git a/poulpy-core/src/layouts/prepared/lwe_to_glwe_switching_key.rs b/poulpy-core/src/layouts/prepared/lwe_to_glwe_switching_key.rs new file mode 100644 index 0000000..30ed131 --- /dev/null +++ b/poulpy-core/src/layouts/prepared/lwe_to_glwe_switching_key.rs @@ -0,0 +1,208 @@ +use poulpy_hal::layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}; + +use crate::layouts::{ + Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWEPrepared, GGLWEPreparedToMut, GGLWEPreparedToRef, GGLWEToRef, GLWEInfos, + GLWESwitchingKeyDegrees, GLWESwitchingKeyDegreesMut, LWEInfos, Rank, TorusPrecision, + prepared::{GLWESwitchingKeyPrepared, GLWESwitchingKeyPreparedFactory}, +}; + +/// A special [GLWESwitchingKey] required to for the conversion from [LWE] to [GLWE]. +#[derive(PartialEq, Eq)] +pub struct LWEToGLWESwitchingKeyPrepared(pub(crate) GLWESwitchingKeyPrepared); + +impl LWEInfos for LWEToGLWESwitchingKeyPrepared { + fn base2k(&self) -> Base2K { + self.0.base2k() + } + + fn k(&self) -> TorusPrecision { + self.0.k() + } + + fn n(&self) -> Degree { + self.0.n() + } + + fn size(&self) -> usize { + self.0.size() + } +} + +impl GLWEInfos for LWEToGLWESwitchingKeyPrepared { + fn rank(&self) -> Rank { + self.rank_out() + } +} + +impl GGLWEInfos for LWEToGLWESwitchingKeyPrepared { + fn dsize(&self) -> Dsize { + self.0.dsize() + } + + fn rank_in(&self) -> Rank { + self.0.rank_in() + } + + fn rank_out(&self) -> Rank { + self.0.rank_out() + } + + fn dnum(&self) -> Dnum { + self.0.dnum() + } +} + +pub trait LWEToGLWESwitchingKeyPreparedFactory +where + Self: GLWESwitchingKeyPreparedFactory, +{ + fn alloc_lwe_to_glwe_switching_key_prepared( + &self, + base2k: Base2K, + k: TorusPrecision, + rank_out: Rank, + dnum: Dnum, + ) -> LWEToGLWESwitchingKeyPrepared, B> { + LWEToGLWESwitchingKeyPrepared(self.alloc_glwe_switching_key_prepared(base2k, k, Rank(1), rank_out, dnum, Dsize(1))) + } + fn alloc_lwe_to_glwe_switching_key_prepared_from_infos(&self, infos: &A) -> LWEToGLWESwitchingKeyPrepared, B> + where + A: GGLWEInfos, + { + debug_assert_eq!( + infos.rank_in().0, + 1, + "rank_in > 1 is not supported for LWEToGLWESwitchingKey" + ); + debug_assert_eq!( + infos.dsize().0, + 1, + "dsize > 1 is not supported for LWEToGLWESwitchingKey" + ); + self.alloc_lwe_to_glwe_switching_key_prepared(infos.base2k(), infos.k(), infos.rank_out(), infos.dnum()) + } + + fn bytes_of_lwe_to_glwe_switching_key_prepared( + &self, + base2k: Base2K, + k: TorusPrecision, + rank_out: Rank, + dnum: Dnum, + ) -> usize { + self.bytes_of_glwe_switching_key_prepared(base2k, k, Rank(1), rank_out, dnum, Dsize(1)) + } + + fn bytes_of_lwe_to_glwe_switching_key_prepared_from_infos(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + debug_assert_eq!( + infos.rank_in().0, + 1, + "rank_in > 1 is not supported for LWEToGLWESwitchingKey" + ); + debug_assert_eq!( + infos.dsize().0, + 1, + "dsize > 1 is not supported for LWEToGLWESwitchingKey" + ); + self.bytes_of_lwe_to_glwe_switching_key_prepared(infos.base2k(), infos.k(), infos.rank_out(), infos.dnum()) + } + + fn prepare_lwe_to_glwe_switching_key_tmp_bytes(&self, infos: &A) + where + A: GGLWEInfos, + { + self.prepare_glwe_switching_key_tmp_bytes(infos); + } + + fn prepare_lwe_to_glwe_switching_key(&self, res: &mut R, other: &O, scratch: &mut Scratch) + where + R: GGLWEPreparedToMut + GLWESwitchingKeyDegreesMut, + O: GGLWEToRef + GLWESwitchingKeyDegrees, + { + self.prepare_glwe_switching(res, other, scratch); + } +} + +impl LWEToGLWESwitchingKeyPreparedFactory for Module where Self: GLWESwitchingKeyPreparedFactory {} + +impl LWEToGLWESwitchingKeyPrepared, B> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: GGLWEInfos, + M: LWEToGLWESwitchingKeyPreparedFactory, + { + module.alloc_lwe_to_glwe_switching_key_prepared_from_infos(infos) + } + + pub fn alloc(module: &M, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> Self + where + M: LWEToGLWESwitchingKeyPreparedFactory, + { + module.alloc_lwe_to_glwe_switching_key_prepared(base2k, k, rank_out, dnum) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: LWEToGLWESwitchingKeyPreparedFactory, + { + module.bytes_of_lwe_to_glwe_switching_key_prepared_from_infos(infos) + } + + pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> usize + where + M: LWEToGLWESwitchingKeyPreparedFactory, + { + module.bytes_of_lwe_to_glwe_switching_key_prepared(base2k, k, rank_out, dnum) + } +} + +impl LWEToGLWESwitchingKeyPrepared, B> { + pub fn prepare_tmp_bytes(&self, module: &M, infos: &A) + where + A: GGLWEInfos, + M: LWEToGLWESwitchingKeyPreparedFactory, + { + module.prepare_lwe_to_glwe_switching_key_tmp_bytes(infos); + } +} + +impl LWEToGLWESwitchingKeyPrepared { + pub fn prepare(&mut self, module: &M, other: &O, scratch: &mut Scratch) + where + O: GGLWEToRef + GLWESwitchingKeyDegrees, + M: LWEToGLWESwitchingKeyPreparedFactory, + { + module.prepare_lwe_to_glwe_switching_key(self, other, scratch); + } +} + +impl GGLWEPreparedToRef for LWEToGLWESwitchingKeyPrepared +where + GLWESwitchingKeyPrepared: GGLWEPreparedToRef, +{ + fn to_ref(&self) -> GGLWEPrepared<&[u8], B> { + self.0.to_ref() + } +} + +impl GGLWEPreparedToMut for LWEToGLWESwitchingKeyPrepared +where + GLWESwitchingKeyPrepared: GGLWEPreparedToMut, +{ + fn to_mut(&mut self) -> GGLWEPrepared<&mut [u8], B> { + self.0.to_mut() + } +} + +impl GLWESwitchingKeyDegreesMut for LWEToGLWESwitchingKeyPrepared { + fn input_degree(&mut self) -> &mut Degree { + &mut self.0.input_degree + } + + fn output_degree(&mut self) -> &mut Degree { + &mut self.0.output_degree + } +} diff --git a/poulpy-core/src/layouts/prepared/mod.rs b/poulpy-core/src/layouts/prepared/mod.rs index eb47848..8944b97 100644 --- a/poulpy-core/src/layouts/prepared/mod.rs +++ b/poulpy-core/src/layouts/prepared/mod.rs @@ -1,34 +1,23 @@ -mod gglwe_atk; -mod gglwe_ct; -mod gglwe_ksk; -mod gglwe_tsk; -mod ggsw_ct; -mod glwe_pk; -mod glwe_sk; -mod glwe_to_lwe_ksk; -mod lwe_ksk; -mod lwe_to_glwe_ksk; +mod gglwe; +mod ggsw; +mod glwe; +mod glwe_automorphism_key; +mod glwe_public_key; +mod glwe_secret; +mod glwe_switching_key; +mod glwe_tensor_key; +mod glwe_to_lwe_switching_key; +mod lwe_switching_key; +mod lwe_to_glwe_switching_key; -pub use gglwe_atk::*; -pub use gglwe_ct::*; -pub use gglwe_ksk::*; -pub use gglwe_tsk::*; -pub use ggsw_ct::*; -pub use glwe_pk::*; -pub use glwe_sk::*; -pub use glwe_to_lwe_ksk::*; -pub use lwe_ksk::*; -pub use lwe_to_glwe_ksk::*; -use poulpy_hal::layouts::{Backend, Module, Scratch}; - -pub trait PrepareScratchSpace { - fn prepare_scratch_space(module: &Module, infos: &T) -> usize; -} - -pub trait PrepareAlloc { - fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> T; -} - -pub trait Prepare { - fn prepare(&mut self, module: &Module, other: &T, scratch: &mut Scratch); -} +pub use gglwe::*; +pub use ggsw::*; +pub use glwe::*; +pub use glwe_automorphism_key::*; +pub use glwe_public_key::*; +pub use glwe_secret::*; +pub use glwe_switching_key::*; +pub use glwe_tensor_key::*; +pub use glwe_to_lwe_switching_key::*; +pub use lwe_switching_key::*; +pub use lwe_to_glwe_switching_key::*; diff --git a/poulpy-core/src/lib.rs b/poulpy-core/src/lib.rs index 70035af..ccad084 100644 --- a/poulpy-core/src/lib.rs +++ b/poulpy-core/src/lib.rs @@ -14,12 +14,18 @@ mod utils; pub use operations::*; pub mod layouts; +pub use automorphism::*; +pub use conversion::*; +pub use decryption::*; pub use dist::*; +pub use encryption::*; pub use external_product::*; pub use glwe_packing::*; +pub use glwe_trace::*; +pub use keyswitching::*; +pub use noise::*; +pub use scratch::*; pub use encryption::SIGMA; -pub use scratch::*; - pub mod tests; diff --git a/poulpy-core/src/noise/gglwe.rs b/poulpy-core/src/noise/gglwe.rs new file mode 100644 index 0000000..dc32d57 --- /dev/null +++ b/poulpy-core/src/noise/gglwe.rs @@ -0,0 +1,76 @@ +use poulpy_hal::{ + api::{ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow, ScratchTakeBasic, VecZnxFillUniform, VecZnxSubScalarInplace}, + layouts::{Backend, DataRef, Module, ScalarZnxToRef, Scratch, ScratchOwned, ZnxZero}, +}; + +use crate::decryption::GLWEDecrypt; +use crate::layouts::{GGLWE, GGLWEInfos, GGLWEToRef, GLWEPlaintext, LWEInfos, prepared::GLWESecretPreparedToRef}; + +impl GGLWE { + pub fn assert_noise(&self, module: &M, sk_prepared: &S, pt_want: &P, max_noise: f64) + where + S: GLWESecretPreparedToRef, + P: ScalarZnxToRef, + M: GGLWENoise, + Scratch: ScratchTakeBasic, + { + module.gglwe_assert_noise(self, sk_prepared, pt_want, max_noise); + } +} + +pub trait GGLWENoise { + fn gglwe_assert_noise(&self, res: &R, sk_prepared: &S, pt_want: &P, max_noise: f64) + where + R: GGLWEToRef, + S: GLWESecretPreparedToRef, + P: ScalarZnxToRef, + Scratch: ScratchTakeBasic; +} + +impl GGLWENoise for Module +where + Module: GLWEDecrypt + VecZnxFillUniform + VecZnxSubScalarInplace, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeBasic, +{ + fn gglwe_assert_noise(&self, res: &R, sk_prepared: &S, pt_want: &P, max_noise: f64) + where + R: GGLWEToRef, + S: GLWESecretPreparedToRef, + P: ScalarZnxToRef, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeBasic, + { + let res: &GGLWE<&[u8]> = &res.to_ref(); + + let dsize: usize = res.dsize().into(); + let base2k: usize = res.base2k().into(); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(self.glwe_decrypt_tmp_bytes(res)); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(res); + + (0..res.rank_in().into()).for_each(|col_i| { + (0..res.dnum().into()).for_each(|row_i| { + self.glwe_decrypt( + &res.at(row_i, col_i), + &mut pt, + sk_prepared, + scratch.borrow(), + ); + + self.vec_znx_sub_scalar_inplace(&mut pt.data, 0, (dsize - 1) + row_i * dsize, pt_want, col_i); + + let noise_have: f64 = pt.data.std(base2k, 0).log2(); + + // println!("noise_have: {noise_have}"); + + assert!( + noise_have <= max_noise, + "noise_have: {noise_have} > max_noise: {max_noise}" + ); + + pt.data.zero(); + }); + }); + } +} diff --git a/poulpy-core/src/noise/gglwe_ct.rs b/poulpy-core/src/noise/gglwe_ct.rs deleted file mode 100644 index 0712b7f..0000000 --- a/poulpy-core/src/noise/gglwe_ct.rs +++ /dev/null @@ -1,61 +0,0 @@ -use poulpy_hal::{ - api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, - VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, - VecZnxNormalizeTmpBytes, VecZnxSubScalarInplace, - }, - layouts::{Backend, DataRef, Module, ScalarZnx, ScratchOwned, ZnxZero}, - oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl}, -}; - -use crate::layouts::{GGLWECiphertext, GGLWEInfos, GLWECiphertext, GLWEPlaintext, LWEInfos, prepared::GLWESecretPrepared}; - -impl GGLWECiphertext { - pub fn assert_noise( - &self, - module: &Module, - sk: &GLWESecretPrepared, - pt_want: &ScalarZnx, - max_noise: f64, - ) where - DataSk: DataRef, - DataWant: DataRef, - Module: VecZnxDftAllocBytes - + VecZnxBigAllocBytes - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalizeTmpBytes - + VecZnxSubScalarInplace, - B: Backend + TakeVecZnxDftImpl + TakeVecZnxBigImpl + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, - { - let dsize: usize = self.dsize().into(); - let base2k: usize = self.base2k().into(); - - let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space(module, self)); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(self); - - (0..self.rank_in().into()).for_each(|col_i| { - (0..self.dnum().into()).for_each(|row_i| { - self.at(row_i, col_i) - .decrypt(module, &mut pt, sk, scratch.borrow()); - - module.vec_znx_sub_scalar_inplace(&mut pt.data, 0, (dsize - 1) + row_i * dsize, pt_want, col_i); - - let noise_have: f64 = pt.data.std(base2k, 0).log2(); - - println!("noise_have: {noise_have}"); - - assert!( - noise_have <= max_noise, - "noise_have: {noise_have} > max_noise: {max_noise}" - ); - - pt.data.zero(); - }); - }); - } -} diff --git a/poulpy-core/src/noise/ggsw.rs b/poulpy-core/src/noise/ggsw.rs new file mode 100644 index 0000000..7adeeca --- /dev/null +++ b/poulpy-core/src/noise/ggsw.rs @@ -0,0 +1,173 @@ +use poulpy_hal::{ + api::{ + ScratchOwnedAlloc, ScratchOwnedBorrow, ScratchTakeBasic, SvpApplyDftToDftInplace, VecZnxAddScalarInplace, VecZnxBigAlloc, + VecZnxBigNormalize, VecZnxDftAlloc, VecZnxDftApply, VecZnxIdftApplyTmpA, VecZnxNormalizeTmpBytes, VecZnxSubInplace, + }, + layouts::{Backend, DataRef, Module, ScalarZnxToRef, Scratch, ScratchOwned, VecZnxBig, VecZnxDft, ZnxZero}, +}; + +use crate::decryption::GLWEDecrypt; +use crate::layouts::prepared::GLWESecretPreparedToRef; +use crate::layouts::{GGSW, GGSWInfos, GGSWToRef, GLWEInfos, GLWEPlaintext, LWEInfos, prepared::GLWESecretPrepared}; + +impl GGSW { + pub fn assert_noise(&self, module: &M, sk_prepared: &S, pt_want: &P, max_noise: F) + where + S: GLWESecretPreparedToRef, + P: ScalarZnxToRef, + M: GGSWNoise, + F: Fn(usize) -> f64, + { + module.ggsw_assert_noise(self, sk_prepared, pt_want, max_noise); + } + + pub fn print_noise(&self, module: &M, sk_prepared: &S, pt_want: &P) + where + S: GLWESecretPreparedToRef, + P: ScalarZnxToRef, + M: GGSWNoise, + { + module.ggsw_print_noise(self, sk_prepared, pt_want); + } +} + +pub trait GGSWNoise { + fn ggsw_assert_noise(&self, res: &R, sk_prepared: &S, pt_want: &P, max_noise: F) + where + R: GGSWToRef, + S: GLWESecretPreparedToRef, + P: ScalarZnxToRef, + F: Fn(usize) -> f64; + + fn ggsw_print_noise(&self, res: &R, sk_prepared: &S, pt_want: &P) + where + R: GGSWToRef, + S: GLWESecretPreparedToRef, + P: ScalarZnxToRef; +} + +impl GGSWNoise for Module +where + Module: GLWEDecrypt + + VecZnxDftAlloc + + VecZnxBigAlloc + + VecZnxAddScalarInplace + + VecZnxIdftApplyTmpA + + VecZnxSubInplace, + Scratch: ScratchTakeBasic, + ScratchOwned: ScratchOwnedBorrow + ScratchOwnedAlloc, +{ + fn ggsw_assert_noise(&self, res: &R, sk_prepared: &S, pt_want: &P, max_noise: F) + where + R: GGSWToRef, + S: GLWESecretPreparedToRef, + P: ScalarZnxToRef, + F: Fn(usize) -> f64, + { + let res: &GGSW<&[u8]> = &res.to_ref(); + let sk_prepared: &GLWESecretPrepared<&[u8], BE> = &sk_prepared.to_ref(); + + let base2k: usize = res.base2k().into(); + let dsize: usize = res.dsize().into(); + + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(res); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(res); + let mut pt_dft: VecZnxDft, BE> = self.vec_znx_dft_alloc(1, res.size()); + let mut pt_big: VecZnxBig, BE> = self.vec_znx_big_alloc(1, res.size()); + + let mut scratch: ScratchOwned = + ScratchOwned::alloc(self.glwe_decrypt_tmp_bytes(res) | self.vec_znx_normalize_tmp_bytes()); + + (0..(res.rank() + 1).into()).for_each(|col_j| { + (0..res.dnum().into()).for_each(|row_i| { + self.vec_znx_add_scalar_inplace(&mut pt.data, 0, (dsize - 1) + row_i * dsize, pt_want, 0); + + // mul with sk[col_j-1] + if col_j > 0 { + self.vec_znx_dft_apply(1, 0, &mut pt_dft, 0, &pt.data, 0); + self.svp_apply_dft_to_dft_inplace(&mut pt_dft, 0, &sk_prepared.data, col_j - 1); + self.vec_znx_idft_apply_tmpa(&mut pt_big, 0, &mut pt_dft, 0); + self.vec_znx_big_normalize( + base2k, + &mut pt.data, + 0, + base2k, + &pt_big, + 0, + scratch.borrow(), + ); + } + + self.glwe_decrypt( + &res.at(row_i, col_j), + &mut pt_have, + sk_prepared, + scratch.borrow(), + ); + + self.vec_znx_sub_inplace(&mut pt_have.data, 0, &pt.data, 0); + + let std_pt: f64 = pt_have.data.std(base2k, 0).log2(); + let noise: f64 = max_noise(col_j); + assert!(std_pt <= noise, "{std_pt} > {noise}"); + + pt.data.zero(); + }); + }); + } + + fn ggsw_print_noise(&self, res: &R, sk_prepared: &S, pt_want: &P) + where + R: GGSWToRef, + S: GLWESecretPreparedToRef, + P: ScalarZnxToRef, + { + let res: &GGSW<&[u8]> = &res.to_ref(); + let sk_prepared: &GLWESecretPrepared<&[u8], BE> = &sk_prepared.to_ref(); + + let base2k: usize = res.base2k().into(); + let dsize: usize = res.dsize().into(); + + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(res); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(res); + let mut pt_dft: VecZnxDft, BE> = self.vec_znx_dft_alloc(1, res.size()); + let mut pt_big: VecZnxBig, BE> = self.vec_znx_big_alloc(1, res.size()); + + let mut scratch: ScratchOwned = + ScratchOwned::alloc(self.glwe_decrypt_tmp_bytes(res) | self.vec_znx_normalize_tmp_bytes()); + + for col_j in 0..(res.rank() + 1).into() { + for row_i in 0..res.dnum().into() { + self.vec_znx_add_scalar_inplace(&mut pt.data, 0, (dsize - 1) + row_i * dsize, pt_want, 0); + + // mul with sk[col_j-1] + if col_j > 0 { + self.vec_znx_dft_apply(1, 0, &mut pt_dft, 0, &pt.data, 0); + self.svp_apply_dft_to_dft_inplace(&mut pt_dft, 0, &sk_prepared.data, col_j - 1); + self.vec_znx_idft_apply_tmpa(&mut pt_big, 0, &mut pt_dft, 0); + self.vec_znx_big_normalize( + base2k, + &mut pt.data, + 0, + base2k, + &pt_big, + 0, + scratch.borrow(), + ); + } + + self.glwe_decrypt( + &res.at(row_i, col_j), + &mut pt_have, + sk_prepared, + scratch.borrow(), + ); + self.vec_znx_sub_inplace(&mut pt_have.data, 0, &pt.data, 0); + + let std_pt: f64 = pt_have.data.std(base2k, 0).log2(); + println!("col: {col_j} row: {row_i}: {std_pt}"); + pt.data.zero(); + } + } + } +} diff --git a/poulpy-core/src/noise/ggsw_ct.rs b/poulpy-core/src/noise/ggsw_ct.rs deleted file mode 100644 index 03bb0c0..0000000 --- a/poulpy-core/src/noise/ggsw_ct.rs +++ /dev/null @@ -1,158 +0,0 @@ -use poulpy_hal::{ - api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, VecZnxAddScalarInplace, VecZnxBigAddInplace, - VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, - VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, - VecZnxNormalizeTmpBytes, VecZnxSubInplace, - }, - layouts::{Backend, DataRef, Module, ScalarZnx, ScratchOwned, VecZnxBig, VecZnxDft, ZnxZero}, - oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl}, -}; - -use crate::layouts::{ - GGSWCiphertext, GGSWInfos, GLWECiphertext, GLWEInfos, GLWEPlaintext, LWEInfos, prepared::GLWESecretPrepared, -}; - -impl GGSWCiphertext { - pub fn assert_noise( - &self, - module: &Module, - sk_prepared: &GLWESecretPrepared, - pt_want: &ScalarZnx, - max_noise: F, - ) where - DataSk: DataRef, - DataScalar: DataRef, - Module: VecZnxDftAllocBytes - + VecZnxBigAllocBytes - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalizeTmpBytes - + VecZnxBigAlloc - + VecZnxDftAlloc - + VecZnxBigNormalizeTmpBytes - + VecZnxIdftApplyTmpA - + VecZnxAddScalarInplace - + VecZnxSubInplace, - B: Backend + TakeVecZnxDftImpl + TakeVecZnxBigImpl + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, - F: Fn(usize) -> f64, - { - let base2k: usize = self.base2k().into(); - let dsize: usize = self.dsize().into(); - - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(self); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(self); - let mut pt_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(1, self.size()); - let mut pt_big: VecZnxBig, B> = module.vec_znx_big_alloc(1, self.size()); - - let mut scratch: ScratchOwned = - ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space(module, self) | module.vec_znx_normalize_tmp_bytes()); - - (0..(self.rank() + 1).into()).for_each(|col_j| { - (0..self.dnum().into()).for_each(|row_i| { - module.vec_znx_add_scalar_inplace(&mut pt.data, 0, (dsize - 1) + row_i * dsize, pt_want, 0); - - // mul with sk[col_j-1] - if col_j > 0 { - module.vec_znx_dft_apply(1, 0, &mut pt_dft, 0, &pt.data, 0); - module.svp_apply_dft_to_dft_inplace(&mut pt_dft, 0, &sk_prepared.data, col_j - 1); - module.vec_znx_idft_apply_tmpa(&mut pt_big, 0, &mut pt_dft, 0); - module.vec_znx_big_normalize( - base2k, - &mut pt.data, - 0, - base2k, - &pt_big, - 0, - scratch.borrow(), - ); - } - - self.at(row_i, col_j) - .decrypt(module, &mut pt_have, sk_prepared, scratch.borrow()); - - module.vec_znx_sub_inplace(&mut pt_have.data, 0, &pt.data, 0); - - let std_pt: f64 = pt_have.data.std(base2k, 0).log2(); - let noise: f64 = max_noise(col_j); - assert!(std_pt <= noise, "{std_pt} > {noise}"); - - pt.data.zero(); - }); - }); - } -} - -impl GGSWCiphertext { - pub fn print_noise( - &self, - module: &Module, - sk_prepared: &GLWESecretPrepared, - pt_want: &ScalarZnx, - ) where - DataSk: DataRef, - DataScalar: DataRef, - Module: VecZnxDftAllocBytes - + VecZnxBigAllocBytes - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalizeTmpBytes - + VecZnxBigAlloc - + VecZnxDftAlloc - + VecZnxBigNormalizeTmpBytes - + VecZnxIdftApplyTmpA - + VecZnxAddScalarInplace - + VecZnxSubInplace, - B: Backend + TakeVecZnxDftImpl + TakeVecZnxBigImpl + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, - { - let base2k: usize = self.base2k().into(); - let dsize: usize = self.dsize().into(); - - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(self); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(self); - let mut pt_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(1, self.size()); - let mut pt_big: VecZnxBig, B> = module.vec_znx_big_alloc(1, self.size()); - - let mut scratch: ScratchOwned = - ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space(module, self) | module.vec_znx_normalize_tmp_bytes()); - - (0..(self.rank() + 1).into()).for_each(|col_j| { - (0..self.dnum().into()).for_each(|row_i| { - module.vec_znx_add_scalar_inplace(&mut pt.data, 0, (dsize - 1) + row_i * dsize, pt_want, 0); - - // mul with sk[col_j-1] - if col_j > 0 { - module.vec_znx_dft_apply(1, 0, &mut pt_dft, 0, &pt.data, 0); - module.svp_apply_dft_to_dft_inplace(&mut pt_dft, 0, &sk_prepared.data, col_j - 1); - module.vec_znx_idft_apply_tmpa(&mut pt_big, 0, &mut pt_dft, 0); - module.vec_znx_big_normalize( - base2k, - &mut pt.data, - 0, - base2k, - &pt_big, - 0, - scratch.borrow(), - ); - } - - self.at(row_i, col_j) - .decrypt(module, &mut pt_have, sk_prepared, scratch.borrow()); - module.vec_znx_sub_inplace(&mut pt_have.data, 0, &pt.data, 0); - - let std_pt: f64 = pt_have.data.std(base2k, 0).log2(); - println!("col: {col_j} row: {row_i}: {std_pt}"); - pt.data.zero(); - // println!(">>>>>>>>>>>>>>>>"); - }); - }); - } -} diff --git a/poulpy-core/src/noise/glwe.rs b/poulpy-core/src/noise/glwe.rs new file mode 100644 index 0000000..7c6979e --- /dev/null +++ b/poulpy-core/src/noise/glwe.rs @@ -0,0 +1,80 @@ +use poulpy_hal::{ + api::{ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxNormalizeInplace, VecZnxSubInplace}, + layouts::{Backend, DataRef, Module, Scratch, ScratchOwned}, +}; + +use crate::{ + ScratchTakeCore, + decryption::GLWEDecrypt, + layouts::{GLWE, GLWEPlaintext, GLWEPlaintextToRef, GLWEToRef, LWEInfos, prepared::GLWESecretPreparedToRef}, +}; + +impl GLWE { + pub fn noise(&self, module: &M, sk_prepared: &S, pt_want: &P, scratch: &mut Scratch) -> f64 + where + M: GLWENoise, + S: GLWESecretPreparedToRef, + P: GLWEPlaintextToRef, + { + module.glwe_noise(self, sk_prepared, pt_want, scratch) + } + + pub fn assert_noise(&self, module: &M, sk_prepared: &S, pt_want: &P, max_noise: f64) + where + S: GLWESecretPreparedToRef, + P: GLWEPlaintextToRef, + M: GLWENoise, + { + module.glwe_assert_noise(self, sk_prepared, pt_want, max_noise); + } +} + +pub trait GLWENoise { + fn glwe_noise(&self, res: &R, sk_prepared: &S, pt_want: &P, scratch: &mut Scratch) -> f64 + where + R: GLWEToRef, + S: GLWESecretPreparedToRef, + P: GLWEPlaintextToRef; + + fn glwe_assert_noise(&self, res: &R, sk_prepared: &S, pt_want: &P, max_noise: f64) + where + R: GLWEToRef, + S: GLWESecretPreparedToRef, + P: GLWEPlaintextToRef; +} + +impl GLWENoise for Module +where + Module: GLWEDecrypt + VecZnxSubInplace + VecZnxNormalizeInplace, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchTakeCore, +{ + fn glwe_noise(&self, res: &R, sk_prepared: &S, pt_want: &P, scratch: &mut Scratch) -> f64 + where + R: GLWEToRef, + S: GLWESecretPreparedToRef, + P: GLWEPlaintextToRef, + { + let res_ref: &GLWE<&[u8]> = &res.to_ref(); + + let pt_want: &GLWEPlaintext<&[u8]> = &pt_want.to_ref(); + + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(res_ref); + self.glwe_decrypt(res, &mut pt_have, sk_prepared, scratch); + self.vec_znx_sub_inplace(&mut pt_have.data, 0, &pt_want.data, 0); + self.vec_znx_normalize_inplace(res_ref.base2k().into(), &mut pt_have.data, 0, scratch); + pt_have.data.std(res_ref.base2k().into(), 0).log2() + } + + fn glwe_assert_noise(&self, res: &R, sk_prepared: &S, pt_want: &P, max_noise: f64) + where + R: GLWEToRef, + S: GLWESecretPreparedToRef, + P: GLWEPlaintextToRef, + { + let res: &GLWE<&[u8]> = &res.to_ref(); + let mut scratch: ScratchOwned = ScratchOwned::alloc(self.glwe_decrypt_tmp_bytes(res)); + let noise_have: f64 = self.glwe_noise(res, sk_prepared, pt_want, scratch.borrow()); + assert!(noise_have <= max_noise, "{noise_have} {max_noise}"); + } +} diff --git a/poulpy-core/src/noise/glwe_ct.rs b/poulpy-core/src/noise/glwe_ct.rs deleted file mode 100644 index f7af2a1..0000000 --- a/poulpy-core/src/noise/glwe_ct.rs +++ /dev/null @@ -1,68 +0,0 @@ -use poulpy_hal::{ - api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAddInplace, - VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, - VecZnxIdftApplyConsume, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSubInplace, - }, - layouts::{Backend, DataRef, Module, Scratch, ScratchOwned}, - oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl}, -}; - -use crate::layouts::{GLWECiphertext, GLWEPlaintext, LWEInfos, prepared::GLWESecretPrepared}; - -impl GLWECiphertext { - pub fn noise( - &self, - module: &Module, - sk_prepared: &GLWESecretPrepared, - pt_want: &GLWEPlaintext, - scratch: &mut Scratch, - ) -> f64 - where - DataSk: DataRef, - DataPt: DataRef, - B: Backend, - Module: VecZnxDftApply - + VecZnxSubInplace - + VecZnxNormalizeInplace - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize, - Scratch: TakeVecZnxDft + TakeVecZnxBig, - { - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(self); - self.decrypt(module, &mut pt_have, sk_prepared, scratch); - module.vec_znx_sub_inplace(&mut pt_have.data, 0, &pt_want.data, 0); - module.vec_znx_normalize_inplace(self.base2k().into(), &mut pt_have.data, 0, scratch); - pt_have.data.std(self.base2k().into(), 0).log2() - } - - pub fn assert_noise( - &self, - module: &Module, - sk_prepared: &GLWESecretPrepared, - pt_want: &GLWEPlaintext, - max_noise: f64, - ) where - DataSk: DataRef, - DataPt: DataRef, - Module: VecZnxDftAllocBytes - + VecZnxBigAllocBytes - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalizeTmpBytes - + VecZnxSubInplace - + VecZnxNormalizeInplace, - B: Backend + TakeVecZnxDftImpl + TakeVecZnxBigImpl + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, - { - let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space(module, self)); - let noise_have: f64 = self.noise(module, sk_prepared, pt_want, scratch.borrow()); - assert!(noise_have <= max_noise, "{noise_have} {max_noise}"); - } -} diff --git a/poulpy-core/src/noise/mod.rs b/poulpy-core/src/noise/mod.rs index 25e2b9b..6f8882f 100644 --- a/poulpy-core/src/noise/mod.rs +++ b/poulpy-core/src/noise/mod.rs @@ -1,6 +1,10 @@ -mod gglwe_ct; -mod ggsw_ct; -mod glwe_ct; +mod gglwe; +mod ggsw; +mod glwe; + +pub use gglwe::*; +pub use ggsw::*; +pub use glwe::*; #[allow(clippy::too_many_arguments)] #[allow(dead_code)] diff --git a/poulpy-core/src/operations/glwe.rs b/poulpy-core/src/operations/glwe.rs index b0ee4f6..c6f1818 100644 --- a/poulpy-core/src/operations/glwe.rs +++ b/poulpy-core/src/operations/glwe.rs @@ -1,328 +1,364 @@ use poulpy_hal::{ api::{ - VecZnxAdd, VecZnxAddInplace, VecZnxCopy, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegateInplace, + ModuleN, VecZnxAdd, VecZnxAddInplace, VecZnxCopy, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubInplace, VecZnxSubNegateInplace, }, - layouts::{Backend, DataMut, Module, Scratch, VecZnx, ZnxZero}, + layouts::{Backend, Module, Scratch, VecZnx, ZnxZero}, }; -use crate::layouts::{ - GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEInfos, GLWELayoutSet, GLWEPlaintext, LWEInfos, TorusPrecision, +use crate::{ + ScratchTakeCore, + layouts::{GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, SetGLWEInfos, TorusPrecision}, }; -impl GLWEOperations for GLWEPlaintext +pub trait GLWEAdd where - D: DataMut, - GLWEPlaintext: GLWECiphertextToMut + GLWEInfos, + Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace, +{ + fn glwe_add(&self, res: &mut R, a: &A, b: &B) + where + R: GLWEToMut, + A: GLWEToRef, + B: GLWEToRef, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &mut GLWE<&[u8]> = &mut a.to_ref(); + let b: &GLWE<&[u8]> = &b.to_ref(); + + assert_eq!(a.n(), self.n() as u32); + assert_eq!(b.n(), self.n() as u32); + assert_eq!(res.n(), self.n() as u32); + assert_eq!(a.base2k(), b.base2k()); + assert!(res.rank() >= a.rank().max(b.rank())); + + let min_col: usize = (a.rank().min(b.rank()) + 1).into(); + let max_col: usize = (a.rank().max(b.rank() + 1)).into(); + let self_col: usize = (res.rank() + 1).into(); + + (0..min_col).for_each(|i| { + self.vec_znx_add(res.data_mut(), i, a.data(), i, b.data(), i); + }); + + if a.rank() > b.rank() { + (min_col..max_col).for_each(|i| { + self.vec_znx_copy(res.data_mut(), i, a.data(), i); + }); + } else { + (min_col..max_col).for_each(|i| { + self.vec_znx_copy(res.data_mut(), i, b.data(), i); + }); + } + + let size: usize = res.size(); + (max_col..self_col).for_each(|i| { + (0..size).for_each(|j| { + res.data.zero_at(i, j); + }); + }); + + res.set_base2k(a.base2k()); + res.set_k(set_k_binary(res, a, b)); + } + + fn glwe_add_inplace(&self, res: &mut R, a: &A) + where + R: GLWEToMut, + A: GLWEToRef, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); + + assert_eq!(res.n(), self.n() as u32); + assert_eq!(a.n(), self.n() as u32); + assert_eq!(res.base2k(), a.base2k()); + assert!(res.rank() >= a.rank()); + + (0..(a.rank() + 1).into()).for_each(|i| { + self.vec_znx_add_inplace(res.data_mut(), i, a.data(), i); + }); + + res.set_k(set_k_unary(res, a)) + } +} + +impl GLWEAdd for Module where Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace {} + +impl GLWESub for Module where + Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegateInplace + VecZnxSubInplace + VecZnxSubNegateInplace { } -impl GLWEOperations for GLWECiphertext where GLWECiphertext: GLWECiphertextToMut + GLWEInfos {} - -pub trait GLWEOperations: GLWECiphertextToMut + GLWEInfos + GLWELayoutSet + Sized { - fn add(&mut self, module: &Module, a: &A, b: &B) +pub trait GLWESub +where + Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegateInplace + VecZnxSubInplace + VecZnxSubNegateInplace, +{ + fn glwe_sub(&self, res: &mut R, a: &A, b: &B) where - A: GLWECiphertextToRef + GLWEInfos, - B: GLWECiphertextToRef + GLWEInfos, - Module: VecZnxAdd + VecZnxCopy, + R: GLWEToMut, + A: GLWEToRef, + B: GLWEToRef, { - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(b.n(), self.n()); - assert_eq!(a.base2k(), b.base2k()); - assert!(self.rank() >= a.rank().max(b.rank())); - } + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); + let b: &GLWE<&[u8]> = &b.to_ref(); + + assert_eq!(a.n(), self.n() as u32); + assert_eq!(b.n(), self.n() as u32); + assert_eq!(a.base2k(), b.base2k()); + assert!(res.rank() >= a.rank().max(b.rank())); let min_col: usize = (a.rank().min(b.rank()) + 1).into(); let max_col: usize = (a.rank().max(b.rank() + 1)).into(); - let self_col: usize = (self.rank() + 1).into(); - - let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); - let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref(); - let b_ref: &GLWECiphertext<&[u8]> = &b.to_ref(); + let self_col: usize = (res.rank() + 1).into(); (0..min_col).for_each(|i| { - module.vec_znx_add(&mut self_mut.data, i, &a_ref.data, i, &b_ref.data, i); + self.vec_znx_sub(res.data_mut(), i, a.data(), i, b.data(), i); }); if a.rank() > b.rank() { (min_col..max_col).for_each(|i| { - module.vec_znx_copy(&mut self_mut.data, i, &a_ref.data, i); + self.vec_znx_copy(res.data_mut(), i, a.data(), i); }); } else { (min_col..max_col).for_each(|i| { - module.vec_znx_copy(&mut self_mut.data, i, &b_ref.data, i); + self.vec_znx_copy(res.data_mut(), i, b.data(), i); + self.vec_znx_negate_inplace(res.data_mut(), i); }); } - let size: usize = self_mut.size(); + let size: usize = res.size(); (max_col..self_col).for_each(|i| { (0..size).for_each(|j| { - self_mut.data.zero_at(i, j); + res.data.zero_at(i, j); }); }); - self.set_basek(a.base2k()); - self.set_k(set_k_binary(self, a, b)); + res.set_base2k(a.base2k()); + res.set_k(set_k_binary(res, a, b)); } - fn add_inplace(&mut self, module: &Module, a: &A) + fn glwe_sub_inplace(&self, res: &mut R, a: &A) where - A: GLWECiphertextToRef + GLWEInfos, - Module: VecZnxAddInplace, + R: GLWEToMut, + A: GLWEToRef, { - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(self.base2k(), a.base2k()); - assert!(self.rank() >= a.rank()) - } + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); - let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); - let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref(); + assert_eq!(res.n(), self.n() as u32); + assert_eq!(a.n(), self.n() as u32); + assert_eq!(res.base2k(), a.base2k()); + assert!(res.rank() >= a.rank()); (0..(a.rank() + 1).into()).for_each(|i| { - module.vec_znx_add_inplace(&mut self_mut.data, i, &a_ref.data, i); + self.vec_znx_sub_inplace(res.data_mut(), i, a.data(), i); }); - self.set_k(set_k_unary(self, a)) + res.set_k(set_k_unary(res, a)) } - fn sub(&mut self, module: &Module, a: &A, b: &B) + fn glwe_sub_negate_inplace(&self, res: &mut R, a: &A) where - A: GLWECiphertextToRef + GLWEInfos, - B: GLWECiphertextToRef + GLWEInfos, - Module: VecZnxSub + VecZnxCopy + VecZnxNegateInplace, + R: GLWEToMut, + A: GLWEToRef, { - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(b.n(), self.n()); - assert_eq!(a.base2k(), b.base2k()); - assert!(self.rank() >= a.rank().max(b.rank())); - } + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); - let min_col: usize = (a.rank().min(b.rank()) + 1).into(); - let max_col: usize = (a.rank().max(b.rank() + 1)).into(); - let self_col: usize = (self.rank() + 1).into(); - - let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); - let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref(); - let b_ref: &GLWECiphertext<&[u8]> = &b.to_ref(); - - (0..min_col).for_each(|i| { - module.vec_znx_sub(&mut self_mut.data, i, &a_ref.data, i, &b_ref.data, i); - }); - - if a.rank() > b.rank() { - (min_col..max_col).for_each(|i| { - module.vec_znx_copy(&mut self_mut.data, i, &a_ref.data, i); - }); - } else { - (min_col..max_col).for_each(|i| { - module.vec_znx_copy(&mut self_mut.data, i, &b_ref.data, i); - module.vec_znx_negate_inplace(&mut self_mut.data, i); - }); - } - - let size: usize = self_mut.size(); - (max_col..self_col).for_each(|i| { - (0..size).for_each(|j| { - self_mut.data.zero_at(i, j); - }); - }); - - self.set_basek(a.base2k()); - self.set_k(set_k_binary(self, a, b)); - } - - fn sub_inplace_ab(&mut self, module: &Module, a: &A) - where - A: GLWECiphertextToRef + GLWEInfos, - Module: VecZnxSubInplace, - { - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(self.base2k(), a.base2k()); - assert!(self.rank() >= a.rank()) - } - - let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); - let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref(); + assert_eq!(res.n(), self.n() as u32); + assert_eq!(a.n(), self.n() as u32); + assert_eq!(res.base2k(), a.base2k()); + assert!(res.rank() >= a.rank()); (0..(a.rank() + 1).into()).for_each(|i| { - module.vec_znx_sub_inplace(&mut self_mut.data, i, &a_ref.data, i); + self.vec_znx_sub_negate_inplace(res.data_mut(), i, a.data(), i); }); - self.set_k(set_k_unary(self, a)) + res.set_k(set_k_unary(res, a)) } +} - fn sub_inplace_ba(&mut self, module: &Module, a: &A) +impl GLWERotate for Module where Self: ModuleN + VecZnxRotate + VecZnxRotateInplace {} + +pub trait GLWERotate +where + Self: ModuleN + VecZnxRotate + VecZnxRotateInplace, +{ + fn glwe_rotate(&self, k: i64, res: &mut R, a: &A) where - A: GLWECiphertextToRef + GLWEInfos, - Module: VecZnxSubNegateInplace, + R: GLWEToMut, + A: GLWEToRef, { - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(self.base2k(), a.base2k()); - assert!(self.rank() >= a.rank()) - } + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); - let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); - let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref(); + assert_eq!(a.n(), self.n() as u32); + assert_eq!(res.rank(), a.rank()); (0..(a.rank() + 1).into()).for_each(|i| { - module.vec_znx_sub_negate_inplace(&mut self_mut.data, i, &a_ref.data, i); + self.vec_znx_rotate(k, res.data_mut(), i, a.data(), i); }); - self.set_k(set_k_unary(self, a)) + res.set_base2k(a.base2k()); + res.set_k(set_k_unary(res, a)) } - fn rotate(&mut self, module: &Module, k: i64, a: &A) + fn glwe_rotate_inplace(&self, k: i64, res: &mut R, scratch: &mut Scratch) where - A: GLWECiphertextToRef + GLWEInfos, - Module: VecZnxRotate, + R: GLWEToMut, + Scratch: ScratchTakeCore, { - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(self.rank(), a.rank()) - } + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); - let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); - let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref(); - - (0..(a.rank() + 1).into()).for_each(|i| { - module.vec_znx_rotate(k, &mut self_mut.data, i, &a_ref.data, i); - }); - - self.set_basek(a.base2k()); - self.set_k(set_k_unary(self, a)) - } - - fn rotate_inplace(&mut self, module: &Module, k: i64, scratch: &mut Scratch) - where - Module: VecZnxRotateInplace, - { - let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); - - (0..(self_mut.rank() + 1).into()).for_each(|i| { - module.vec_znx_rotate_inplace(k, &mut self_mut.data, i, scratch); - }); - } - - fn mul_xp_minus_one(&mut self, module: &Module, k: i64, a: &A) - where - A: GLWECiphertextToRef + GLWEInfos, - Module: VecZnxMulXpMinusOne, - { - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(self.rank(), a.rank()) - } - - let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); - let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref(); - - (0..(a.rank() + 1).into()).for_each(|i| { - module.vec_znx_mul_xp_minus_one(k, &mut self_mut.data, i, &a_ref.data, i); - }); - - self.set_basek(a.base2k()); - self.set_k(set_k_unary(self, a)) - } - - fn mul_xp_minus_one_inplace(&mut self, module: &Module, k: i64, scratch: &mut Scratch) - where - Module: VecZnxMulXpMinusOneInplace, - { - let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); - - (0..(self_mut.rank() + 1).into()).for_each(|i| { - module.vec_znx_mul_xp_minus_one_inplace(k, &mut self_mut.data, i, scratch); - }); - } - - fn copy(&mut self, module: &Module, a: &A) - where - A: GLWECiphertextToRef + GLWEInfos, - Module: VecZnxCopy, - { - #[cfg(debug_assertions)] - { - assert_eq!(self.n(), a.n()); - assert_eq!(self.rank(), a.rank()); - } - - let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); - let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref(); - - (0..(self_mut.rank() + 1).into()).for_each(|i| { - module.vec_znx_copy(&mut self_mut.data, i, &a_ref.data, i); - }); - - self.set_k(a.k().min(self.max_k())); - self.set_basek(a.base2k()); - } - - fn rsh(&mut self, module: &Module, k: usize, scratch: &mut Scratch) - where - Module: VecZnxRshInplace, - { - let base2k: usize = self.base2k().into(); - (0..(self.rank() + 1).into()).for_each(|i| { - module.vec_znx_rsh_inplace(base2k, k, &mut self.to_mut().data, i, scratch); - }) - } - - fn normalize(&mut self, module: &Module, a: &A, scratch: &mut Scratch) - where - A: GLWECiphertextToRef + GLWEInfos, - Module: VecZnxNormalize, - { - #[cfg(debug_assertions)] - { - assert_eq!(self.n(), a.n()); - assert_eq!(self.rank(), a.rank()); - } - - let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); - let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref(); - - (0..(self_mut.rank() + 1).into()).for_each(|i| { - module.vec_znx_normalize( - a.base2k().into(), - &mut self_mut.data, - i, - a.base2k().into(), - &a_ref.data, - i, - scratch, - ); - }); - self.set_basek(a.base2k()); - self.set_k(a.k().min(self.k())); - } - - fn normalize_inplace(&mut self, module: &Module, scratch: &mut Scratch) - where - Module: VecZnxNormalizeInplace, - { - let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); - (0..(self_mut.rank() + 1).into()).for_each(|i| { - module.vec_znx_normalize_inplace(self_mut.base2k().into(), &mut self_mut.data, i, scratch); + (0..(res.rank() + 1).into()).for_each(|i| { + self.vec_znx_rotate_inplace(k, res.data_mut(), i, scratch); }); } } -impl GLWECiphertext> { - pub fn rsh_scratch_space(n: usize) -> usize { - VecZnx::rsh_scratch_space(n) +impl GLWEMulXpMinusOne for Module where Self: ModuleN + VecZnxMulXpMinusOne + VecZnxMulXpMinusOneInplace {} + +pub trait GLWEMulXpMinusOne +where + Self: ModuleN + VecZnxMulXpMinusOne + VecZnxMulXpMinusOneInplace, +{ + fn glwe_mul_xp_minus_one(&self, k: i64, res: &mut R, a: &A) + where + R: GLWEToMut, + A: GLWEToRef, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); + + assert_eq!(res.n(), self.n() as u32); + assert_eq!(a.n(), self.n() as u32); + assert_eq!(res.rank(), a.rank()); + + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_mul_xp_minus_one(k, res.data_mut(), i, a.data(), i); + } + + res.set_base2k(a.base2k()); + res.set_k(set_k_unary(res, a)) + } + + fn glwe_mul_xp_minus_one_inplace(&self, k: i64, res: &mut R, scratch: &mut Scratch) + where + R: GLWEToMut, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + + assert_eq!(res.n(), self.n() as u32); + + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_mul_xp_minus_one_inplace(k, res.data_mut(), i, scratch); + } + } +} + +impl GLWECopy for Module where Self: ModuleN + VecZnxCopy {} + +pub trait GLWECopy +where + Self: ModuleN + VecZnxCopy, +{ + fn glwe_copy(&self, res: &mut R, a: &A) + where + R: GLWEToMut, + A: GLWEToRef, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); + + assert_eq!(res.n(), self.n() as u32); + assert_eq!(a.n(), self.n() as u32); + assert_eq!(res.rank(), a.rank()); + + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_copy(res.data_mut(), i, a.data(), i); + } + + res.set_k(a.k().min(res.max_k())); + res.set_base2k(a.base2k()); + } +} + +impl GLWEShift for Module where Self: ModuleN + VecZnxRshInplace {} + +pub trait GLWEShift +where + Self: ModuleN + VecZnxRshInplace, +{ + fn glwe_rsh_tmp_byte(&self) -> usize { + VecZnx::rsh_tmp_bytes(self.n()) + } + + fn glwe_rsh(&self, k: usize, res: &mut R, scratch: &mut Scratch) + where + R: GLWEToMut, + Scratch: ScratchTakeCore, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let base2k: usize = res.base2k().into(); + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_rsh_inplace(base2k, k, res.data_mut(), i, scratch); + } + } +} + +impl GLWE> { + pub fn rsh_tmp_bytes(module: &M) -> usize + where + M: GLWEShift, + { + module.glwe_rsh_tmp_byte() + } +} + +impl GLWENormalize for Module where Self: ModuleN + VecZnxNormalize + VecZnxNormalizeInplace {} + +pub trait GLWENormalize +where + Self: ModuleN + VecZnxNormalize + VecZnxNormalizeInplace, +{ + fn glwe_normalize(&self, res: &mut R, a: &A, scratch: &mut Scratch) + where + R: GLWEToMut, + A: GLWEToRef, + Scratch: ScratchTakeCore, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); + + assert_eq!(res.n(), self.n() as u32); + assert_eq!(a.n(), self.n() as u32); + assert_eq!(res.rank(), a.rank()); + + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_normalize( + res.base2k().into(), + res.data_mut(), + i, + a.base2k().into(), + a.data(), + i, + scratch, + ); + } + + res.set_k(a.k().min(res.k())); + } + + fn glwe_normalize_inplace(&self, res: &mut R, scratch: &mut Scratch) + where + R: GLWEToMut, + Scratch: ScratchTakeCore, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_normalize_inplace(res.base2k().into(), res.data_mut(), i, scratch); + } } } diff --git a/poulpy-core/src/scratch.rs b/poulpy-core/src/scratch.rs index 1a5a6ce..2220dc4 100644 --- a/poulpy-core/src/scratch.rs +++ b/poulpy-core/src/scratch.rs @@ -1,190 +1,69 @@ use poulpy_hal::{ - api::{TakeMatZnx, TakeScalarZnx, TakeSvpPPol, TakeVecZnx, TakeVecZnxDft, TakeVmpPMat}, + api::{ModuleN, ScratchAvailable, ScratchTakeBasic, SvpPPolBytesOf, VecZnxDftBytesOf, VmpPMatBytesOf}, layouts::{Backend, Scratch}, }; use crate::{ dist::Distribution, layouts::{ - Degree, GGLWEAutomorphismKey, GGLWECiphertext, GGLWEInfos, GGLWESwitchingKey, GGLWETensorKey, GGSWCiphertext, GGSWInfos, - GLWECiphertext, GLWEInfos, GLWEPlaintext, GLWEPublicKey, GLWESecret, Rank, + Degree, GGLWE, GGLWEInfos, GGLWELayout, GGSW, GGSWInfos, GLWE, GLWEAutomorphismKey, GLWEInfos, GLWEPlaintext, + GLWEPrepared, GLWEPublicKey, GLWESecret, GLWESwitchingKey, GLWETensorKey, Rank, prepared::{ - GGLWEAutomorphismKeyPrepared, GGLWECiphertextPrepared, GGLWESwitchingKeyPrepared, GGLWETensorKeyPrepared, - GGSWCiphertextPrepared, GLWEPublicKeyPrepared, GLWESecretPrepared, + GGLWEPrepared, GGSWPrepared, GLWEAutomorphismKeyPrepared, GLWEPublicKeyPrepared, GLWESecretPrepared, + GLWESwitchingKeyPrepared, GLWETensorKeyPrepared, }, }, }; -pub trait TakeGLWECt { - fn take_glwe_ct(&mut self, infos: &A) -> (GLWECiphertext<&mut [u8]>, &mut Self) - where - A: GLWEInfos; -} - -pub trait TakeGLWECtSlice { - fn take_glwe_ct_slice(&mut self, size: usize, infos: &A) -> (Vec>, &mut Self) - where - A: GLWEInfos; -} - -pub trait TakeGLWEPt { - fn take_glwe_pt(&mut self, infos: &A) -> (GLWEPlaintext<&mut [u8]>, &mut Self) - where - A: GLWEInfos; -} - -pub trait TakeGGLWE { - fn take_gglwe(&mut self, infos: &A) -> (GGLWECiphertext<&mut [u8]>, &mut Self) - where - A: GGLWEInfos; -} - -pub trait TakeGGLWEPrepared { - fn take_gglwe_prepared(&mut self, infos: &A) -> (GGLWECiphertextPrepared<&mut [u8], B>, &mut Self) - where - A: GGLWEInfos; -} - -pub trait TakeGGSW { - fn take_ggsw(&mut self, infos: &A) -> (GGSWCiphertext<&mut [u8]>, &mut Self) - where - A: GGSWInfos; -} - -pub trait TakeGGSWPrepared { - fn take_ggsw_prepared(&mut self, infos: &A) -> (GGSWCiphertextPrepared<&mut [u8], B>, &mut Self) - where - A: GGSWInfos; -} - -pub trait TakeGGSWPreparedSlice { - fn take_ggsw_prepared_slice(&mut self, size: usize, infos: &A) -> (Vec>, &mut Self) - where - A: GGSWInfos; -} - -pub trait TakeGLWESecret { - fn take_glwe_secret(&mut self, n: Degree, rank: Rank) -> (GLWESecret<&mut [u8]>, &mut Self); -} - -pub trait TakeGLWESecretPrepared { - fn take_glwe_secret_prepared(&mut self, n: Degree, rank: Rank) -> (GLWESecretPrepared<&mut [u8], B>, &mut Self); -} - -pub trait TakeGLWEPk { - fn take_glwe_pk(&mut self, infos: &A) -> (GLWEPublicKey<&mut [u8]>, &mut Self) - where - A: GLWEInfos; -} - -pub trait TakeGLWEPkPrepared { - fn take_glwe_pk_prepared(&mut self, infos: &A) -> (GLWEPublicKeyPrepared<&mut [u8], B>, &mut Self) - where - A: GLWEInfos; -} - -pub trait TakeGLWESwitchingKey { - fn take_glwe_switching_key(&mut self, infos: &A) -> (GGLWESwitchingKey<&mut [u8]>, &mut Self) - where - A: GGLWEInfos; -} - -pub trait TakeGGLWESwitchingKeyPrepared { - fn take_gglwe_switching_key_prepared(&mut self, infos: &A) -> (GGLWESwitchingKeyPrepared<&mut [u8], B>, &mut Self) - where - A: GGLWEInfos; -} - -pub trait TakeTensorKey { - fn take_tensor_key(&mut self, infos: &A) -> (GGLWETensorKey<&mut [u8]>, &mut Self) - where - A: GGLWEInfos; -} - -pub trait TakeGGLWETensorKeyPrepared { - fn take_gglwe_tensor_key_prepared(&mut self, infos: &A) -> (GGLWETensorKeyPrepared<&mut [u8], B>, &mut Self) - where - A: GGLWEInfos; -} - -pub trait TakeGGLWEAutomorphismKey { - fn take_gglwe_automorphism_key(&mut self, infos: &A) -> (GGLWEAutomorphismKey<&mut [u8]>, &mut Self) - where - A: GGLWEInfos; -} - -pub trait TakeGGLWEAutomorphismKeyPrepared { - fn take_gglwe_automorphism_key_prepared(&mut self, infos: &A) -> (GGLWEAutomorphismKeyPrepared<&mut [u8], B>, &mut Self) - where - A: GGLWEInfos; -} - -impl TakeGLWECt for Scratch +pub trait ScratchTakeCore where - Scratch: TakeVecZnx, + Self: ScratchTakeBasic + ScratchAvailable, { - fn take_glwe_ct(&mut self, infos: &A) -> (GLWECiphertext<&mut [u8]>, &mut Self) + fn take_glwe(&mut self, infos: &A) -> (GLWE<&mut [u8]>, &mut Self) where A: GLWEInfos, { let (data, scratch) = self.take_vec_znx(infos.n().into(), (infos.rank() + 1).into(), infos.size()); ( - GLWECiphertext::builder() - .base2k(infos.base2k()) - .k(infos.k()) - .data(data) - .build() - .unwrap(), + GLWE { + k: infos.k(), + base2k: infos.base2k(), + data, + }, scratch, ) } -} -impl TakeGLWECtSlice for Scratch -where - Scratch: TakeVecZnx, -{ - fn take_glwe_ct_slice(&mut self, size: usize, infos: &A) -> (Vec>, &mut Self) + fn take_glwe_slice(&mut self, size: usize, infos: &A) -> (Vec>, &mut Self) where A: GLWEInfos, { - let mut scratch: &mut Scratch = self; - let mut cts: Vec> = Vec::with_capacity(size); + let mut scratch: &mut Self = self; + let mut cts: Vec> = Vec::with_capacity(size); for _ in 0..size { - let (ct, new_scratch) = scratch.take_glwe_ct(infos); + let (ct, new_scratch) = scratch.take_glwe(infos); scratch = new_scratch; cts.push(ct); } (cts, scratch) } -} -impl TakeGLWEPt for Scratch -where - Scratch: TakeVecZnx, -{ - fn take_glwe_pt(&mut self, infos: &A) -> (GLWEPlaintext<&mut [u8]>, &mut Self) + fn take_glwe_plaintext(&mut self, infos: &A) -> (GLWEPlaintext<&mut [u8]>, &mut Self) where A: GLWEInfos, { let (data, scratch) = self.take_vec_znx(infos.n().into(), 1, infos.size()); ( - GLWEPlaintext::builder() - .base2k(infos.base2k()) - .k(infos.k()) - .data(data) - .build() - .unwrap(), + GLWEPlaintext { + k: infos.k(), + base2k: infos.base2k(), + data, + }, scratch, ) } -} -impl TakeGGLWE for Scratch -where - Scratch: TakeMatZnx, -{ - fn take_gglwe(&mut self, infos: &A) -> (GGLWECiphertext<&mut [u8]>, &mut Self) + fn take_gglwe(&mut self, infos: &A) -> (GGLWE<&mut [u8]>, &mut Self) where A: GGLWEInfos, { @@ -196,51 +75,41 @@ where infos.size(), ); ( - GGLWECiphertext::builder() - .base2k(infos.base2k()) - .k(infos.k()) - .dsize(infos.dsize()) - .data(data) - .build() - .unwrap(), + GGLWE { + k: infos.k(), + base2k: infos.base2k(), + dsize: infos.dsize(), + data, + }, scratch, ) } -} -impl TakeGGLWEPrepared for Scratch -where - Scratch: TakeVmpPMat, -{ - fn take_gglwe_prepared(&mut self, infos: &A) -> (GGLWECiphertextPrepared<&mut [u8], B>, &mut Self) + fn take_gglwe_prepared(&mut self, module: &M, infos: &A) -> (GGLWEPrepared<&mut [u8], B>, &mut Self) where A: GGLWEInfos, + M: ModuleN + VmpPMatBytesOf, { + assert_eq!(module.n() as u32, infos.n()); let (data, scratch) = self.take_vmp_pmat( - infos.n().into(), + module, infos.dnum().into(), infos.rank_in().into(), (infos.rank_out() + 1).into(), infos.size(), ); ( - GGLWECiphertextPrepared::builder() - .base2k(infos.base2k()) - .dsize(infos.dsize()) - .k(infos.k()) - .data(data) - .build() - .unwrap(), + GGLWEPrepared { + k: infos.k(), + base2k: infos.base2k(), + dsize: infos.dsize(), + data, + }, scratch, ) } -} -impl TakeGGSW for Scratch -where - Scratch: TakeMatZnx, -{ - fn take_ggsw(&mut self, infos: &A) -> (GGSWCiphertext<&mut [u8]>, &mut Self) + fn take_ggsw(&mut self, infos: &A) -> (GGSW<&mut [u8]>, &mut Self) where A: GGSWInfos, { @@ -252,112 +121,106 @@ where infos.size(), ); ( - GGSWCiphertext::builder() - .base2k(infos.base2k()) - .dsize(infos.dsize()) - .k(infos.k()) - .data(data) - .build() - .unwrap(), + GGSW { + k: infos.k(), + base2k: infos.base2k(), + dsize: infos.dsize(), + data, + }, scratch, ) } -} -impl TakeGGSWPrepared for Scratch -where - Scratch: TakeVmpPMat, -{ - fn take_ggsw_prepared(&mut self, infos: &A) -> (GGSWCiphertextPrepared<&mut [u8], B>, &mut Self) + fn take_ggsw_prepared(&mut self, module: &M, infos: &A) -> (GGSWPrepared<&mut [u8], B>, &mut Self) where A: GGSWInfos, + M: ModuleN + VmpPMatBytesOf, { + assert_eq!(module.n() as u32, infos.n()); let (data, scratch) = self.take_vmp_pmat( - infos.n().into(), + module, infos.dnum().into(), (infos.rank() + 1).into(), (infos.rank() + 1).into(), infos.size(), ); ( - GGSWCiphertextPrepared::builder() - .base2k(infos.base2k()) - .dsize(infos.dsize()) - .k(infos.k()) - .data(data) - .build() - .unwrap(), + GGSWPrepared { + k: infos.k(), + base2k: infos.base2k(), + dsize: infos.dsize(), + data, + }, scratch, ) } -} -impl TakeGGSWPreparedSlice for Scratch -where - Scratch: TakeGGSWPrepared, -{ - fn take_ggsw_prepared_slice(&mut self, size: usize, infos: &A) -> (Vec>, &mut Self) + fn take_ggsw_prepared_slice( + &mut self, + module: &M, + size: usize, + infos: &A, + ) -> (Vec>, &mut Self) where A: GGSWInfos, + M: ModuleN + VmpPMatBytesOf, { - let mut scratch: &mut Scratch = self; - let mut cts: Vec> = Vec::with_capacity(size); + let mut scratch: &mut Self = self; + let mut cts: Vec> = Vec::with_capacity(size); for _ in 0..size { - let (ct, new_scratch) = scratch.take_ggsw_prepared(infos); + let (ct, new_scratch) = scratch.take_ggsw_prepared(module, infos); scratch = new_scratch; cts.push(ct) } (cts, scratch) } -} -impl TakeGLWEPk for Scratch -where - Scratch: TakeVecZnx, -{ - fn take_glwe_pk(&mut self, infos: &A) -> (GLWEPublicKey<&mut [u8]>, &mut Self) + fn take_glwe_public_key(&mut self, infos: &A) -> (GLWEPublicKey<&mut [u8]>, &mut Self) where A: GLWEInfos, { - let (data, scratch) = self.take_vec_znx(infos.n().into(), (infos.rank() + 1).into(), infos.size()); + let (data, scratch) = self.take_glwe(infos); ( - GLWEPublicKey::builder() - .base2k(infos.base2k()) - .k(infos.k()) - .base2k(infos.base2k()) - .data(data) - .build() - .unwrap(), + GLWEPublicKey { + dist: Distribution::NONE, + key: data, + }, scratch, ) } -} -impl TakeGLWEPkPrepared for Scratch -where - Scratch: TakeVecZnxDft, -{ - fn take_glwe_pk_prepared(&mut self, infos: &A) -> (GLWEPublicKeyPrepared<&mut [u8], B>, &mut Self) + fn take_glwe_public_key_prepared(&mut self, module: &M, infos: &A) -> (GLWEPublicKeyPrepared<&mut [u8], B>, &mut Self) where A: GLWEInfos, + M: ModuleN + VecZnxDftBytesOf, { - let (data, scratch) = self.take_vec_znx_dft(infos.n().into(), (infos.rank() + 1).into(), infos.size()); + let (data, scratch) = self.take_glwe_prepared(module, infos); ( - GLWEPublicKeyPrepared::builder() - .base2k(infos.base2k()) - .k(infos.k()) - .data(data) - .build() - .unwrap(), + GLWEPublicKeyPrepared { + dist: Distribution::NONE, + key: data, + }, + scratch, + ) + } + + fn take_glwe_prepared(&mut self, module: &M, infos: &A) -> (GLWEPrepared<&mut [u8], B>, &mut Self) + where + A: GLWEInfos, + M: ModuleN + VecZnxDftBytesOf, + { + assert_eq!(module.n() as u32, infos.n()); + let (data, scratch) = self.take_vec_znx_dft(module, (infos.rank() + 1).into(), infos.size()); + ( + GLWEPrepared { + k: infos.k(), + base2k: infos.base2k(), + data, + }, scratch, ) } -} -impl TakeGLWESecret for Scratch -where - Scratch: TakeScalarZnx, -{ fn take_glwe_secret(&mut self, n: Degree, rank: Rank) -> (GLWESecret<&mut [u8]>, &mut Self) { let (data, scratch) = self.take_scalar_znx(n.into(), rank.into()); ( @@ -368,14 +231,12 @@ where scratch, ) } -} -impl TakeGLWESecretPrepared for Scratch -where - Scratch: TakeSvpPPol, -{ - fn take_glwe_secret_prepared(&mut self, n: Degree, rank: Rank) -> (GLWESecretPrepared<&mut [u8], B>, &mut Self) { - let (data, scratch) = self.take_svp_ppol(n.into(), rank.into()); + fn take_glwe_secret_prepared(&mut self, module: &M, rank: Rank) -> (GLWESecretPrepared<&mut [u8], B>, &mut Self) + where + M: ModuleN + SvpPPolBytesOf, + { + let (data, scratch) = self.take_svp_ppol(module, rank.into()); ( GLWESecretPrepared { data, @@ -384,141 +245,127 @@ where scratch, ) } -} -impl TakeGLWESwitchingKey for Scratch -where - Scratch: TakeMatZnx, -{ - fn take_glwe_switching_key(&mut self, infos: &A) -> (GGLWESwitchingKey<&mut [u8]>, &mut Self) + fn take_glwe_switching_key(&mut self, infos: &A) -> (GLWESwitchingKey<&mut [u8]>, &mut Self) where A: GGLWEInfos, { let (data, scratch) = self.take_gglwe(infos); ( - GGLWESwitchingKey { + GLWESwitchingKey { key: data, - sk_in_n: 0, - sk_out_n: 0, + input_degree: Degree(0), + output_degree: Degree(0), }, scratch, ) } -} -impl TakeGGLWESwitchingKeyPrepared for Scratch -where - Scratch: TakeGGLWEPrepared, -{ - fn take_gglwe_switching_key_prepared(&mut self, infos: &A) -> (GGLWESwitchingKeyPrepared<&mut [u8], B>, &mut Self) + fn take_glwe_switching_key_prepared( + &mut self, + module: &M, + infos: &A, + ) -> (GLWESwitchingKeyPrepared<&mut [u8], B>, &mut Self) where A: GGLWEInfos, + M: ModuleN + VmpPMatBytesOf, { - let (data, scratch) = self.take_gglwe_prepared(infos); + assert_eq!(module.n() as u32, infos.n()); + let (data, scratch) = self.take_gglwe_prepared(module, infos); ( - GGLWESwitchingKeyPrepared { + GLWESwitchingKeyPrepared { key: data, - sk_in_n: 0, - sk_out_n: 0, + input_degree: Degree(0), + output_degree: Degree(0), }, scratch, ) } -} -impl TakeGGLWEAutomorphismKey for Scratch -where - Scratch: TakeMatZnx, -{ - fn take_gglwe_automorphism_key(&mut self, infos: &A) -> (GGLWEAutomorphismKey<&mut [u8]>, &mut Self) + fn take_glwe_automorphism_key(&mut self, infos: &A) -> (GLWEAutomorphismKey<&mut [u8]>, &mut Self) where A: GGLWEInfos, { - let (data, scratch) = self.take_glwe_switching_key(infos); - (GGLWEAutomorphismKey { key: data, p: 0 }, scratch) + let (data, scratch) = self.take_gglwe(infos); + (GLWEAutomorphismKey { key: data, p: 0 }, scratch) } -} -impl TakeGGLWEAutomorphismKeyPrepared for Scratch -where - Scratch: TakeGGLWESwitchingKeyPrepared, -{ - fn take_gglwe_automorphism_key_prepared(&mut self, infos: &A) -> (GGLWEAutomorphismKeyPrepared<&mut [u8], B>, &mut Self) + fn take_glwe_automorphism_key_prepared( + &mut self, + module: &M, + infos: &A, + ) -> (GLWEAutomorphismKeyPrepared<&mut [u8], B>, &mut Self) where A: GGLWEInfos, + M: ModuleN + VmpPMatBytesOf, { - let (data, scratch) = self.take_gglwe_switching_key_prepared(infos); - (GGLWEAutomorphismKeyPrepared { key: data, p: 0 }, scratch) + assert_eq!(module.n() as u32, infos.n()); + let (data, scratch) = self.take_gglwe_prepared(module, infos); + (GLWEAutomorphismKeyPrepared { key: data, p: 0 }, scratch) } -} -impl TakeTensorKey for Scratch -where - Scratch: TakeMatZnx, -{ - fn take_tensor_key(&mut self, infos: &A) -> (GGLWETensorKey<&mut [u8]>, &mut Self) + fn take_glwe_tensor_key(&mut self, infos: &A) -> (GLWETensorKey<&mut [u8]>, &mut Self) where A: GGLWEInfos, { assert_eq!( infos.rank_in(), infos.rank_out(), - "rank_in != rank_out is not supported for GGLWETensorKey" + "rank_in != rank_out is not supported for GLWETensorKey" ); - let mut keys: Vec> = Vec::new(); + let mut keys: Vec> = Vec::new(); let pairs: usize = (((infos.rank_out().0 + 1) * infos.rank_out().0) >> 1).max(1) as usize; - let mut scratch: &mut Scratch = self; + let mut scratch: &mut Self = self; - let mut ksk_infos: crate::layouts::GGLWECiphertextLayout = infos.layout(); + let mut ksk_infos: GGLWELayout = infos.gglwe_layout(); ksk_infos.rank_in = Rank(1); if pairs != 0 { - let (gglwe, s) = scratch.take_glwe_switching_key(&ksk_infos); + let (gglwe, s) = scratch.take_gglwe(&ksk_infos); scratch = s; keys.push(gglwe); } for _ in 1..pairs { - let (gglwe, s) = scratch.take_glwe_switching_key(&ksk_infos); + let (gglwe, s) = scratch.take_gglwe(&ksk_infos); scratch = s; keys.push(gglwe); } - (GGLWETensorKey { keys }, scratch) + (GLWETensorKey { keys }, scratch) } -} -impl TakeGGLWETensorKeyPrepared for Scratch -where - Scratch: TakeVmpPMat, -{ - fn take_gglwe_tensor_key_prepared(&mut self, infos: &A) -> (GGLWETensorKeyPrepared<&mut [u8], B>, &mut Self) + fn take_glwe_tensor_key_prepared(&mut self, module: &M, infos: &A) -> (GLWETensorKeyPrepared<&mut [u8], B>, &mut Self) where A: GGLWEInfos, + M: ModuleN + VmpPMatBytesOf, { + assert_eq!(module.n() as u32, infos.n()); assert_eq!( infos.rank_in(), infos.rank_out(), "rank_in != rank_out is not supported for GGLWETensorKeyPrepared" ); - let mut keys: Vec> = Vec::new(); + let mut keys: Vec> = Vec::new(); let pairs: usize = (((infos.rank_out().0 + 1) * infos.rank_out().0) >> 1).max(1) as usize; - let mut scratch: &mut Scratch = self; + let mut scratch: &mut Self = self; - let mut ksk_infos: crate::layouts::GGLWECiphertextLayout = infos.layout(); + let mut ksk_infos: GGLWELayout = infos.gglwe_layout(); ksk_infos.rank_in = Rank(1); if pairs != 0 { - let (gglwe, s) = scratch.take_gglwe_switching_key_prepared(&ksk_infos); + let (gglwe, s) = scratch.take_gglwe_prepared(module, &ksk_infos); scratch = s; keys.push(gglwe); } for _ in 1..pairs { - let (gglwe, s) = scratch.take_gglwe_switching_key_prepared(&ksk_infos); + let (gglwe, s) = scratch.take_gglwe_prepared(module, &ksk_infos); scratch = s; keys.push(gglwe); } - (GGLWETensorKeyPrepared { keys }, scratch) + (GLWETensorKeyPrepared { keys }, scratch) } } + +impl ScratchTakeCore for Scratch where Self: ScratchTakeBasic + ScratchAvailable {} diff --git a/poulpy-core/src/tests/mod.rs b/poulpy-core/src/tests/mod.rs index 23b489b..dd16db0 100644 --- a/poulpy-core/src/tests/mod.rs +++ b/poulpy-core/src/tests/mod.rs @@ -8,175 +8,174 @@ use poulpy_hal::backend_test_suite; #[cfg(test)] backend_test_suite!( - mod cpu_spqlios, - backend = poulpy_backend::cpu_spqlios::FFT64Spqlios, - size = 1<<8, - tests = { - // GLWE Encryption - glwe_encrypt_sk => crate::tests::test_suite::encryption::test_glwe_encrypt_sk, - glwe_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_glwe_compressed_encrypt_sk, - glwe_encrypt_zero_sk => crate::tests::test_suite::encryption::test_glwe_encrypt_zero_sk, - glwe_encrypt_pk => crate::tests::test_suite::encryption::test_glwe_encrypt_pk, - // GLWE Keyswitch - glwe_keyswitch => crate::tests::test_suite::keyswitch::test_glwe_keyswitch, - glwe_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_glwe_keyswitch_inplace, - // GLWE Automorphism - glwe_automorphism => crate::tests::test_suite::automorphism::test_glwe_automorphism, - glwe_automorphism_inplace => crate::tests::test_suite::automorphism::test_glwe_automorphism_inplace, - // GLWE External Product - glwe_external_product => crate::tests::test_suite::external_product::test_glwe_external_product, - glwe_external_product_inplace => crate::tests::test_suite::external_product::test_glwe_external_product_inplace, - // GLWE Trace - glwe_trace_inplace => crate::tests::test_suite::test_glwe_trace_inplace, - glwe_packing => crate::tests::test_suite::test_glwe_packing, - // GGLWE Encryption - gglwe_switching_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_switching_key_encrypt_sk, - gglwe_switching_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_switching_key_compressed_encrypt_sk, - gglwe_automorphisk_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_automorphisk_key_encrypt_sk, - gglwe_automorphisk_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_automorphisk_key_compressed_encrypt_sk, - gglwe_tensor_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_tensor_key_encrypt_sk, - gglwe_tensor_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_tensor_key_compressed_encrypt_sk, - // GGLWE Keyswitching - gglwe_switching_key_keyswitch => crate::tests::test_suite::keyswitch::test_gglwe_switching_key_keyswitch, - gglwe_switching_key_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_gglwe_switching_key_keyswitch_inplace, - // GGLWE External Product - gglwe_switching_key_external_product => crate::tests::test_suite::external_product::test_gglwe_switching_key_external_product, - gglwe_switching_key_external_product_inplace => crate::tests::test_suite::external_product::test_gglwe_switching_key_external_product_inplace, - // GGLWE Automorphism - gglwe_automorphism_key_automorphism => crate::tests::test_suite::automorphism::test_gglwe_automorphism_key_automorphism, - gglwe_automorphism_key_automorphism_inplace => crate::tests::test_suite::automorphism::test_gglwe_automorphism_key_automorphism_inplace, - // GGSW Encryption - ggsw_encrypt_sk => crate::tests::test_suite::encryption::test_ggsw_encrypt_sk, - ggsw_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_ggsw_compressed_encrypt_sk, - // GGSW Keyswitching - ggsw_keyswitch => crate::tests::test_suite::keyswitch::test_ggsw_keyswitch, - ggsw_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_ggsw_keyswitch_inplace, - // GGSW External Product - ggsw_external_product => crate::tests::test_suite::external_product::test_ggsw_external_product, - ggsw_external_product_inplace => crate::tests::test_suite::external_product::test_ggsw_external_product_inplace, - // GGSW Automorphism - ggsw_automorphism => crate::tests::test_suite::automorphism::test_ggsw_automorphism, - ggsw_automorphism_inplace => crate::tests::test_suite::automorphism::test_ggsw_automorphism_inplace, - // LWE - lwe_keyswitch => crate::tests::test_suite::keyswitch::test_lwe_keyswitch, - glwe_to_lwe => crate::tests::test_suite::test_glwe_to_lwe, - lwe_to_glwe => crate::tests::test_suite::test_lwe_to_glwe, - } + mod cpu_spqlios, + backend = poulpy_backend::cpu_spqlios::FFT64Spqlios, + size = 1<<8, + tests = { + //GLWE Encryption + glwe_encrypt_sk => crate::tests::test_suite::encryption::test_glwe_encrypt_sk, + glwe_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_glwe_compressed_encrypt_sk, + glwe_encrypt_zero_sk => crate::tests::test_suite::encryption::test_glwe_encrypt_zero_sk, + glwe_encrypt_pk => crate::tests::test_suite::encryption::test_glwe_encrypt_pk, +// GLWE Keyswitch +glwe_keyswitch => crate::tests::test_suite::keyswitch::test_glwe_keyswitch, +glwe_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_glwe_keyswitch_inplace, +// GLWE Automorphism +glwe_automorphism => crate::tests::test_suite::automorphism::test_glwe_automorphism, +glwe_automorphism_inplace => crate::tests::test_suite::automorphism::test_glwe_automorphism_inplace, +// GLWE External Product +glwe_external_product => crate::tests::test_suite::external_product::test_glwe_external_product, +glwe_external_product_inplace => crate::tests::test_suite::external_product::test_glwe_external_product_inplace, +// GLWE Trace +glwe_trace_inplace => crate::tests::test_suite::test_glwe_trace_inplace, +glwe_packing => crate::tests::test_suite::test_glwe_packing, +// GGLWE Encryption +gglwe_switching_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_switching_key_encrypt_sk, +gglwe_switching_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_switching_key_compressed_encrypt_sk, +gglwe_automorphism_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_automorphism_key_encrypt_sk, +gglwe_automorphism_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_automorphism_key_compressed_encrypt_sk, +gglwe_tensor_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_tensor_key_encrypt_sk, +gglwe_tensor_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_tensor_key_compressed_encrypt_sk, +// GGLWE Keyswitching +gglwe_switching_key_keyswitch => crate::tests::test_suite::keyswitch::test_gglwe_switching_key_keyswitch, +gglwe_switching_key_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_gglwe_switching_key_keyswitch_inplace, +// GGLWE External Product +gglwe_switching_key_external_product => crate::tests::test_suite::external_product::test_gglwe_switching_key_external_product, +gglwe_switching_key_external_product_inplace => crate::tests::test_suite::external_product::test_gglwe_switching_key_external_product_inplace, +// GGLWE Automorphism +gglwe_automorphism_key_automorphism => crate::tests::test_suite::automorphism::test_gglwe_automorphism_key_automorphism, +gglwe_automorphism_key_automorphism_inplace => crate::tests::test_suite::automorphism::test_gglwe_automorphism_key_automorphism_inplace, +// GGSW Encryption +ggsw_encrypt_sk => crate::tests::test_suite::encryption::test_ggsw_encrypt_sk, +ggsw_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_ggsw_compressed_encrypt_sk, +// GGSW Keyswitching +ggsw_keyswitch => crate::tests::test_suite::keyswitch::test_ggsw_keyswitch, +ggsw_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_ggsw_keyswitch_inplace, +// GGSW External Product +ggsw_external_product => crate::tests::test_suite::external_product::test_ggsw_external_product, +ggsw_external_product_inplace => crate::tests::test_suite::external_product::test_ggsw_external_product_inplace, +// GGSW Automorphism +ggsw_automorphism => crate::tests::test_suite::automorphism::test_ggsw_automorphism, +ggsw_automorphism_inplace => crate::tests::test_suite::automorphism::test_ggsw_automorphism_inplace, +// LWE +lwe_keyswitch => crate::tests::test_suite::keyswitch::test_lwe_keyswitch, +glwe_to_lwe => crate::tests::test_suite::test_glwe_to_lwe, +lwe_to_glwe => crate::tests::test_suite::test_lwe_to_glwe, +} ); - #[cfg(test)] backend_test_suite!( - mod cpu_ref, - backend = poulpy_backend::cpu_fft64_ref::FFT64Ref, - size = 1<<8, - tests = { - // GLWE Encryption - glwe_encrypt_sk => crate::tests::test_suite::encryption::test_glwe_encrypt_sk, - glwe_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_glwe_compressed_encrypt_sk, - glwe_encrypt_zero_sk => crate::tests::test_suite::encryption::test_glwe_encrypt_zero_sk, - glwe_encrypt_pk => crate::tests::test_suite::encryption::test_glwe_encrypt_pk, - // GLWE Keyswitch - glwe_keyswitch => crate::tests::test_suite::keyswitch::test_glwe_keyswitch, - glwe_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_glwe_keyswitch_inplace, - // GLWE Automorphism - glwe_automorphism => crate::tests::test_suite::automorphism::test_glwe_automorphism, - glwe_automorphism_inplace => crate::tests::test_suite::automorphism::test_glwe_automorphism_inplace, - // GLWE External Product - glwe_external_product => crate::tests::test_suite::external_product::test_glwe_external_product, - glwe_external_product_inplace => crate::tests::test_suite::external_product::test_glwe_external_product_inplace, - // GLWE Trace - glwe_trace_inplace => crate::tests::test_suite::test_glwe_trace_inplace, - glwe_packing => crate::tests::test_suite::test_glwe_packing, - // GGLWE Encryption - gglwe_switching_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_switching_key_encrypt_sk, - gglwe_switching_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_switching_key_compressed_encrypt_sk, - gglwe_automorphisk_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_automorphisk_key_encrypt_sk, - gglwe_automorphisk_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_automorphisk_key_compressed_encrypt_sk, - gglwe_tensor_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_tensor_key_encrypt_sk, - gglwe_tensor_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_tensor_key_compressed_encrypt_sk, - // GGLWE Keyswitching - gglwe_switching_key_keyswitch => crate::tests::test_suite::keyswitch::test_gglwe_switching_key_keyswitch, - gglwe_switching_key_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_gglwe_switching_key_keyswitch_inplace, - // GGLWE External Product - gglwe_switching_key_external_product => crate::tests::test_suite::external_product::test_gglwe_switching_key_external_product, - gglwe_switching_key_external_product_inplace => crate::tests::test_suite::external_product::test_gglwe_switching_key_external_product_inplace, - // GGLWE Automorphism - gglwe_automorphism_key_automorphism => crate::tests::test_suite::automorphism::test_gglwe_automorphism_key_automorphism, - gglwe_automorphism_key_automorphism_inplace => crate::tests::test_suite::automorphism::test_gglwe_automorphism_key_automorphism_inplace, - // GGSW Encryption - ggsw_encrypt_sk => crate::tests::test_suite::encryption::test_ggsw_encrypt_sk, - ggsw_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_ggsw_compressed_encrypt_sk, - // GGSW Keyswitching - ggsw_keyswitch => crate::tests::test_suite::keyswitch::test_ggsw_keyswitch, - ggsw_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_ggsw_keyswitch_inplace, - // GGSW External Product - ggsw_external_product => crate::tests::test_suite::external_product::test_ggsw_external_product, - ggsw_external_product_inplace => crate::tests::test_suite::external_product::test_ggsw_external_product_inplace, - // GGSW Automorphism - ggsw_automorphism => crate::tests::test_suite::automorphism::test_ggsw_automorphism, - ggsw_automorphism_inplace => crate::tests::test_suite::automorphism::test_ggsw_automorphism_inplace, - // LWE - lwe_keyswitch => crate::tests::test_suite::keyswitch::test_lwe_keyswitch, - glwe_to_lwe => crate::tests::test_suite::test_glwe_to_lwe, - lwe_to_glwe => crate::tests::test_suite::test_lwe_to_glwe, - } -); + mod cpu_ref, + backend = poulpy_backend::cpu_fft64_ref::FFT64Ref, + size = 1<<8, + tests = { + //GLWE Encryption + glwe_encrypt_sk => crate::tests::test_suite::encryption::test_glwe_encrypt_sk, + glwe_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_glwe_compressed_encrypt_sk, + glwe_encrypt_zero_sk => crate::tests::test_suite::encryption::test_glwe_encrypt_zero_sk, + glwe_encrypt_pk => crate::tests::test_suite::encryption::test_glwe_encrypt_pk, +// GLWE Keyswitch + glwe_keyswitch => crate::tests::test_suite::keyswitch::test_glwe_keyswitch, +glwe_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_glwe_keyswitch_inplace, +// GLWE Automorphism +glwe_automorphism => crate::tests::test_suite::automorphism::test_glwe_automorphism, +glwe_automorphism_inplace => crate::tests::test_suite::automorphism::test_glwe_automorphism_inplace, +// GLWE External Product +glwe_external_product => crate::tests::test_suite::external_product::test_glwe_external_product, +glwe_external_product_inplace => crate::tests::test_suite::external_product::test_glwe_external_product_inplace, +// GLWE Trace +glwe_trace_inplace => crate::tests::test_suite::test_glwe_trace_inplace, +glwe_packing => crate::tests::test_suite::test_glwe_packing, +// GGLWE Encryption +gglwe_switching_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_switching_key_encrypt_sk, +gglwe_switching_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_switching_key_compressed_encrypt_sk, +gglwe_automorphism_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_automorphism_key_encrypt_sk, +gglwe_automorphism_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_automorphism_key_compressed_encrypt_sk, +gglwe_tensor_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_tensor_key_encrypt_sk, +gglwe_tensor_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_tensor_key_compressed_encrypt_sk, +// GGLWE Keyswitching +gglwe_switching_key_keyswitch => crate::tests::test_suite::keyswitch::test_gglwe_switching_key_keyswitch, +gglwe_switching_key_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_gglwe_switching_key_keyswitch_inplace, +// GGLWE External Product +gglwe_switching_key_external_product => crate::tests::test_suite::external_product::test_gglwe_switching_key_external_product, +gglwe_switching_key_external_product_inplace => crate::tests::test_suite::external_product::test_gglwe_switching_key_external_product_inplace, +// GGLWE Automorphism +gglwe_automorphism_key_automorphism => crate::tests::test_suite::automorphism::test_gglwe_automorphism_key_automorphism, +gglwe_automorphism_key_automorphism_inplace => crate::tests::test_suite::automorphism::test_gglwe_automorphism_key_automorphism_inplace, +// GGSW Encryption +ggsw_encrypt_sk => crate::tests::test_suite::encryption::test_ggsw_encrypt_sk, +ggsw_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_ggsw_compressed_encrypt_sk, +// GGSW Keyswitching +ggsw_keyswitch => crate::tests::test_suite::keyswitch::test_ggsw_keyswitch, +ggsw_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_ggsw_keyswitch_inplace, +// GGSW External Product +ggsw_external_product => crate::tests::test_suite::external_product::test_ggsw_external_product, +ggsw_external_product_inplace => crate::tests::test_suite::external_product::test_ggsw_external_product_inplace, +// GGSW Automorphism +ggsw_automorphism => crate::tests::test_suite::automorphism::test_ggsw_automorphism, +ggsw_automorphism_inplace => crate::tests::test_suite::automorphism::test_ggsw_automorphism_inplace, +// LWE +lwe_keyswitch => crate::tests::test_suite::keyswitch::test_lwe_keyswitch, +glwe_to_lwe => crate::tests::test_suite::test_glwe_to_lwe, +lwe_to_glwe => crate::tests::test_suite::test_lwe_to_glwe, +} + ); #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] #[cfg(test)] backend_test_suite!( - mod cpu_avx, - backend = poulpy_backend::cpu_fft64_avx::FFT64Avx, - size = 1<<8, - tests = { - // GLWE Encryption - glwe_encrypt_sk => crate::tests::test_suite::encryption::test_glwe_encrypt_sk, - glwe_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_glwe_compressed_encrypt_sk, - glwe_encrypt_zero_sk => crate::tests::test_suite::encryption::test_glwe_encrypt_zero_sk, - glwe_encrypt_pk => crate::tests::test_suite::encryption::test_glwe_encrypt_pk, - // GLWE Keyswitch - glwe_keyswitch => crate::tests::test_suite::keyswitch::test_glwe_keyswitch, - glwe_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_glwe_keyswitch_inplace, - // GLWE Automorphism - glwe_automorphism => crate::tests::test_suite::automorphism::test_glwe_automorphism, - glwe_automorphism_inplace => crate::tests::test_suite::automorphism::test_glwe_automorphism_inplace, - // GLWE External Product - glwe_external_product => crate::tests::test_suite::external_product::test_glwe_external_product, - glwe_external_product_inplace => crate::tests::test_suite::external_product::test_glwe_external_product_inplace, - // GLWE Trace - glwe_trace_inplace => crate::tests::test_suite::test_glwe_trace_inplace, - glwe_packing => crate::tests::test_suite::test_glwe_packing, - // GGLWE Encryption - gglwe_switching_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_switching_key_encrypt_sk, - gglwe_switching_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_switching_key_compressed_encrypt_sk, - gglwe_automorphisk_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_automorphisk_key_encrypt_sk, - gglwe_automorphisk_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_automorphisk_key_compressed_encrypt_sk, - gglwe_tensor_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_tensor_key_encrypt_sk, - gglwe_tensor_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_tensor_key_compressed_encrypt_sk, - // GGLWE Keyswitching - gglwe_switching_key_keyswitch => crate::tests::test_suite::keyswitch::test_gglwe_switching_key_keyswitch, - gglwe_switching_key_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_gglwe_switching_key_keyswitch_inplace, - // GGLWE External Product - gglwe_switching_key_external_product => crate::tests::test_suite::external_product::test_gglwe_switching_key_external_product, - gglwe_switching_key_external_product_inplace => crate::tests::test_suite::external_product::test_gglwe_switching_key_external_product_inplace, - // GGLWE Automorphism - gglwe_automorphism_key_automorphism => crate::tests::test_suite::automorphism::test_gglwe_automorphism_key_automorphism, - gglwe_automorphism_key_automorphism_inplace => crate::tests::test_suite::automorphism::test_gglwe_automorphism_key_automorphism_inplace, - // GGSW Encryption - ggsw_encrypt_sk => crate::tests::test_suite::encryption::test_ggsw_encrypt_sk, - ggsw_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_ggsw_compressed_encrypt_sk, - // GGSW Keyswitching - ggsw_keyswitch => crate::tests::test_suite::keyswitch::test_ggsw_keyswitch, - ggsw_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_ggsw_keyswitch_inplace, - // GGSW External Product - ggsw_external_product => crate::tests::test_suite::external_product::test_ggsw_external_product, - ggsw_external_product_inplace => crate::tests::test_suite::external_product::test_ggsw_external_product_inplace, - // GGSW Automorphism - ggsw_automorphism => crate::tests::test_suite::automorphism::test_ggsw_automorphism, - ggsw_automorphism_inplace => crate::tests::test_suite::automorphism::test_ggsw_automorphism_inplace, - // LWE - lwe_keyswitch => crate::tests::test_suite::keyswitch::test_lwe_keyswitch, - glwe_to_lwe => crate::tests::test_suite::test_glwe_to_lwe, - lwe_to_glwe => crate::tests::test_suite::test_lwe_to_glwe, - } -); + mod cpu_avx, + backend = poulpy_backend::cpu_fft64_avx::FFT64Avx, + size = 1<<8, + tests = { + //GLWE Encryption + glwe_encrypt_sk => crate::tests::test_suite::encryption::test_glwe_encrypt_sk, + glwe_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_glwe_compressed_encrypt_sk, + glwe_encrypt_zero_sk => crate::tests::test_suite::encryption::test_glwe_encrypt_zero_sk, + glwe_encrypt_pk => crate::tests::test_suite::encryption::test_glwe_encrypt_pk, +// GLWE Keyswitch +glwe_keyswitch => crate::tests::test_suite::keyswitch::test_glwe_keyswitch, +glwe_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_glwe_keyswitch_inplace, +// GLWE Automorphism +glwe_automorphism => crate::tests::test_suite::automorphism::test_glwe_automorphism, +glwe_automorphism_inplace => crate::tests::test_suite::automorphism::test_glwe_automorphism_inplace, +// GLWE External Product +glwe_external_product => crate::tests::test_suite::external_product::test_glwe_external_product, +glwe_external_product_inplace => crate::tests::test_suite::external_product::test_glwe_external_product_inplace, +// GLWE Trace +glwe_trace_inplace => crate::tests::test_suite::test_glwe_trace_inplace, +glwe_packing => crate::tests::test_suite::test_glwe_packing, +// GGLWE Encryption +gglwe_switching_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_switching_key_encrypt_sk, +gglwe_switching_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_switching_key_compressed_encrypt_sk, +gglwe_automorphism_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_automorphism_key_encrypt_sk, +gglwe_automorphism_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_automorphism_key_compressed_encrypt_sk, +gglwe_tensor_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_tensor_key_encrypt_sk, +gglwe_tensor_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_tensor_key_compressed_encrypt_sk, +// GGLWE Keyswitching +gglwe_switching_key_keyswitch => crate::tests::test_suite::keyswitch::test_gglwe_switching_key_keyswitch, +gglwe_switching_key_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_gglwe_switching_key_keyswitch_inplace, +// GGLWE External Product +gglwe_switching_key_external_product => crate::tests::test_suite::external_product::test_gglwe_switching_key_external_product, +gglwe_switching_key_external_product_inplace => crate::tests::test_suite::external_product::test_gglwe_switching_key_external_product_inplace, +// GGLWE Automorphism +gglwe_automorphism_key_automorphism => crate::tests::test_suite::automorphism::test_gglwe_automorphism_key_automorphism, +gglwe_automorphism_key_automorphism_inplace => crate::tests::test_suite::automorphism::test_gglwe_automorphism_key_automorphism_inplace, +// GGSW Encryption +ggsw_encrypt_sk => crate::tests::test_suite::encryption::test_ggsw_encrypt_sk, +ggsw_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_ggsw_compressed_encrypt_sk, +// GGSW Keyswitching +ggsw_keyswitch => crate::tests::test_suite::keyswitch::test_ggsw_keyswitch, +ggsw_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_ggsw_keyswitch_inplace, +// GGSW External Product +ggsw_external_product => crate::tests::test_suite::external_product::test_ggsw_external_product, +ggsw_external_product_inplace => crate::tests::test_suite::external_product::test_ggsw_external_product_inplace, +// GGSW Automorphism +ggsw_automorphism => crate::tests::test_suite::automorphism::test_ggsw_automorphism, +ggsw_automorphism_inplace => crate::tests::test_suite::automorphism::test_ggsw_automorphism_inplace, +// LWE +lwe_keyswitch => crate::tests::test_suite::keyswitch::test_lwe_keyswitch, +glwe_to_lwe => crate::tests::test_suite::test_glwe_to_lwe, +lwe_to_glwe => crate::tests::test_suite::test_lwe_to_glwe, +} + ); diff --git a/poulpy-core/src/tests/serialization.rs b/poulpy-core/src/tests/serialization.rs index 8fe477c..c67d87d 100644 --- a/poulpy-core/src/tests/serialization.rs +++ b/poulpy-core/src/tests/serialization.rs @@ -1,12 +1,12 @@ use poulpy_hal::test_suite::serialization::test_reader_writer_interface; use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWEAutomorphismKey, GGLWECiphertext, GGLWESwitchingKey, GGLWETensorKey, GGSWCiphertext, - GLWECiphertext, GLWEToLWEKey, LWECiphertext, LWESwitchingKey, LWEToGLWESwitchingKey, Rank, TorusPrecision, + Base2K, Degree, Dnum, Dsize, GGLWE, GGSW, GLWE, GLWEAutomorphismKey, GLWESwitchingKey, GLWETensorKey, GLWEToLWESwitchingKey, + LWE, LWESwitchingKey, LWEToGLWESwitchingKey, Rank, TorusPrecision, compressed::{ - GGLWEAutomorphismKeyCompressed, GGLWECiphertextCompressed, GGLWESwitchingKeyCompressed, GGLWETensorKeyCompressed, - GGSWCiphertextCompressed, GLWECiphertextCompressed, GLWEToLWESwitchingKeyCompressed, LWECiphertextCompressed, - LWESwitchingKeyCompressed, LWEToGLWESwitchingKeyCompressed, + GGLWECompressed, GGSWCompressed, GLWEAutomorphismKeyCompressed, GLWECompressed, GLWESwitchingKeyCompressed, + GLWETensorKeyCompressed, GLWEToLWESwitchingKeyCompressed, LWECompressed, LWESwitchingKeyCompressed, + LWEToGLWESwitchingKeyCompressed, }, }; @@ -20,125 +20,124 @@ const DSIZE: Dsize = Dsize(1); #[test] fn glwe_serialization() { - let original: GLWECiphertext> = GLWECiphertext::alloc_with(N_GLWE, BASE2K, K, RANK); + let original: GLWE> = GLWE::alloc(N_GLWE, BASE2K, K, RANK); poulpy_hal::test_suite::serialization::test_reader_writer_interface(original); } #[test] fn glwe_compressed_serialization() { - let original: GLWECiphertextCompressed> = GLWECiphertextCompressed::alloc_with(N_GLWE, BASE2K, K, RANK); + let original: GLWECompressed> = GLWECompressed::alloc(N_GLWE, BASE2K, K, RANK); test_reader_writer_interface(original); } #[test] fn lwe_serialization() { - let original: LWECiphertext> = LWECiphertext::alloc_with(N_LWE, BASE2K, K); + let original: LWE> = LWE::alloc(N_LWE, BASE2K, K); test_reader_writer_interface(original); } #[test] fn lwe_compressed_serialization() { - let original: LWECiphertextCompressed> = LWECiphertextCompressed::alloc_with(BASE2K, K); + let original: LWECompressed> = LWECompressed::alloc(BASE2K, K); test_reader_writer_interface(original); } #[test] fn test_gglwe_serialization() { - let original: GGLWECiphertext> = GGLWECiphertext::alloc_with(N_GLWE, BASE2K, K, RANK, RANK, DNUM, DSIZE); + let original: GGLWE> = GGLWE::alloc(N_GLWE, BASE2K, K, RANK, RANK, DNUM, DSIZE); test_reader_writer_interface(original); } #[test] fn test_gglwe_compressed_serialization() { - let original: GGLWECiphertextCompressed> = - GGLWECiphertextCompressed::alloc_with(N_GLWE, BASE2K, K, RANK, RANK, DNUM, DSIZE); + let original: GGLWECompressed> = GGLWECompressed::alloc(N_GLWE, BASE2K, K, RANK, RANK, DNUM, DSIZE); test_reader_writer_interface(original); } #[test] fn test_glwe_switching_key_serialization() { - let original: GGLWESwitchingKey> = GGLWESwitchingKey::alloc_with(N_GLWE, BASE2K, K, RANK, RANK, DNUM, DSIZE); + let original: GLWESwitchingKey> = GLWESwitchingKey::alloc(N_GLWE, BASE2K, K, RANK, RANK, DNUM, DSIZE); test_reader_writer_interface(original); } #[test] fn test_glwe_switching_key_compressed_serialization() { - let original: GGLWESwitchingKeyCompressed> = - GGLWESwitchingKeyCompressed::alloc_with(N_GLWE, BASE2K, K, RANK, RANK, DNUM, DSIZE); + let original: GLWESwitchingKeyCompressed> = + GLWESwitchingKeyCompressed::alloc(N_GLWE, BASE2K, K, RANK, RANK, DNUM, DSIZE); test_reader_writer_interface(original); } #[test] fn test_automorphism_key_serialization() { - let original: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc_with(N_GLWE, BASE2K, K, RANK, DNUM, DSIZE); + let original: GLWEAutomorphismKey> = GLWEAutomorphismKey::alloc(N_GLWE, BASE2K, K, RANK, DNUM, DSIZE); test_reader_writer_interface(original); } #[test] fn test_automorphism_key_compressed_serialization() { - let original: GGLWEAutomorphismKeyCompressed> = - GGLWEAutomorphismKeyCompressed::alloc_with(N_GLWE, BASE2K, K, RANK, DNUM, DSIZE); + let original: GLWEAutomorphismKeyCompressed> = + GLWEAutomorphismKeyCompressed::alloc(N_GLWE, BASE2K, K, RANK, DNUM, DSIZE); test_reader_writer_interface(original); } #[test] fn test_tensor_key_serialization() { - let original: GGLWETensorKey> = GGLWETensorKey::alloc_with(N_GLWE, BASE2K, K, RANK, DNUM, DSIZE); + let original: GLWETensorKey> = GLWETensorKey::alloc(N_GLWE, BASE2K, K, RANK, DNUM, DSIZE); test_reader_writer_interface(original); } #[test] fn test_tensor_key_compressed_serialization() { - let original: GGLWETensorKeyCompressed> = GGLWETensorKeyCompressed::alloc_with(N_GLWE, BASE2K, K, RANK, DNUM, DSIZE); + let original: GLWETensorKeyCompressed> = GLWETensorKeyCompressed::alloc(N_GLWE, BASE2K, K, RANK, DNUM, DSIZE); test_reader_writer_interface(original); } #[test] fn glwe_to_lwe_switching_key_serialization() { - let original: GLWEToLWEKey> = GLWEToLWEKey::alloc_with(N_GLWE, BASE2K, K, RANK, DNUM); + let original: GLWEToLWESwitchingKey> = GLWEToLWESwitchingKey::alloc(N_GLWE, BASE2K, K, RANK, DNUM); test_reader_writer_interface(original); } #[test] fn glwe_to_lwe_switching_key_compressed_serialization() { let original: GLWEToLWESwitchingKeyCompressed> = - GLWEToLWESwitchingKeyCompressed::alloc_with(N_GLWE, BASE2K, K, RANK, DNUM); + GLWEToLWESwitchingKeyCompressed::alloc(N_GLWE, BASE2K, K, RANK, DNUM); test_reader_writer_interface(original); } #[test] fn lwe_to_glwe_switching_key_serialization() { - let original: LWEToGLWESwitchingKey> = LWEToGLWESwitchingKey::alloc_with(N_GLWE, BASE2K, K, RANK, DNUM); + let original: LWEToGLWESwitchingKey> = LWEToGLWESwitchingKey::alloc(N_GLWE, BASE2K, K, RANK, DNUM); test_reader_writer_interface(original); } #[test] fn lwe_to_glwe_switching_key_compressed_serialization() { let original: LWEToGLWESwitchingKeyCompressed> = - LWEToGLWESwitchingKeyCompressed::alloc_with(N_GLWE, BASE2K, K, RANK, DNUM); + LWEToGLWESwitchingKeyCompressed::alloc(N_GLWE, BASE2K, K, RANK, DNUM); test_reader_writer_interface(original); } #[test] fn lwe_switching_key_serialization() { - let original: LWESwitchingKey> = LWESwitchingKey::alloc_with(N_GLWE, BASE2K, K, DNUM); + let original: LWESwitchingKey> = LWESwitchingKey::alloc(N_GLWE, BASE2K, K, DNUM); test_reader_writer_interface(original); } #[test] fn lwe_switching_key_compressed_serialization() { - let original: LWESwitchingKeyCompressed> = LWESwitchingKeyCompressed::alloc_with(N_GLWE, BASE2K, K, DNUM); + let original: LWESwitchingKeyCompressed> = LWESwitchingKeyCompressed::alloc(N_GLWE, BASE2K, K, DNUM); test_reader_writer_interface(original); } #[test] fn ggsw_serialization() { - let original: GGSWCiphertext> = GGSWCiphertext::alloc_with(N_GLWE, BASE2K, K, RANK, DNUM, DSIZE); + let original: GGSW> = GGSW::alloc(N_GLWE, BASE2K, K, RANK, DNUM, DSIZE); test_reader_writer_interface(original); } #[test] fn ggsw_compressed_serialization() { - let original: GGSWCiphertextCompressed> = GGSWCiphertextCompressed::alloc_with(N_GLWE, BASE2K, K, RANK, DNUM, DSIZE); + let original: GGSWCompressed> = GGSWCompressed::alloc(N_GLWE, BASE2K, K, RANK, DNUM, DSIZE); test_reader_writer_interface(original); } diff --git a/poulpy-core/src/tests/test_suite/automorphism/gglwe_atk.rs b/poulpy-core/src/tests/test_suite/automorphism/gglwe_atk.rs index 1dd3e58..3f2b94a 100644 --- a/poulpy-core/src/tests/test_suite/automorphism/gglwe_atk.rs +++ b/poulpy-core/src/tests/test_suite/automorphism/gglwe_atk.rs @@ -1,71 +1,33 @@ use poulpy_hal::{ - api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, - VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, - VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, - VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSubScalarInplace, VecZnxSwitchRing, - VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, - }, - layouts::{Backend, Module, ScratchOwned}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, - TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, - }, + api::{ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAutomorphism, VecZnxSubScalarInplace}, + layouts::{Backend, GaloisElement, Module, Scratch, ScratchOwned}, source::Source, }; use crate::{ + GLWEAutomorphismKeyAutomorphism, GLWEAutomorphismKeyEncryptSk, GLWEDecrypt, ScratchTakeCore, encryption::SIGMA, layouts::{ - GGLWEAutomorphismKey, GGLWEAutomorphismKeyLayout, GGLWEInfos, GLWEPlaintext, GLWESecret, - prepared::{GGLWEAutomorphismKeyPrepared, GLWESecretPrepared, Prepare, PrepareAlloc}, + GGLWEInfos, GLWEAutomorphismKey, GLWEAutomorphismKeyLayout, GLWEAutomorphismKeyPreparedFactory, GLWEPlaintext, + GLWESecret, GLWESecretPreparedFactory, + prepared::{GLWEAutomorphismKeyPrepared, GLWESecretPrepared}, }, noise::log2_std_noise_gglwe_product, }; #[allow(clippy::too_many_arguments)] -pub fn test_gglwe_automorphism_key_automorphism(module: &Module) +pub fn test_gglwe_automorphism_key_automorphism(module: &Module) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize + Module: GLWEAutomorphismKeyEncryptSk + + GLWEAutomorphismKeyPreparedFactory + + GLWEAutomorphismKeyAutomorphism + VecZnxAutomorphism - + VecZnxAutomorphismInplace - + SvpPPolAllocBytes - + VecZnxDftAllocBytes - + VecZnxNormalizeTmpBytes - + VmpPMatAlloc - + VmpPrepare - + SvpPrepare - + SvpApplyDftToDftInplace - + VecZnxAddScalarInplace - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + VecZnxSwitchRing - + SvpPPolAlloc - + VecZnxBigAddInplace - + VecZnxSubScalarInplace, - B: Backend - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxDftImpl - + TakeVecZnxImpl - + TakeSvpPPolImpl - + TakeVecZnxBigImpl, + + GaloisElement + + VecZnxSubScalarInplace + + GLWESecretPreparedFactory + + GLWEDecrypt, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, { let base2k: usize = 12; let k_in: usize = 60; @@ -84,7 +46,7 @@ where let dnum_out: usize = k_out / (base2k * di); let dnum_apply: usize = k_in.div_ceil(base2k * di); - let auto_key_in_infos: GGLWEAutomorphismKeyLayout = GGLWEAutomorphismKeyLayout { + let auto_key_in_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n: n.into(), base2k: base2k.into(), k: k_in.into(), @@ -93,7 +55,7 @@ where rank: rank.into(), }; - let auto_key_out_infos: GGLWEAutomorphismKeyLayout = GGLWEAutomorphismKeyLayout { + let auto_key_out_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n: n.into(), base2k: base2k.into(), k: k_out.into(), @@ -102,7 +64,7 @@ where rank: rank.into(), }; - let auto_key_apply_infos: GGLWEAutomorphismKeyLayout = GGLWEAutomorphismKeyLayout { + let auto_key_apply_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n: n.into(), base2k: base2k.into(), k: k_apply.into(), @@ -111,18 +73,18 @@ where rank: rank.into(), }; - let mut auto_key_in: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&auto_key_in_infos); - let mut auto_key_out: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&auto_key_out_infos); - let mut auto_key_apply: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&auto_key_apply_infos); + let mut auto_key_in: GLWEAutomorphismKey> = GLWEAutomorphismKey::alloc_from_infos(&auto_key_in_infos); + let mut auto_key_out: GLWEAutomorphismKey> = GLWEAutomorphismKey::alloc_from_infos(&auto_key_out_infos); + let mut auto_key_apply: GLWEAutomorphismKey> = GLWEAutomorphismKey::alloc_from_infos(&auto_key_apply_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, &auto_key_in_infos) - | GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, &auto_key_apply_infos) - | GGLWEAutomorphismKey::automorphism_scratch_space( + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &auto_key_in_infos) + | GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &auto_key_apply_infos) + | GLWEAutomorphismKey::automorphism_tmp_bytes( module, &auto_key_out_infos, &auto_key_in_infos, @@ -130,7 +92,7 @@ where ), ); - let mut sk: GLWESecret> = GLWESecret::alloc(&auto_key_in); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&auto_key_in); sk.fill_ternary_prob(0.5, &mut source_xs); // gglwe_{s1}(s0) = s0 -> s1 @@ -153,8 +115,8 @@ where scratch.borrow(), ); - let mut auto_key_apply_prepared: GGLWEAutomorphismKeyPrepared, B> = - GGLWEAutomorphismKeyPrepared::alloc(module, &auto_key_apply_infos); + let mut auto_key_apply_prepared: GLWEAutomorphismKeyPrepared, BE> = + GLWEAutomorphismKeyPrepared::alloc_from_infos(module, &auto_key_apply_infos); auto_key_apply_prepared.prepare(module, &auto_key_apply, scratch.borrow()); @@ -166,9 +128,9 @@ where scratch.borrow(), ); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&auto_key_out_infos); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&auto_key_out_infos); - let mut sk_auto: GLWESecret> = GLWESecret::alloc(&auto_key_out_infos); + let mut sk_auto: GLWESecret> = GLWESecret::alloc_from_infos(&auto_key_out_infos); sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk for i in 0..rank { module.vec_znx_automorphism( @@ -180,7 +142,8 @@ where ); } - let sk_auto_dft: GLWESecretPrepared, B> = sk_auto.prepare_alloc(module, scratch.borrow()); + let mut sk_auto_dft: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc_from_infos(module, &sk_auto); + sk_auto_dft.prepare(module, &sk_auto); (0..auto_key_out.rank_in().into()).for_each(|col_i| { (0..auto_key_out.dnum().into()).for_each(|row_i| { @@ -222,61 +185,18 @@ where } #[allow(clippy::too_many_arguments)] -pub fn test_gglwe_automorphism_key_automorphism_inplace(module: &Module) +pub fn test_gglwe_automorphism_key_automorphism_inplace(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize + Module: GLWEAutomorphismKeyEncryptSk + + GLWEAutomorphismKeyPreparedFactory + + GLWEAutomorphismKeyAutomorphism + VecZnxAutomorphism - + VecZnxSwitchRing - + VecZnxAddScalarInplace - + VecZnxAutomorphism - + VecZnxAutomorphismInplace - + VecZnxDftAllocBytes - + VecZnxBigAllocBytes - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalizeTmpBytes + + GaloisElement + VecZnxSubScalarInplace - + VecZnxCopy - + VmpPMatAlloc - + VmpPrepare, - B: Backend - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxDftImpl - + TakeVecZnxImpl - + TakeSvpPPolImpl - + TakeVecZnxBigImpl, + + GLWESecretPreparedFactory + + GLWEDecrypt, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, { let base2k: usize = 12; let k_in: usize = 60; @@ -293,7 +213,7 @@ where let dnum_in: usize = k_in / (base2k * di); let dnum_apply: usize = k_in.div_ceil(base2k * di); - let auto_key_layout: GGLWEAutomorphismKeyLayout = GGLWEAutomorphismKeyLayout { + let auto_key_layout: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n: n.into(), base2k: base2k.into(), k: k_in.into(), @@ -302,7 +222,7 @@ where rank: rank.into(), }; - let auto_key_apply_layout: GGLWEAutomorphismKeyLayout = GGLWEAutomorphismKeyLayout { + let auto_key_apply_layout: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n: n.into(), base2k: base2k.into(), k: k_apply.into(), @@ -311,20 +231,20 @@ where rank: rank.into(), }; - let mut auto_key: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&auto_key_layout); - let mut auto_key_apply: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&auto_key_apply_layout); + let mut auto_key: GLWEAutomorphismKey> = GLWEAutomorphismKey::alloc_from_infos(&auto_key_layout); + let mut auto_key_apply: GLWEAutomorphismKey> = GLWEAutomorphismKey::alloc_from_infos(&auto_key_apply_layout); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, &auto_key) - | GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, &auto_key_apply) - | GGLWEAutomorphismKey::automorphism_inplace_scratch_space(module, &auto_key, &auto_key_apply), + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &auto_key) + | GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &auto_key_apply) + | GLWEAutomorphismKey::automorphism_tmp_bytes(module, &auto_key, &auto_key, &auto_key_apply), ); - let mut sk: GLWESecret> = GLWESecret::alloc(&auto_key); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&auto_key); sk.fill_ternary_prob(0.5, &mut source_xs); // gglwe_{s1}(s0) = s0 -> s1 @@ -347,17 +267,17 @@ where scratch.borrow(), ); - let mut auto_key_apply_prepared: GGLWEAutomorphismKeyPrepared, B> = - GGLWEAutomorphismKeyPrepared::alloc(module, &auto_key_apply_layout); + let mut auto_key_apply_prepared: GLWEAutomorphismKeyPrepared, BE> = + GLWEAutomorphismKeyPrepared::alloc_from_infos(module, &auto_key_apply_layout); auto_key_apply_prepared.prepare(module, &auto_key_apply, scratch.borrow()); // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) auto_key.automorphism_inplace(module, &auto_key_apply_prepared, scratch.borrow()); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&auto_key); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&auto_key); - let mut sk_auto: GLWESecret> = GLWESecret::alloc(&auto_key); + let mut sk_auto: GLWESecret> = GLWESecret::alloc_from_infos(&auto_key); sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk for i in 0..rank { @@ -370,7 +290,8 @@ where ); } - let sk_auto_dft: GLWESecretPrepared, B> = sk_auto.prepare_alloc(module, scratch.borrow()); + let mut sk_auto_dft: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc_from_infos(module, &sk_auto); + sk_auto_dft.prepare(module, &sk_auto); (0..auto_key.rank_in().into()).for_each(|col_i| { (0..auto_key.dnum().into()).for_each(|row_i| { diff --git a/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs b/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs index 2fd7151..c3aa3c3 100644 --- a/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs +++ b/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs @@ -1,79 +1,33 @@ use poulpy_hal::{ - api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, - SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, - VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigNormalize, - VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, - VecZnxDftCopy, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, - VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, - }, - layouts::{Backend, Module, ScalarZnx, ScratchOwned}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, - TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl, - }, + api::{ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAutomorphismInplace}, + layouts::{Backend, Module, ScalarZnx, Scratch, ScratchOwned}, source::Source, }; use crate::{ + GGSWAutomorphism, GGSWEncryptSk, GGSWNoise, GLWEAutomorphismKeyEncryptSk, GLWETensorKeyEncryptSk, ScratchTakeCore, encryption::SIGMA, layouts::{ - GGLWEAutomorphismKey, GGLWETensorKey, GGLWETensorKeyLayout, GGSWCiphertext, GGSWCiphertextLayout, GLWESecret, - prepared::{GGLWEAutomorphismKeyPrepared, GGLWETensorKeyPrepared, GLWESecretPrepared, Prepare, PrepareAlloc}, + GGSW, GGSWLayout, GLWEAutomorphismKey, GLWEAutomorphismKeyPreparedFactory, GLWESecret, GLWESecretPreparedFactory, + GLWETensorKey, GLWETensorKeyLayout, GLWETensorKeyPreparedFactory, + prepared::{GLWEAutomorphismKeyPrepared, GLWESecretPrepared, GLWETensorKeyPrepared}, }, noise::noise_ggsw_keyswitch, }; -pub fn test_ggsw_automorphism(module: &Module) +pub fn test_ggsw_automorphism(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxBigAllocBytes - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalizeTmpBytes - + VecZnxBigAlloc - + VecZnxDftAlloc - + VecZnxBigNormalizeTmpBytes - + VecZnxIdftApplyTmpA - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxAddScalarInplace - + VecZnxCopy - + VecZnxSubInplace - + VmpPMatAlloc - + VmpPrepare - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftCopy - + VecZnxDftAddInplace - + VecZnxFillUniform - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpApplyDftToDft - + VecZnxSwitchRing - + VecZnxAutomorphismInplace - + VecZnxAutomorphism, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl - + VecZnxDftAllocBytesImpl - + VecZnxBigAllocBytesImpl - + TakeSvpPPolImpl, + Module: GGSWEncryptSk + + GLWEAutomorphismKeyEncryptSk + + GLWEAutomorphismKeyPreparedFactory + + GGSWAutomorphism + + GLWETensorKeyPreparedFactory + + GLWETensorKeyEncryptSk + + GLWESecretPreparedFactory + + VecZnxAutomorphismInplace + + GGSWNoise, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, { let base2k: usize = 12; let k_in: usize = 54; @@ -92,7 +46,7 @@ where let dsize_in: usize = 1; - let ggsw_in_layout: GGSWCiphertextLayout = GGSWCiphertextLayout { + let ggsw_in_layout: GGSWLayout = GGSWLayout { n: n.into(), base2k: base2k.into(), k: k_in.into(), @@ -101,7 +55,7 @@ where rank: rank.into(), }; - let ggsw_out_layout: GGSWCiphertextLayout = GGSWCiphertextLayout { + let ggsw_out_layout: GGSWLayout = GGSWLayout { n: n.into(), base2k: base2k.into(), k: k_out.into(), @@ -110,7 +64,7 @@ where rank: rank.into(), }; - let tensor_key_layout: GGLWETensorKeyLayout = GGLWETensorKeyLayout { + let tensor_key_layout: GLWETensorKeyLayout = GLWETensorKeyLayout { n: n.into(), base2k: base2k.into(), k: k_tsk.into(), @@ -119,7 +73,7 @@ where rank: rank.into(), }; - let auto_key_layout: GGLWETensorKeyLayout = GGLWETensorKeyLayout { + let auto_key_layout: GLWETensorKeyLayout = GLWETensorKeyLayout { n: n.into(), base2k: base2k.into(), k: k_ksk.into(), @@ -128,28 +82,30 @@ where rank: rank.into(), }; - let mut ct_in: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_in_layout); - let mut ct_out: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_out_layout); - let mut tensor_key: GGLWETensorKey> = GGLWETensorKey::alloc(&tensor_key_layout); - let mut auto_key: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&auto_key_layout); + let mut ct_in: GGSW> = GGSW::alloc_from_infos(&ggsw_in_layout); + let mut ct_out: GGSW> = GGSW::alloc_from_infos(&ggsw_out_layout); + let mut tensor_key: GLWETensorKey> = GLWETensorKey::alloc_from_infos(&tensor_key_layout); + let mut auto_key: GLWEAutomorphismKey> = GLWEAutomorphismKey::alloc_from_infos(&auto_key_layout); let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, &ct_in) - | GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, &auto_key) - | GGLWETensorKey::encrypt_sk_scratch_space(module, &tensor_key) - | GGSWCiphertext::automorphism_scratch_space(module, &ct_out, &ct_in, &auto_key, &tensor_key), + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GGSW::encrypt_sk_tmp_bytes(module, &ct_in) + | GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &auto_key) + | GLWETensorKey::encrypt_sk_tmp_bytes(module, &tensor_key) + | GGSW::automorphism_tmp_bytes(module, &ct_out, &ct_in, &auto_key, &tensor_key), ); let var_xs: f64 = 0.5; - let mut sk: GLWESecret> = GLWESecret::alloc(&ct_out); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&ct_out); sk.fill_ternary_prob(var_xs, &mut source_xs); - let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); + + let mut sk_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc_from_infos(module, &sk); + sk_prepared.prepare(module, &sk); auto_key.encrypt_sk( module, @@ -178,11 +134,12 @@ where scratch.borrow(), ); - let mut auto_key_prepared: GGLWEAutomorphismKeyPrepared, B> = - GGLWEAutomorphismKeyPrepared::alloc(module, &auto_key_layout); + let mut auto_key_prepared: GLWEAutomorphismKeyPrepared, BE> = + GLWEAutomorphismKeyPrepared::alloc_from_infos(module, &auto_key_layout); auto_key_prepared.prepare(module, &auto_key, scratch.borrow()); - let mut tsk_prepared: GGLWETensorKeyPrepared, B> = GGLWETensorKeyPrepared::alloc(module, &tensor_key_layout); + let mut tsk_prepared: GLWETensorKeyPrepared, BE> = + GLWETensorKeyPrepared::alloc_from_infos(module, &tensor_key_layout); tsk_prepared.prepare(module, &tensor_key, scratch.borrow()); ct_out.automorphism( @@ -217,56 +174,19 @@ where } #[allow(clippy::too_many_arguments)] -pub fn test_ggsw_automorphism_inplace(module: &Module) +pub fn test_ggsw_automorphism_inplace(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxBigAllocBytes - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalizeTmpBytes - + VecZnxBigAlloc - + VecZnxDftAlloc - + VecZnxBigNormalizeTmpBytes - + VecZnxIdftApplyTmpA - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxAddScalarInplace - + VecZnxCopy - + VecZnxSubInplace - + VmpPMatAlloc - + VmpPrepare - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxBigAddSmallInplace - + VecZnxDftCopy - + VecZnxDftAddInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + VecZnxFillUniform - + SvpApplyDftToDft - + VecZnxSwitchRing - + VecZnxAutomorphismInplace - + VecZnxAutomorphism, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl - + VecZnxDftAllocBytesImpl - + VecZnxBigAllocBytesImpl - + TakeSvpPPolImpl, + Module: GGSWEncryptSk + + GLWEAutomorphismKeyEncryptSk + + GLWEAutomorphismKeyPreparedFactory + + GGSWAutomorphism + + GLWETensorKeyPreparedFactory + + GLWETensorKeyEncryptSk + + GLWESecretPreparedFactory + + VecZnxAutomorphismInplace + + GGSWNoise, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, { let base2k: usize = 12; let k_out: usize = 54; @@ -282,7 +202,7 @@ where let dnum_in: usize = k_out.div_euclid(base2k * di); let dsize_in: usize = 1; - let ggsw_out_layout: GGSWCiphertextLayout = GGSWCiphertextLayout { + let ggsw_out_layout: GGSWLayout = GGSWLayout { n: n.into(), base2k: base2k.into(), k: k_out.into(), @@ -291,7 +211,7 @@ where rank: rank.into(), }; - let tensor_key_layout: GGLWETensorKeyLayout = GGLWETensorKeyLayout { + let tensor_key_layout: GLWETensorKeyLayout = GLWETensorKeyLayout { n: n.into(), base2k: base2k.into(), k: k_tsk.into(), @@ -300,7 +220,7 @@ where rank: rank.into(), }; - let auto_key_layout: GGLWETensorKeyLayout = GGLWETensorKeyLayout { + let auto_key_layout: GLWETensorKeyLayout = GLWETensorKeyLayout { n: n.into(), base2k: base2k.into(), k: k_ksk.into(), @@ -309,27 +229,29 @@ where rank: rank.into(), }; - let mut ct: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_out_layout); - let mut tensor_key: GGLWETensorKey> = GGLWETensorKey::alloc(&tensor_key_layout); - let mut auto_key: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&auto_key_layout); + let mut ct: GGSW> = GGSW::alloc_from_infos(&ggsw_out_layout); + let mut tensor_key: GLWETensorKey> = GLWETensorKey::alloc_from_infos(&tensor_key_layout); + let mut auto_key: GLWEAutomorphismKey> = GLWEAutomorphismKey::alloc_from_infos(&auto_key_layout); let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, &ct) - | GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, &auto_key) - | GGLWETensorKey::encrypt_sk_scratch_space(module, &tensor_key) - | GGSWCiphertext::automorphism_inplace_scratch_space(module, &ct, &auto_key, &tensor_key), + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GGSW::encrypt_sk_tmp_bytes(module, &ct) + | GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &auto_key) + | GLWETensorKey::encrypt_sk_tmp_bytes(module, &tensor_key) + | GGSW::automorphism_tmp_bytes(module, &ct, &ct, &auto_key, &tensor_key), ); let var_xs: f64 = 0.5; - let mut sk: GLWESecret> = GLWESecret::alloc(&ct); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&ct); sk.fill_ternary_prob(var_xs, &mut source_xs); - let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); + + let mut sk_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc_from_infos(module, &sk); + sk_prepared.prepare(module, &sk); auto_key.encrypt_sk( module, @@ -358,11 +280,12 @@ where scratch.borrow(), ); - let mut auto_key_prepared: GGLWEAutomorphismKeyPrepared, B> = - GGLWEAutomorphismKeyPrepared::alloc(module, &auto_key_layout); + let mut auto_key_prepared: GLWEAutomorphismKeyPrepared, BE> = + GLWEAutomorphismKeyPrepared::alloc_from_infos(module, &auto_key_layout); auto_key_prepared.prepare(module, &auto_key, scratch.borrow()); - let mut tsk_prepared: GGLWETensorKeyPrepared, B> = GGLWETensorKeyPrepared::alloc(module, &tensor_key_layout); + let mut tsk_prepared: GLWETensorKeyPrepared, BE> = + GLWETensorKeyPrepared::alloc_from_infos(module, &tensor_key_layout); tsk_prepared.prepare(module, &tensor_key, scratch.borrow()); ct.automorphism_inplace(module, &auto_key_prepared, &tsk_prepared, scratch.borrow()); diff --git a/poulpy-core/src/tests/test_suite/automorphism/glwe_ct.rs b/poulpy-core/src/tests/test_suite/automorphism/glwe_ct.rs index 5828c48..58f737a 100644 --- a/poulpy-core/src/tests/test_suite/automorphism/glwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/automorphism/glwe_ct.rs @@ -1,69 +1,33 @@ use poulpy_hal::{ - api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, - VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, - VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, - }, - layouts::{Backend, Module, ScratchOwned}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, - TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, - }, + api::{ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAutomorphismInplace, VecZnxFillUniform}, + layouts::{Backend, Module, Scratch, ScratchOwned}, source::Source, }; use crate::{ + GLWEAutomorphism, GLWEAutomorphismKeyEncryptSk, GLWEDecrypt, GLWEEncryptSk, GLWENoise, ScratchTakeCore, encryption::SIGMA, layouts::{ - GGLWEAutomorphismKey, GGLWEAutomorphismKeyLayout, GLWECiphertext, GLWECiphertextLayout, GLWEPlaintext, GLWESecret, - prepared::{GGLWEAutomorphismKeyPrepared, GLWESecretPrepared, Prepare, PrepareAlloc}, + GLWE, GLWEAutomorphismKey, GLWEAutomorphismKeyLayout, GLWEAutomorphismKeyPreparedFactory, GLWELayout, GLWEPlaintext, + GLWESecret, GLWESecretPreparedFactory, + prepared::{GLWEAutomorphismKeyPrepared, GLWESecretPrepared}, }, noise::log2_std_noise_gglwe_product, }; -pub fn test_glwe_automorphism(module: &Module) +pub fn test_glwe_automorphism(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes + Module: GLWEEncryptSk + + GLWESecretPreparedFactory + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxAutomorphism - + VecZnxSwitchRing - + VecZnxAddScalarInplace - + VecZnxAutomorphismInplace - + VmpPMatAlloc - + VmpPrepare, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeSvpPPolImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl, + + GLWEDecrypt + + GLWEAutomorphism + + GLWEAutomorphismKeyEncryptSk + + GLWEAutomorphismKeyPreparedFactory + + GLWENoise + + VecZnxAutomorphismInplace, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, { let base2k: usize = 12; let k_in: usize = 60; @@ -77,21 +41,21 @@ where let n: usize = module.n(); let dnum: usize = k_in.div_ceil(base2k * dsize); - let ct_in_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let ct_in_infos: GLWELayout = GLWELayout { n: n.into(), base2k: base2k.into(), k: k_in.into(), rank: rank.into(), }; - let ct_out_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let ct_out_infos: GLWELayout = GLWELayout { n: n.into(), base2k: base2k.into(), k: k_out.into(), rank: rank.into(), }; - let autokey_infos: GGLWEAutomorphismKeyLayout = GGLWEAutomorphismKeyLayout { + let autokey_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n: n.into(), base2k: base2k.into(), k: k_out.into(), @@ -100,10 +64,10 @@ where dsize: di.into(), }; - let mut autokey: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&autokey_infos); - let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(&ct_in_infos); - let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(&ct_out_infos); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&ct_out_infos); + let mut autokey: GLWEAutomorphismKey> = GLWEAutomorphismKey::alloc_from_infos(&autokey_infos); + let mut ct_in: GLWE> = GLWE::alloc_from_infos(&ct_in_infos); + let mut ct_out: GLWE> = GLWE::alloc_from_infos(&ct_out_infos); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&ct_out_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -111,16 +75,18 @@ where module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa); - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, &autokey) - | GLWECiphertext::decrypt_scratch_space(module, &ct_out) - | GLWECiphertext::encrypt_sk_scratch_space(module, &ct_in) - | GLWECiphertext::automorphism_scratch_space(module, &ct_out, &ct_in, &autokey), + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &autokey) + | GLWE::decrypt_tmp_bytes(module, &ct_out) + | GLWE::encrypt_sk_tmp_bytes(module, &ct_in) + | GLWE::automorphism_tmp_bytes(module, &ct_out, &ct_in, &autokey), ); - let mut sk: GLWESecret> = GLWESecret::alloc(&ct_out); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&ct_out); sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); + + let mut sk_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc_from_infos(module, &sk); + sk_prepared.prepare(module, &sk); autokey.encrypt_sk( module, @@ -140,8 +106,8 @@ where scratch.borrow(), ); - let mut autokey_prepared: GGLWEAutomorphismKeyPrepared, B> = - GGLWEAutomorphismKeyPrepared::alloc(module, &autokey_infos); + let mut autokey_prepared: GLWEAutomorphismKeyPrepared, BE> = + GLWEAutomorphismKeyPrepared::alloc_from_infos(module, &autokey_infos); autokey_prepared.prepare(module, &autokey, scratch.borrow()); ct_out.automorphism(module, &ct_in, &autokey_prepared, scratch.borrow()); @@ -167,46 +133,19 @@ where } #[allow(clippy::too_many_arguments)] -pub fn test_glwe_automorphism_inplace(module: &Module) +pub fn test_glwe_automorphism_inplace(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes + Module: GLWEEncryptSk + + GLWESecretPreparedFactory + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxAutomorphism - + VecZnxSwitchRing - + VecZnxAddScalarInplace - + VecZnxAutomorphismInplace - + VmpPMatAlloc - + VmpPrepare, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeSvpPPolImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl, + + GLWEDecrypt + + GLWEAutomorphism + + GLWEAutomorphismKeyEncryptSk + + GLWEAutomorphismKeyPreparedFactory + + GLWENoise + + VecZnxAutomorphismInplace, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, { let base2k: usize = 12; let k_out: usize = 60; @@ -219,14 +158,14 @@ where let n: usize = module.n(); let dnum: usize = k_out.div_ceil(base2k * dsize); - let ct_out_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let ct_out_infos: GLWELayout = GLWELayout { n: n.into(), base2k: base2k.into(), k: k_out.into(), rank: rank.into(), }; - let autokey_infos: GGLWEAutomorphismKeyLayout = GGLWEAutomorphismKeyLayout { + let autokey_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n: n.into(), base2k: base2k.into(), k: k_ksk.into(), @@ -235,9 +174,9 @@ where dsize: di.into(), }; - let mut autokey: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&autokey_infos); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&ct_out_infos); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&ct_out_infos); + let mut autokey: GLWEAutomorphismKey> = GLWEAutomorphismKey::alloc_from_infos(&autokey_infos); + let mut ct: GLWE> = GLWE::alloc_from_infos(&ct_out_infos); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&ct_out_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -245,16 +184,18 @@ where module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa); - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, &autokey) - | GLWECiphertext::decrypt_scratch_space(module, &ct) - | GLWECiphertext::encrypt_sk_scratch_space(module, &ct) - | GLWECiphertext::automorphism_inplace_scratch_space(module, &ct, &autokey), + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &autokey) + | GLWE::decrypt_tmp_bytes(module, &ct) + | GLWE::encrypt_sk_tmp_bytes(module, &ct) + | GLWE::automorphism_tmp_bytes(module, &ct, &ct, &autokey), ); - let mut sk: GLWESecret> = GLWESecret::alloc(&ct); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&ct); sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); + + let mut sk_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc_from_infos(module, &sk); + sk_prepared.prepare(module, &sk); autokey.encrypt_sk( module, @@ -274,8 +215,8 @@ where scratch.borrow(), ); - let mut autokey_prepared: GGLWEAutomorphismKeyPrepared, B> = - GGLWEAutomorphismKeyPrepared::alloc(module, &autokey); + let mut autokey_prepared: GLWEAutomorphismKeyPrepared, BE> = + GLWEAutomorphismKeyPrepared::alloc_from_infos(module, &autokey); autokey_prepared.prepare(module, &autokey, scratch.borrow()); ct.automorphism_inplace(module, &autokey_prepared, scratch.borrow()); diff --git a/poulpy-core/src/tests/test_suite/conversion.rs b/poulpy-core/src/tests/test_suite/conversion.rs index c2c81dc..c6e7d00 100644 --- a/poulpy-core/src/tests/test_suite/conversion.rs +++ b/poulpy-core/src/tests/test_suite/conversion.rs @@ -1,69 +1,30 @@ use poulpy_hal::{ - api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigAddInplace, - VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, - VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, - VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, ZnAddNormal, ZnFillUniform, ZnNormalizeInplace, - }, - layouts::{Backend, Module, ScratchOwned, ZnxView}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, - TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, - }, + api::{ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow}, + layouts::{Backend, Module, Scratch, ScratchOwned, ZnxView}, source::Source, }; -use crate::layouts::{ - Base2K, Degree, Dnum, GLWECiphertext, GLWECiphertextLayout, GLWEPlaintext, GLWESecret, GLWEToLWEKey, GLWEToLWEKeyLayout, - LWECiphertext, LWECiphertextLayout, LWEPlaintext, LWESecret, LWEToGLWESwitchingKey, LWEToGLWESwitchingKeyLayout, Rank, - TorusPrecision, - prepared::{GLWESecretPrepared, GLWEToLWESwitchingKeyPrepared, LWEToGLWESwitchingKeyPrepared, PrepareAlloc}, +use crate::{ + GLWEDecrypt, GLWEEncryptSk, GLWEFromLWE, GLWEToLWESwitchingKeyEncryptSk, LWEDecrypt, LWEEncryptSk, + LWEToGLWESwitchingKeyEncryptSk, ScratchTakeCore, + layouts::{ + Base2K, Degree, Dnum, GLWE, GLWELayout, GLWEPlaintext, GLWESecret, GLWESecretPreparedFactory, GLWEToLWEKeyLayout, + GLWEToLWESwitchingKey, GLWEToLWESwitchingKeyPreparedFactory, LWE, LWELayout, LWEPlaintext, LWESecret, + LWEToGLWESwitchingKey, LWEToGLWESwitchingKeyLayout, LWEToGLWESwitchingKeyPreparedFactory, Rank, TorusPrecision, + prepared::{GLWESecretPrepared, GLWEToLWESwitchingKeyPrepared, LWEToGLWESwitchingKeyPrepared}, + }, }; -pub fn test_lwe_to_glwe(module: &Module) +pub fn test_lwe_to_glwe(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxNormalizeTmpBytes - + VecZnxAddScalarInplace - + VmpPMatAlloc - + VmpPrepare - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxBigNormalizeTmpBytes - + VecZnxSwitchRing - + VecZnxAutomorphismInplace - + ZnNormalizeInplace - + ZnFillUniform - + ZnAddNormal, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeSvpPPolImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl, + Module: GLWEFromLWE + + LWEToGLWESwitchingKeyEncryptSk + + GLWEDecrypt + + GLWESecretPreparedFactory + + LWEEncryptSk + + LWEToGLWESwitchingKeyPreparedFactory, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, { let n_glwe: Degree = Degree(module.n() as u32); let n_lwe: Degree = Degree(22); @@ -83,104 +44,79 @@ where rank_out: rank, }; - let glwe_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_infos: GLWELayout = GLWELayout { n: n_glwe, base2k: Base2K(17), k: TorusPrecision(34), rank, }; - let lwe_infos: LWECiphertextLayout = LWECiphertextLayout { + let lwe_infos: LWELayout = LWELayout { n: n_lwe, base2k: Base2K(17), k: TorusPrecision(34), }; - let mut scratch: ScratchOwned = ScratchOwned::alloc( - LWEToGLWESwitchingKey::encrypt_sk_scratch_space(module, &lwe_to_glwe_infos) - | GLWECiphertext::from_lwe_scratch_space(module, &glwe_infos, &lwe_infos, &lwe_to_glwe_infos) - | GLWECiphertext::decrypt_scratch_space(module, &glwe_infos), + let mut scratch: ScratchOwned = ScratchOwned::alloc( + LWEToGLWESwitchingKey::encrypt_sk_tmp_bytes(module, &lwe_to_glwe_infos) + | GLWE::from_lwe_tmp_bytes(module, &glwe_infos, &lwe_infos, &lwe_to_glwe_infos) + | GLWE::decrypt_tmp_bytes(module, &glwe_infos), ); - let mut sk_glwe: GLWESecret> = GLWESecret::alloc(&glwe_infos); + let mut sk_glwe: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_infos); sk_glwe.fill_ternary_prob(0.5, &mut source_xs); - let sk_glwe_prepared: GLWESecretPrepared, B> = sk_glwe.prepare_alloc(module, scratch.borrow()); + let mut sk_glwe_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc_from_infos(module, &sk_glwe); + sk_glwe_prepared.prepare(module, &sk_glwe); let mut sk_lwe: LWESecret> = LWESecret::alloc(n_lwe); sk_lwe.fill_ternary_prob(0.5, &mut source_xs); let data: i64 = 17; - let mut lwe_pt: LWEPlaintext> = LWEPlaintext::alloc(&lwe_infos); + let mut lwe_pt: LWEPlaintext> = LWEPlaintext::alloc_from_infos(&lwe_infos); lwe_pt.encode_i64(data, k_lwe_pt); - let mut lwe_ct: LWECiphertext> = LWECiphertext::alloc(&lwe_infos); + let mut lwe_ct: LWE> = LWE::alloc_from_infos(&lwe_infos); lwe_ct.encrypt_sk(module, &lwe_pt, &sk_lwe, &mut source_xa, &mut source_xe); - let mut ksk: LWEToGLWESwitchingKey> = LWEToGLWESwitchingKey::alloc(&lwe_to_glwe_infos); + let mut ksk: LWEToGLWESwitchingKey> = LWEToGLWESwitchingKey::alloc_from_infos(&lwe_to_glwe_infos); ksk.encrypt_sk( module, &sk_lwe, - &sk_glwe, + &sk_glwe_prepared, &mut source_xa, &mut source_xe, scratch.borrow(), ); - let mut glwe_ct: GLWECiphertext> = GLWECiphertext::alloc(&glwe_infos); + let mut glwe_ct: GLWE> = GLWE::alloc_from_infos(&glwe_infos); - let ksk_prepared: LWEToGLWESwitchingKeyPrepared, B> = ksk.prepare_alloc(module, scratch.borrow()); + let mut ksk_prepared: LWEToGLWESwitchingKeyPrepared, BE> = + LWEToGLWESwitchingKeyPrepared::alloc_from_infos(module, &ksk); + ksk_prepared.prepare(module, &ksk, scratch.borrow()); glwe_ct.from_lwe(module, &lwe_ct, &ksk_prepared, scratch.borrow()); - let mut glwe_pt: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_infos); + let mut glwe_pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_infos); glwe_ct.decrypt(module, &mut glwe_pt, &sk_glwe_prepared, scratch.borrow()); assert_eq!(glwe_pt.data.at(0, 0)[0], lwe_pt.data.at(0, 0)[0]); } -pub fn test_glwe_to_lwe(module: &Module) +pub fn test_glwe_to_lwe(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxNormalizeTmpBytes - + VecZnxAddScalarInplace - + VmpPMatAlloc - + VmpPrepare - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxBigNormalizeTmpBytes - + VecZnxSwitchRing - + VecZnxAutomorphismInplace - + ZnNormalizeInplace, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeSvpPPolImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl, + Module: GLWEFromLWE + + GLWEToLWESwitchingKeyEncryptSk + + GLWEEncryptSk + + LWEDecrypt + + GLWEDecrypt + + GLWESecretPreparedFactory + + GLWEToLWESwitchingKeyEncryptSk + + GLWEToLWESwitchingKeyPreparedFactory, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, { let n_glwe: Degree = Degree(module.n() as u32); let n_lwe: Degree = Degree(22); @@ -196,14 +132,14 @@ where rank_in: rank, }; - let glwe_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_infos: GLWELayout = GLWELayout { n: n_glwe, base2k: Base2K(17), k: TorusPrecision(34), rank, }; - let lwe_infos: LWECiphertextLayout = LWECiphertextLayout { + let lwe_infos: LWELayout = LWELayout { n: n_lwe, base2k: Base2K(17), k: TorusPrecision(34), @@ -213,25 +149,26 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWEToLWEKey::encrypt_sk_scratch_space(module, &glwe_to_lwe_infos) - | LWECiphertext::from_glwe_scratch_space(module, &lwe_infos, &glwe_infos, &glwe_to_lwe_infos) - | GLWECiphertext::decrypt_scratch_space(module, &glwe_infos), + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GLWEToLWESwitchingKey::encrypt_sk_tmp_bytes(module, &glwe_to_lwe_infos) + | LWE::from_glwe_tmp_bytes(module, &lwe_infos, &glwe_infos, &glwe_to_lwe_infos) + | GLWE::decrypt_tmp_bytes(module, &glwe_infos), ); - let mut sk_glwe: GLWESecret> = GLWESecret::alloc(&glwe_infos); + let mut sk_glwe: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_infos); sk_glwe.fill_ternary_prob(0.5, &mut source_xs); - let sk_glwe_prepared: GLWESecretPrepared, B> = sk_glwe.prepare_alloc(module, scratch.borrow()); + let mut sk_glwe_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc_from_infos(module, &sk_glwe); + sk_glwe_prepared.prepare(module, &sk_glwe); let mut sk_lwe: LWESecret> = LWESecret::alloc(n_lwe); sk_lwe.fill_ternary_prob(0.5, &mut source_xs); let data: i64 = 17; - let mut glwe_pt: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_infos); + let mut glwe_pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_infos); glwe_pt.encode_coeff_i64(data, k_lwe_pt, 0); - let mut glwe_ct = GLWECiphertext::alloc(&glwe_infos); + let mut glwe_ct: GLWE> = GLWE::alloc_from_infos(&glwe_infos); glwe_ct.encrypt_sk( module, &glwe_pt, @@ -241,7 +178,7 @@ where scratch.borrow(), ); - let mut ksk: GLWEToLWEKey> = GLWEToLWEKey::alloc(&glwe_to_lwe_infos); + let mut ksk: GLWEToLWESwitchingKey> = GLWEToLWESwitchingKey::alloc_from_infos(&glwe_to_lwe_infos); ksk.encrypt_sk( module, @@ -252,13 +189,15 @@ where scratch.borrow(), ); - let mut lwe_ct: LWECiphertext> = LWECiphertext::alloc(&lwe_infos); + let mut lwe_ct: LWE> = LWE::alloc_from_infos(&lwe_infos); - let ksk_prepared: GLWEToLWESwitchingKeyPrepared, B> = ksk.prepare_alloc(module, scratch.borrow()); + let mut ksk_prepared: GLWEToLWESwitchingKeyPrepared, BE> = + GLWEToLWESwitchingKeyPrepared::alloc_from_infos(module, &ksk); + ksk_prepared.prepare(module, &ksk, scratch.borrow()); lwe_ct.from_glwe(module, &glwe_ct, &ksk_prepared, scratch.borrow()); - let mut lwe_pt: LWEPlaintext> = LWEPlaintext::alloc(&lwe_infos); + let mut lwe_pt: LWEPlaintext> = LWEPlaintext::alloc_from_infos(&lwe_infos); lwe_ct.decrypt(module, &mut lwe_pt, &sk_lwe); assert_eq!(glwe_pt.data.at(0, 0)[0], lwe_pt.data.at(0, 0)[0]); diff --git a/poulpy-core/src/tests/test_suite/encryption/gglwe_atk.rs b/poulpy-core/src/tests/test_suite/encryption/gglwe_atk.rs index e9b7c93..9363539 100644 --- a/poulpy-core/src/tests/test_suite/encryption/gglwe_atk.rs +++ b/poulpy-core/src/tests/test_suite/encryption/gglwe_atk.rs @@ -1,71 +1,34 @@ use poulpy_hal::{ - api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, - VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, - VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, - VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSubScalarInplace, VecZnxSwitchRing, - VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, - }, - layouts::{Backend, Module, ScratchOwned}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, - TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, - }, + api::{ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAutomorphism, VecZnxFillUniform}, + layouts::{Backend, GaloisElement, Module, Scratch, ScratchOwned}, source::Source, }; use crate::{ + GGLWEKeyswitch, GLWEAutomorphismKeyCompressedEncryptSk, GLWEAutomorphismKeyEncryptSk, GLWESwitchingKeyCompressedEncryptSk, + GLWESwitchingKeyEncryptSk, ScratchTakeCore, encryption::SIGMA, layouts::{ - GGLWEAutomorphismKey, GGLWEAutomorphismKeyLayout, GLWEInfos, GLWESecret, - compressed::{Decompress, GGLWEAutomorphismKeyCompressed}, - prepared::{GLWESecretPrepared, PrepareAlloc}, + AutomorphismKeyDecompress, GLWEAutomorphismKey, GLWEAutomorphismKeyLayout, GLWEInfos, GLWESecret, + GLWESecretPreparedFactory, GLWESwitchingKeyDecompress, compressed::GLWEAutomorphismKeyCompressed, + prepared::GLWESecretPrepared, }, + noise::GGLWENoise, }; -pub fn test_gglwe_automorphisk_key_encrypt_sk(module: &Module) +pub fn test_gglwe_automorphism_key_encrypt_sk(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes + Module: GLWEAutomorphismKeyEncryptSk + + GGLWEKeyswitch + + GLWESecretPreparedFactory + + GLWESwitchingKeyEncryptSk + + GLWESwitchingKeyCompressedEncryptSk + + GLWESwitchingKeyDecompress + + GGLWENoise + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxBigAddSmallInplace - + VecZnxAutomorphism - + VecZnxSwitchRing - + VecZnxAddScalarInplace - + VecZnxAutomorphismInplace - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxSubScalarInplace - + VecZnxCopy - + VmpPMatAlloc - + VmpPrepare, - B: Backend - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxDftImpl - + TakeVecZnxImpl - + TakeSvpPPolImpl - + TakeVecZnxBigImpl, + + VecZnxAutomorphism, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, { let base2k: usize = 12; let k_ksk: usize = 60; @@ -75,7 +38,7 @@ where let n: usize = module.n(); let dnum: usize = (k_ksk - di * base2k) / (di * base2k); - let atk_infos: GGLWEAutomorphismKeyLayout = GGLWEAutomorphismKeyLayout { + let atk_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n: n.into(), base2k: base2k.into(), k: k_ksk.into(), @@ -84,17 +47,17 @@ where rank: rank.into(), }; - let mut atk: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&atk_infos); + let mut atk: GLWEAutomorphismKey> = GLWEAutomorphismKey::alloc_from_infos(&atk_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWEAutomorphismKey::encrypt_sk_scratch_space( + let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWEAutomorphismKey::encrypt_sk_tmp_bytes( module, &atk_infos, )); - let mut sk: GLWESecret> = GLWESecret::alloc(&atk_infos); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&atk_infos); sk.fill_ternary_prob(0.5, &mut source_xs); let p = -5; @@ -118,57 +81,28 @@ where i, ); }); - let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); + let mut sk_out_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, sk_out.rank()); + sk_out_prepared.prepare(module, &sk_out); atk.key - .key .assert_noise(module, &sk_out_prepared, &sk.data, SIGMA); } } } -pub fn test_gglwe_automorphisk_key_compressed_encrypt_sk(module: &Module) +pub fn test_gglwe_automorphism_key_compressed_encrypt_sk(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxBigAddSmallInplace + Module: GLWEAutomorphismKeyCompressedEncryptSk + + GGLWEKeyswitch + + GLWESecretPreparedFactory + + GLWESwitchingKeyEncryptSk + + GLWESwitchingKeyCompressedEncryptSk + + AutomorphismKeyDecompress + VecZnxAutomorphism - + VecZnxSwitchRing - + VecZnxAddScalarInplace - + VecZnxAutomorphismInplace - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxSubScalarInplace - + VecZnxCopy - + VmpPMatAlloc - + VmpPrepare, - B: Backend - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxDftImpl - + TakeVecZnxImpl - + TakeSvpPPolImpl - + TakeVecZnxBigImpl, + + VecZnxFillUniform + + GGLWENoise, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, { let base2k: usize = 12; let k_ksk: usize = 60; @@ -178,7 +112,7 @@ where let n: usize = module.n(); let dnum: usize = (k_ksk - di * base2k) / (di * base2k); - let atk_infos: GGLWEAutomorphismKeyLayout = GGLWEAutomorphismKeyLayout { + let atk_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n: n.into(), base2k: base2k.into(), k: k_ksk.into(), @@ -187,16 +121,17 @@ where rank: rank.into(), }; - let mut atk_compressed: GGLWEAutomorphismKeyCompressed> = GGLWEAutomorphismKeyCompressed::alloc(&atk_infos); + let mut atk_compressed: GLWEAutomorphismKeyCompressed> = + GLWEAutomorphismKeyCompressed::alloc_from_infos(&atk_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWEAutomorphismKey::encrypt_sk_scratch_space( + let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWEAutomorphismKeyCompressed::encrypt_sk_tmp_bytes( module, &atk_infos, )); - let mut sk: GLWESecret> = GLWESecret::alloc(&atk_infos); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&atk_infos); sk.fill_ternary_prob(0.5, &mut source_xs); let p = -5; @@ -215,13 +150,13 @@ where i, ); }); - let sk_out_prepared = sk_out.prepare_alloc(module, scratch.borrow()); + let mut sk_out_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, sk_out.rank()); + sk_out_prepared.prepare(module, &sk_out); - let mut atk: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&atk_infos); + let mut atk: GLWEAutomorphismKey> = GLWEAutomorphismKey::alloc_from_infos(&atk_infos); atk.decompress(module, &atk_compressed); atk.key - .key .assert_noise(module, &sk_out_prepared, &sk.data, SIGMA); } } diff --git a/poulpy-core/src/tests/test_suite/encryption/gglwe_ct.rs b/poulpy-core/src/tests/test_suite/encryption/gglwe_ct.rs index 60bb7e2..2b64f02 100644 --- a/poulpy-core/src/tests/test_suite/encryption/gglwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/encryption/gglwe_ct.rs @@ -1,77 +1,44 @@ use poulpy_hal::{ - api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, - VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, - VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, - VecZnxSubScalarInplace, VecZnxSwitchRing, VmpPMatAlloc, VmpPrepare, - }, - layouts::{Backend, Module, ScratchOwned}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, - TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl, - }, + api::{ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxFillUniform}, + layouts::{Backend, Module, Scratch, ScratchOwned}, source::Source, }; use crate::{ + GGLWEEncryptSk, GGLWEKeyswitch, GLWESwitchingKeyCompressedEncryptSk, GLWESwitchingKeyEncryptSk, ScratchTakeCore, + decryption::GLWEDecrypt, encryption::SIGMA, layouts::{ - GGLWECiphertextLayout, GGLWESwitchingKey, GLWESecret, - compressed::{Decompress, GGLWESwitchingKeyCompressed}, - prepared::{GLWESecretPrepared, PrepareAlloc}, + GGLWELayout, GLWESecret, GLWESecretPreparedFactory, GLWESwitchingKey, GLWESwitchingKeyCompressed, + GLWESwitchingKeyDecompress, + prepared::{GGLWEPreparedFactory, GLWESecretPrepared}, }, + noise::GGLWENoise, }; -pub fn test_gglwe_switching_key_encrypt_sk(module: &Module) +pub fn test_gglwe_switching_key_encrypt_sk(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes + Module: GGLWEEncryptSk + + GGLWEPreparedFactory + + GGLWEKeyswitch + + GLWEDecrypt + + GLWESecretPreparedFactory + + GLWESwitchingKeyEncryptSk + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAddSmallInplace - + VecZnxSwitchRing - + VecZnxAddScalarInplace - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxSubScalarInplace - + VecZnxCopy - + VmpPMatAlloc - + VmpPrepare, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl - + VecZnxDftAllocBytesImpl - + VecZnxBigAllocBytesImpl - + TakeSvpPPolImpl, + + GGLWENoise, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, { + let n: usize = module.n(); let base2k: usize = 12; let k_ksk: usize = 54; let dsize: usize = k_ksk / base2k; for rank_in in 1_usize..3 { for rank_out in 1_usize..3 { for di in 1_usize..dsize + 1 { - let n: usize = module.n(); let dnum: usize = (k_ksk - di * base2k) / (di * base2k); - let gglwe_infos: GGLWECiphertextLayout = GGLWECiphertextLayout { + let gglwe_infos: GGLWELayout = GGLWELayout { n: n.into(), base2k: base2k.into(), k: k_ksk.into(), @@ -81,23 +48,22 @@ where rank_out: rank_out.into(), }; - let mut ksk: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&gglwe_infos); + let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc_from_infos(&gglwe_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWESwitchingKey::encrypt_sk_scratch_space( - module, - &gglwe_infos, - )); + let mut scratch: ScratchOwned = + ScratchOwned::alloc(GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &gglwe_infos)); - let mut sk_in: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_in.into()); + let mut sk_in: GLWESecret> = GLWESecret::alloc(n.into(), rank_in.into()); sk_in.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_out: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_out.into()); + let mut sk_out: GLWESecret> = GLWESecret::alloc(n.into(), rank_out.into()); sk_out.fill_ternary_prob(0.5, &mut source_xs); - let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); + let mut sk_out_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank_out.into()); + sk_out_prepared.prepare(module, &sk_out); ksk.encrypt_sk( module, @@ -115,55 +81,31 @@ where } } -pub fn test_gglwe_switching_key_compressed_encrypt_sk(module: &Module) +pub fn test_gglwe_switching_key_compressed_encrypt_sk(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAddSmallInplace - + VecZnxSwitchRing - + VecZnxAddScalarInplace - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxSubScalarInplace - + VecZnxCopy - + VmpPMatAlloc - + VmpPrepare, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl - + VecZnxDftAllocBytesImpl - + VecZnxBigAllocBytesImpl - + TakeSvpPPolImpl, + Module: GGLWEEncryptSk + + GGLWEPreparedFactory + + GGLWEKeyswitch + + GLWEDecrypt + + GLWESecretPreparedFactory + + GLWESwitchingKeyEncryptSk + + GLWESwitchingKeyCompressedEncryptSk + + GLWESwitchingKeyDecompress + + GGLWENoise + + VecZnxFillUniform, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, { + let n: usize = module.n(); let base2k: usize = 12; let k_ksk: usize = 54; let dsize: usize = k_ksk / base2k; for rank_in in 1_usize..3 { for rank_out in 1_usize..3 { for di in 1_usize..dsize + 1 { - let n: usize = module.n(); let dnum: usize = (k_ksk - di * base2k) / (di * base2k); - let gglwe_infos: GGLWECiphertextLayout = GGLWECiphertextLayout { + let gglwe_infos: GGLWELayout = GGLWELayout { n: n.into(), base2k: base2k.into(), k: k_ksk.into(), @@ -173,22 +115,24 @@ where rank_out: rank_out.into(), }; - let mut ksk_compressed: GGLWESwitchingKeyCompressed> = GGLWESwitchingKeyCompressed::alloc(&gglwe_infos); + let mut ksk_compressed: GLWESwitchingKeyCompressed> = + GLWESwitchingKeyCompressed::alloc_from_infos(&gglwe_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWESwitchingKeyCompressed::encrypt_sk_scratch_space( + let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWESwitchingKeyCompressed::encrypt_sk_tmp_bytes( module, &gglwe_infos, )); - let mut sk_in: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_in.into()); + let mut sk_in: GLWESecret> = GLWESecret::alloc(n.into(), rank_in.into()); sk_in.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_out: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_out.into()); + let mut sk_out: GLWESecret> = GLWESecret::alloc(n.into(), rank_out.into()); sk_out.fill_ternary_prob(0.5, &mut source_xs); - let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); + let mut sk_out_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank_out.into()); + sk_out_prepared.prepare(module, &sk_out); let seed_xa = [1u8; 32]; @@ -201,7 +145,7 @@ where scratch.borrow(), ); - let mut ksk: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&gglwe_infos); + let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc_from_infos(&gglwe_infos); ksk.decompress(module, &ksk_compressed); ksk.key diff --git a/poulpy-core/src/tests/test_suite/encryption/ggsw_ct.rs b/poulpy-core/src/tests/test_suite/encryption/ggsw_ct.rs index 9eacf38..54b5df0 100644 --- a/poulpy-core/src/tests/test_suite/encryption/ggsw_ct.rs +++ b/poulpy-core/src/tests/test_suite/encryption/ggsw_ct.rs @@ -1,69 +1,23 @@ use poulpy_hal::{ - api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, - VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAlloc, VecZnxDftAllocBytes, - VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VmpPMatAlloc, VmpPrepare, - }, - layouts::{Backend, Module, ScalarZnx, ScratchOwned}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, - TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl, - }, + api::{ScratchOwnedAlloc, ScratchOwnedBorrow}, + layouts::{Backend, Module, ScalarZnx, Scratch, ScratchOwned}, source::Source, }; use crate::{ + GGSWCompressedEncryptSk, GGSWEncryptSk, GGSWNoise, ScratchTakeCore, encryption::SIGMA, layouts::{ - GGSWCiphertext, GGSWCiphertextLayout, GLWESecret, - compressed::{Decompress, GGSWCiphertextCompressed}, - prepared::{GLWESecretPrepared, PrepareAlloc}, + GGSW, GGSWDecompress, GGSWLayout, GLWESecret, GLWESecretPreparedFactory, compressed::GGSWCompressed, + prepared::GLWESecretPrepared, }, }; -pub fn test_ggsw_encrypt_sk(module: &Module) +pub fn test_ggsw_encrypt_sk(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAddSmallInplace - + VecZnxAddScalarInplace - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxCopy - + VmpPMatAlloc - + VmpPrepare - + VecZnxBigAlloc - + VecZnxDftAlloc - + VecZnxBigNormalizeTmpBytes - + VecZnxIdftApplyTmpA, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeSvpPPolImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl - + VecZnxDftAllocBytesImpl - + VecZnxBigAllocBytesImpl - + TakeSvpPPolImpl, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchTakeCore, + Module: GGSWEncryptSk + GLWESecretPreparedFactory + GGSWNoise, { let base2k: usize = 12; let k: usize = 54; @@ -73,7 +27,7 @@ where let n: usize = module.n(); let dnum: usize = (k - di * base2k) / (di * base2k); - let ggsw_infos: GGSWCiphertextLayout = GGSWCiphertextLayout { + let ggsw_infos: GGSWLayout = GGSWLayout { n: n.into(), base2k: base2k.into(), k: k.into(), @@ -82,7 +36,7 @@ where rank: rank.into(), }; - let mut ct: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_infos); + let mut ct: GGSW> = GGSW::alloc_from_infos(&ggsw_infos); let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); @@ -92,14 +46,13 @@ where pt_scalar.fill_ternary_hw(0, n, &mut source_xs); - let mut scratch: ScratchOwned = ScratchOwned::alloc(GGSWCiphertext::encrypt_sk_scratch_space( - module, - &ggsw_infos, - )); + let mut scratch: ScratchOwned = ScratchOwned::alloc(GGSW::encrypt_sk_tmp_bytes(module, &ggsw_infos)); - let mut sk: GLWESecret> = GLWESecret::alloc(&ggsw_infos); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&ggsw_infos); sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); + + let mut sk_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank.into()); + sk_prepared.prepare(module, &sk); ct.encrypt_sk( module, @@ -117,46 +70,11 @@ where } } -pub fn test_ggsw_compressed_encrypt_sk(module: &Module) +pub fn test_ggsw_compressed_encrypt_sk(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAddSmallInplace - + VecZnxAddScalarInplace - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxCopy - + VmpPMatAlloc - + VmpPrepare - + VecZnxBigAlloc - + VecZnxDftAlloc - + VecZnxBigNormalizeTmpBytes - + VecZnxIdftApplyTmpA, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl - + VecZnxDftAllocBytesImpl - + VecZnxBigAllocBytesImpl - + TakeSvpPPolImpl, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchTakeCore, + Module: GGSWCompressedEncryptSk + GLWESecretPreparedFactory + GGSWNoise + GGSWDecompress, { let base2k: usize = 12; let k: usize = 54; @@ -166,7 +84,7 @@ where let n: usize = module.n(); let dnum: usize = (k - di * base2k) / (di * base2k); - let ggsw_infos: GGSWCiphertextLayout = GGSWCiphertextLayout { + let ggsw_infos: GGSWLayout = GGSWLayout { n: n.into(), base2k: base2k.into(), k: k.into(), @@ -175,7 +93,7 @@ where rank: rank.into(), }; - let mut ct_compressed: GGSWCiphertextCompressed> = GGSWCiphertextCompressed::alloc(&ggsw_infos); + let mut ct_compressed: GGSWCompressed> = GGSWCompressed::alloc_from_infos(&ggsw_infos); let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); @@ -184,14 +102,13 @@ where pt_scalar.fill_ternary_hw(0, n, &mut source_xs); - let mut scratch: ScratchOwned = ScratchOwned::alloc(GGSWCiphertextCompressed::encrypt_sk_scratch_space( - module, - &ggsw_infos, - )); + let mut scratch: ScratchOwned = ScratchOwned::alloc(GGSWCompressed::encrypt_sk_tmp_bytes(module, &ggsw_infos)); - let mut sk: GLWESecret> = GLWESecret::alloc(&ggsw_infos); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&ggsw_infos); sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); + + let mut sk_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank.into()); + sk_prepared.prepare(module, &sk); let seed_xa: [u8; 32] = [1u8; 32]; @@ -206,7 +123,7 @@ where let noise_f = |_col_i: usize| -(k as f64) + SIGMA.log2() + 0.5; - let mut ct: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_infos); + let mut ct: GGSW> = GGSW::alloc_from_infos(&ggsw_infos); ct.decompress(module, &ct_compressed); ct.assert_noise(module, &sk_prepared, &pt_scalar, noise_f); diff --git a/poulpy-core/src/tests/test_suite/encryption/glwe_ct.rs b/poulpy-core/src/tests/test_suite/encryption/glwe_ct.rs index a1169f6..514d28d 100644 --- a/poulpy-core/src/tests/test_suite/encryption/glwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/encryption/glwe_ct.rs @@ -1,61 +1,26 @@ use poulpy_hal::{ - api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, - SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxBigAddInplace, VecZnxBigAddNormal, VecZnxBigAddSmallInplace, - VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxCopy, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, - VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, - VecZnxSubInplace, - }, - layouts::{Backend, Module, ScratchOwned}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, - TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, - }, + api::{ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxFillUniform}, + layouts::{Backend, Module, Scratch, ScratchOwned}, source::Source, }; use crate::{ + GLWECompressedEncryptSk, GLWEEncryptPk, GLWEEncryptSk, GLWEPublicKeyGenerate, GLWESub, ScratchTakeCore, + decryption::GLWEDecrypt, encryption::SIGMA, layouts::{ - GLWECiphertext, GLWECiphertextLayout, GLWEPlaintext, GLWEPlaintextLayout, GLWEPublicKey, GLWESecret, LWEInfos, - compressed::{Decompress, GLWECiphertextCompressed}, - prepared::{GLWEPublicKeyPrepared, GLWESecretPrepared, PrepareAlloc}, + GLWE, GLWELayout, GLWEPlaintext, GLWEPlaintextLayout, GLWEPublicKey, GLWEPublicKeyPreparedFactory, GLWESecret, + GLWESecretPreparedFactory, LWEInfos, + compressed::GLWECompressed, + prepared::{GLWEPublicKeyPrepared, GLWESecretPrepared}, }, - operations::GLWEOperations, }; -pub fn test_glwe_encrypt_sk(module: &Module) +pub fn test_glwe_encrypt_sk(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxBigAllocBytes - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalizeTmpBytes - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + SvpApplyDftToDft - + VecZnxBigAddNormal - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeSvpPPolImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl, + Module: GLWEEncryptSk + GLWEDecrypt + GLWESecretPreparedFactory + VecZnxFillUniform + GLWESub, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, { let base2k: usize = 8; let k_ct: usize = 54; @@ -64,7 +29,7 @@ where for rank in 1_usize..3 { let n: usize = module.n(); - let glwe_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_infos: GLWELayout = GLWELayout { n: n.into(), base2k: base2k.into(), k: k_ct.into(), @@ -77,22 +42,22 @@ where k: k_pt.into(), }; - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&glwe_infos); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&pt_infos); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&pt_infos); + let mut ct: GLWE> = GLWE::alloc_from_infos(&glwe_infos); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&pt_infos); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&pt_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::encrypt_sk_scratch_space(module, &glwe_infos) - | GLWECiphertext::decrypt_scratch_space(module, &glwe_infos), - ); + let mut scratch: ScratchOwned = + ScratchOwned::alloc(GLWE::encrypt_sk_tmp_bytes(module, &glwe_infos) | GLWE::decrypt_tmp_bytes(module, &glwe_infos)); - let mut sk: GLWESecret> = GLWESecret::alloc(&glwe_infos); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_infos); sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); + + let mut sk_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank.into()); + sk_prepared.prepare(module, &sk); module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa); @@ -107,7 +72,7 @@ where ct.decrypt(module, &mut pt_have, &sk_prepared, scratch.borrow()); - pt_want.sub_inplace_ab(module, &pt_have); + module.glwe_sub_inplace(&mut pt_want, &pt_have); let noise_have: f64 = pt_want.data.std(base2k, 0) * (ct.k().as_u32() as f64).exp2(); let noise_want: f64 = SIGMA; @@ -116,48 +81,21 @@ where } } -pub fn test_glwe_compressed_encrypt_sk(module: &Module) +pub fn test_glwe_compressed_encrypt_sk(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxBigAllocBytes - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalizeTmpBytes - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + SvpApplyDftToDft - + VecZnxBigAddNormal - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + VecZnxCopy, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeSvpPPolImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl, + Module: GLWECompressedEncryptSk + GLWEDecrypt + GLWESecretPreparedFactory + VecZnxFillUniform + GLWESub, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, { let base2k: usize = 8; let k_ct: usize = 54; let k_pt: usize = 30; for rank in 1_usize..3 { + // println!("rank: {}", rank); let n: usize = module.n(); - let glwe_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_infos: GLWELayout = GLWELayout { n: n.into(), base2k: base2k.into(), k: k_ct.into(), @@ -170,23 +108,24 @@ where k: k_pt.into(), }; - let mut ct_compressed: GLWECiphertextCompressed> = GLWECiphertextCompressed::alloc(&glwe_infos); + let mut ct_compressed: GLWECompressed> = GLWECompressed::alloc_from_infos(&glwe_infos); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&pt_infos); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&pt_infos); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&pt_infos); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&pt_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertextCompressed::encrypt_sk_scratch_space(module, &glwe_infos) - | GLWECiphertext::decrypt_scratch_space(module, &glwe_infos), + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GLWECompressed::encrypt_sk_tmp_bytes(module, &glwe_infos) | GLWE::decrypt_tmp_bytes(module, &glwe_infos), ); - let mut sk: GLWESecret> = GLWESecret::alloc(&glwe_infos); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_infos); sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); + + let mut sk_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank.into()); + sk_prepared.prepare(module, &sk); module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa); @@ -201,12 +140,12 @@ where scratch.borrow(), ); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&glwe_infos); + let mut ct: GLWE> = GLWE::alloc_from_infos(&glwe_infos); ct.decompress(module, &ct_compressed); ct.decrypt(module, &mut pt_have, &sk_prepared, scratch.borrow()); - pt_want.sub_inplace_ab(module, &pt_have); + module.glwe_sub_inplace(&mut pt_want, &pt_have); let noise_have: f64 = pt_want.data.std(base2k, 0) * (ct.k().as_u32() as f64).exp2(); let noise_want: f64 = SIGMA; @@ -219,38 +158,11 @@ where } } -pub fn test_glwe_encrypt_zero_sk(module: &Module) +pub fn test_glwe_encrypt_zero_sk(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxBigAllocBytes - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalizeTmpBytes - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + SvpApplyDftToDft - + VecZnxBigAddNormal - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeSvpPPolImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl, + Module: GLWEEncryptSk + GLWEDecrypt + GLWESecretPreparedFactory + VecZnxFillUniform + GLWESub, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, { let base2k: usize = 8; let k_ct: usize = 54; @@ -258,29 +170,29 @@ where for rank in 1_usize..3 { let n: usize = module.n(); - let glwe_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_infos: GLWELayout = GLWELayout { n: n.into(), base2k: base2k.into(), k: k_ct.into(), rank: rank.into(), }; - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_infos); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([1u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::decrypt_scratch_space(module, &glwe_infos) - | GLWECiphertext::encrypt_sk_scratch_space(module, &glwe_infos), - ); + let mut scratch: ScratchOwned = + ScratchOwned::alloc(GLWE::decrypt_tmp_bytes(module, &glwe_infos) | GLWE::encrypt_sk_tmp_bytes(module, &glwe_infos)); - let mut sk: GLWESecret> = GLWESecret::alloc(&glwe_infos); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_infos); sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&glwe_infos); + let mut sk_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank.into()); + sk_prepared.prepare(module, &sk); + + let mut ct: GLWE> = GLWE::alloc_from_infos(&glwe_infos); ct.encrypt_zero_sk( module, @@ -295,40 +207,17 @@ where } } -pub fn test_glwe_encrypt_pk(module: &Module) +pub fn test_glwe_encrypt_pk(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes + Module: GLWEEncryptPk + + GLWEPublicKeyPreparedFactory + + GLWEPublicKeyGenerate + + GLWEDecrypt + + GLWESecretPreparedFactory + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAddSmallInplace - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxCopy - + VecZnxDftAlloc - + SvpApplyDftToDft - + VecZnxBigAddNormal, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeSvpPPolImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl, + + GLWESub, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, { let base2k: usize = 8; let k_ct: usize = 54; @@ -336,38 +225,38 @@ where for rank in 1_usize..3 { let n: usize = module.n(); - let glwe_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_infos: GLWELayout = GLWELayout { n: n.into(), base2k: base2k.into(), k: k_ct.into(), rank: rank.into(), }; - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&glwe_infos); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_infos); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_infos); + let mut ct: GLWE> = GLWE::alloc_from_infos(&glwe_infos); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_infos); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut source_xu: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::encrypt_sk_scratch_space(module, &glwe_infos) - | GLWECiphertext::decrypt_scratch_space(module, &glwe_infos) - | GLWECiphertext::encrypt_pk_scratch_space(module, &glwe_infos), - ); + let mut scratch: ScratchOwned = + ScratchOwned::alloc(GLWE::decrypt_tmp_bytes(module, &glwe_infos) | GLWE::encrypt_pk_tmp_bytes(module, &glwe_infos)); - let mut sk: GLWESecret> = GLWESecret::alloc(&glwe_infos); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_infos); sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); - let mut pk: GLWEPublicKey> = GLWEPublicKey::alloc(&glwe_infos); - pk.generate_from_sk(module, &sk_prepared, &mut source_xa, &mut source_xe); + let mut sk_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank.into()); + sk_prepared.prepare(module, &sk); + + let mut pk: GLWEPublicKey> = GLWEPublicKey::alloc_from_infos(&glwe_infos); + pk.generate(module, &sk_prepared, &mut source_xa, &mut source_xe); module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa); - let pk_prepared: GLWEPublicKeyPrepared, B> = pk.prepare_alloc(module, scratch.borrow()); + let mut pk_prepared: GLWEPublicKeyPrepared, BE> = GLWEPublicKeyPrepared::alloc_from_infos(module, &glwe_infos); + pk_prepared.prepare(module, &pk); ct.encrypt_pk( module, @@ -380,7 +269,7 @@ where ct.decrypt(module, &mut pt_have, &sk_prepared, scratch.borrow()); - pt_want.sub_inplace_ab(module, &pt_have); + module.glwe_sub_inplace(&mut pt_want, &pt_have); let noise_have: f64 = pt_want.data.std(base2k, 0).log2(); let noise_want: f64 = ((((rank as f64) + 1.0) * n as f64 * 0.5 * SIGMA * SIGMA).sqrt()).log2() - (k_ct as f64); diff --git a/poulpy-core/src/tests/test_suite/encryption/glwe_tsk.rs b/poulpy-core/src/tests/test_suite/encryption/glwe_tsk.rs index 8bb9730..940f917 100644 --- a/poulpy-core/src/tests/test_suite/encryption/glwe_tsk.rs +++ b/poulpy-core/src/tests/test_suite/encryption/glwe_tsk.rs @@ -1,68 +1,37 @@ use poulpy_hal::{ api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, - SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, - VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxCopy, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, - VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSubScalarInplace, VecZnxSwitchRing, - }, - layouts::{Backend, Module, ScratchOwned, VecZnxDft}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, - TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl, + ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDft, VecZnxBigAlloc, VecZnxBigNormalize, + VecZnxCopy, VecZnxDftAlloc, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyTmpA, VecZnxSubScalarInplace, + VecZnxSwitchRing, }, + layouts::{Backend, Module, Scratch, ScratchOwned, VecZnxBig, VecZnxDft}, source::Source, }; use crate::{ + GLWETensorKeyCompressedEncryptSk, GLWETensorKeyEncryptSk, ScratchTakeCore, + decryption::GLWEDecrypt, encryption::SIGMA, layouts::{ - Dsize, GGLWETensorKey, GGLWETensorKeyLayout, GLWEPlaintext, GLWESecret, - compressed::{Decompress, GGLWETensorKeyCompressed}, - prepared::{GLWESecretPrepared, PrepareAlloc}, + Dsize, GLWEPlaintext, GLWESecret, GLWESecretPreparedFactory, GLWETensorKey, GLWETensorKeyCompressed, GLWETensorKeyLayout, + prepared::GLWESecretPrepared, }, }; -pub fn test_gglwe_tensor_key_encrypt_sk(module: &Module) +pub fn test_gglwe_tensor_key_encrypt_sk(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAddSmallInplace - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxCopy - + VecZnxDftAlloc - + SvpApplyDftToDft - + VecZnxBigAlloc - + VecZnxIdftApplyTmpA - + VecZnxAddScalarInplace - + VecZnxSwitchRing + Module: GLWETensorKeyEncryptSk + + GLWESecretPreparedFactory + + GLWEDecrypt + + VecZnxDftAlloc + + VecZnxBigAlloc + + VecZnxDftApply + + SvpApplyDftToDft + + VecZnxIdftApplyTmpA + + VecZnxBigNormalize + VecZnxSubScalarInplace, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl - + VecZnxDftAllocBytesImpl - + VecZnxBigAllocBytesImpl - + TakeSvpPPolImpl, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, { let base2k: usize = 8; let k: usize = 54; @@ -71,7 +40,7 @@ where let n: usize = module.n(); let dnum: usize = k / base2k; - let tensor_key_infos = GGLWETensorKeyLayout { + let tensor_key_infos = GLWETensorKeyLayout { n: n.into(), base2k: base2k.into(), k: k.into(), @@ -80,20 +49,21 @@ where rank: rank.into(), }; - let mut tensor_key: GGLWETensorKey> = GGLWETensorKey::alloc(&tensor_key_infos); + let mut tensor_key: GLWETensorKey> = GLWETensorKey::alloc_from_infos(&tensor_key_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWETensorKey::encrypt_sk_scratch_space( + let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWETensorKey::encrypt_sk_tmp_bytes( module, &tensor_key_infos, )); - let mut sk: GLWESecret> = GLWESecret::alloc(&tensor_key_infos); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&tensor_key_infos); sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); + let mut sk_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank.into()); + sk_prepared.prepare(module, &sk); tensor_key.encrypt_sk( module, @@ -103,12 +73,12 @@ where scratch.borrow(), ); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&tensor_key_infos); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&tensor_key_infos); - let mut sk_ij_dft = module.vec_znx_dft_alloc(1, 1); - let mut sk_ij_big = module.vec_znx_big_alloc(1, 1); - let mut sk_ij: GLWESecret> = GLWESecret::alloc_with(n.into(), 1_u32.into()); - let mut sk_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(rank, 1); + let mut sk_ij_dft: VecZnxDft, BE> = module.vec_znx_dft_alloc(1, 1); + let mut sk_ij_big: VecZnxBig, BE> = module.vec_znx_big_alloc(1, 1); + let mut sk_ij: GLWESecret> = GLWESecret::alloc(n.into(), 1_u32.into()); + let mut sk_dft: VecZnxDft, BE> = module.vec_znx_dft_alloc(rank, 1); for i in 0..rank { module.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); @@ -128,10 +98,9 @@ where scratch.borrow(), ); for row_i in 0..dnum { - tensor_key - .at(i, j) - .at(row_i, 0) - .decrypt(module, &mut pt, &sk_prepared, scratch.borrow()); + let ct = tensor_key.at(i, j).at(row_i, 0); + + ct.decrypt(module, &mut pt, &sk_prepared, scratch.borrow()); module.vec_znx_sub_scalar_inplace(&mut pt.data, 0, row_i, &sk_ij.data, 0); @@ -143,46 +112,23 @@ where } } -pub fn test_gglwe_tensor_key_compressed_encrypt_sk(module: &Module) +pub fn test_gglwe_tensor_key_compressed_encrypt_sk(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes + Module: GLWETensorKeyEncryptSk + + GLWESecretPreparedFactory + + GLWETensorKeyCompressedEncryptSk + + GLWEDecrypt + + VecZnxDftAlloc + + VecZnxBigAlloc + + VecZnxDftApply + + SvpApplyDftToDft + + VecZnxIdftApplyTmpA + + VecZnxSubScalarInplace + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAddSmallInplace - + VecZnxBigAllocBytes - + VecZnxBigAddInplace + VecZnxCopy - + VecZnxDftAlloc - + SvpApplyDftToDft - + VecZnxBigAlloc - + VecZnxIdftApplyTmpA - + VecZnxAddScalarInplace - + VecZnxSwitchRing - + VecZnxSubScalarInplace, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl - + VecZnxDftAllocBytesImpl - + VecZnxBigAllocBytesImpl - + TakeSvpPPolImpl, + + VecZnxSwitchRing, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, { let base2k = 8; let k = 54; @@ -190,7 +136,7 @@ where let n: usize = module.n(); let dnum: usize = k / base2k; - let tensor_key_infos: GGLWETensorKeyLayout = GGLWETensorKeyLayout { + let tensor_key_infos: GLWETensorKeyLayout = GLWETensorKeyLayout { n: n.into(), base2k: base2k.into(), k: k.into(), @@ -199,33 +145,35 @@ where rank: rank.into(), }; - let mut tensor_key_compressed: GGLWETensorKeyCompressed> = GGLWETensorKeyCompressed::alloc(&tensor_key_infos); + let mut tensor_key_compressed: GLWETensorKeyCompressed> = + GLWETensorKeyCompressed::alloc_from_infos(&tensor_key_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWETensorKeyCompressed::encrypt_sk_scratch_space( + let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWETensorKeyCompressed::encrypt_sk_tmp_bytes( module, &tensor_key_infos, )); - let mut sk: GLWESecret> = GLWESecret::alloc(&tensor_key_infos); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&tensor_key_infos); sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); + let mut sk_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank.into()); + sk_prepared.prepare(module, &sk); let seed_xa: [u8; 32] = [1u8; 32]; tensor_key_compressed.encrypt_sk(module, &sk, seed_xa, &mut source_xe, scratch.borrow()); - let mut tensor_key: GGLWETensorKey> = GGLWETensorKey::alloc(&tensor_key_infos); + let mut tensor_key: GLWETensorKey> = GLWETensorKey::alloc_from_infos(&tensor_key_infos); tensor_key.decompress(module, &tensor_key_compressed); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&tensor_key_infos); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&tensor_key_infos); - let mut sk_ij_dft = module.vec_znx_dft_alloc(1, 1); - let mut sk_ij_big = module.vec_znx_big_alloc(1, 1); - let mut sk_ij: GLWESecret> = GLWESecret::alloc_with(n.into(), 1_u32.into()); - let mut sk_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(rank, 1); + let mut sk_ij_dft: VecZnxDft, BE> = module.vec_znx_dft_alloc(1, 1); + let mut sk_ij_big: VecZnxBig, BE> = module.vec_znx_big_alloc(1, 1); + let mut sk_ij: GLWESecret> = GLWESecret::alloc(n.into(), 1_u32.into()); + let mut sk_dft: VecZnxDft, BE> = module.vec_znx_dft_alloc(rank, 1); for i in 0..rank { module.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); diff --git a/poulpy-core/src/tests/test_suite/external_product/gglwe_ksk.rs b/poulpy-core/src/tests/test_suite/external_product/gglwe_ksk.rs index 2cfa618..07a0926 100644 --- a/poulpy-core/src/tests/test_suite/external_product/gglwe_ksk.rs +++ b/poulpy-core/src/tests/test_suite/external_product/gglwe_ksk.rs @@ -1,72 +1,31 @@ use poulpy_hal::{ - api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, - VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, - VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotateInplace, VecZnxSub, - VecZnxSubInplace, VecZnxSubScalarInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, - VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, - }, - layouts::{Backend, Module, ScalarZnx, ScalarZnxToMut, ScratchOwned, ZnxViewMut}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, - TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl, - }, + api::{ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxRotateInplace}, + layouts::{Backend, Module, ScalarZnx, ScalarZnxToMut, Scratch, ScratchOwned, ZnxViewMut}, source::Source, }; use crate::{ + GGLWEExternalProduct, GGLWENoise, GGSWEncryptSk, GLWESwitchingKeyEncryptSk, ScratchTakeCore, encryption::SIGMA, layouts::{ - GGLWESwitchingKey, GGLWESwitchingKeyLayout, GGSWCiphertext, GGSWCiphertextLayout, GLWESecret, - prepared::{GGSWCiphertextPrepared, GLWESecretPrepared, PrepareAlloc}, + GGSW, GGSWLayout, GGSWPreparedFactory, GLWESecret, GLWESecretPreparedFactory, GLWESwitchingKey, GLWESwitchingKeyLayout, + prepared::{GGSWPrepared, GLWESecretPrepared}, }, noise::noise_ggsw_product, }; #[allow(clippy::too_many_arguments)] -pub fn test_gglwe_switching_key_external_product(module: &Module) +pub fn test_gglwe_switching_key_external_product(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxSwitchRing - + VecZnxAddScalarInplace - + VecZnxSubScalarInplace - + VecZnxCopy - + VmpPMatAlloc - + VecZnxRotateInplace - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VmpPrepare, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl - + VecZnxDftAllocBytesImpl - + VecZnxBigAllocBytesImpl - + TakeSvpPPolImpl, + Module: GGLWEExternalProduct + + GGSWEncryptSk + + GLWESwitchingKeyEncryptSk + + GLWESecretPreparedFactory + + VecZnxRotateInplace + + GGSWPreparedFactory + + GGLWENoise, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, { let base2k: usize = 12; let k_in: usize = 60; @@ -81,7 +40,7 @@ where let dnum: usize = k_in.div_ceil(base2k * di); let dsize_in: usize = 1; - let gglwe_in_infos: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + let gglwe_in_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), base2k: base2k.into(), k: k_in.into(), @@ -91,7 +50,7 @@ where rank_out: rank_out.into(), }; - let gglwe_out_infos: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + let gglwe_out_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), base2k: base2k.into(), k: k_out.into(), @@ -101,7 +60,7 @@ where rank_out: rank_out.into(), }; - let ggsw_infos: GGSWCiphertextLayout = GGSWCiphertextLayout { + let ggsw_infos: GGSWLayout = GGSWLayout { n: n.into(), base2k: base2k.into(), k: k_ggsw.into(), @@ -110,9 +69,9 @@ where rank: rank_out.into(), }; - let mut ct_gglwe_in: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&gglwe_in_infos); - let mut ct_gglwe_out: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&gglwe_out_infos); - let mut ct_rgsw: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_infos); + let mut ct_gglwe_in: GLWESwitchingKey> = GLWESwitchingKey::alloc_from_infos(&gglwe_in_infos); + let mut ct_gglwe_out: GLWESwitchingKey> = GLWESwitchingKey::alloc_from_infos(&gglwe_out_infos); + let mut ct_rgsw: GGSW> = GGSW::alloc_from_infos(&ggsw_infos); let mut pt_rgsw: ScalarZnx> = ScalarZnx::alloc(n, 1); @@ -120,15 +79,10 @@ where let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWESwitchingKey::encrypt_sk_scratch_space(module, &gglwe_in_infos) - | GGLWESwitchingKey::external_product_scratch_space( - module, - &gglwe_out_infos, - &gglwe_in_infos, - &ggsw_infos, - ) - | GGSWCiphertext::encrypt_sk_scratch_space(module, &ggsw_infos), + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &gglwe_in_infos) + | GLWESwitchingKey::external_product_tmp_bytes(module, &gglwe_out_infos, &gglwe_in_infos, &ggsw_infos) + | GGSW::encrypt_sk_tmp_bytes(module, &ggsw_infos), ); let r: usize = 1; @@ -137,12 +91,14 @@ where let var_xs: f64 = 0.5; - let mut sk_in: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_in.into()); + let mut sk_in: GLWESecret> = GLWESecret::alloc(n.into(), rank_in.into()); sk_in.fill_ternary_prob(var_xs, &mut source_xs); - let mut sk_out: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_out.into()); + let mut sk_out: GLWESecret> = GLWESecret::alloc(n.into(), rank_out.into()); sk_out.fill_ternary_prob(var_xs, &mut source_xs); - let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); + + let mut sk_out_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank_out.into()); + sk_out_prepared.prepare(module, &sk_out); // gglwe_{s1}(s0) = s0 -> s1 ct_gglwe_in.encrypt_sk( @@ -163,7 +119,8 @@ where scratch.borrow(), ); - let ct_rgsw_prepared: GGSWCiphertextPrepared, B> = ct_rgsw.prepare_alloc(module, scratch.borrow()); + let mut ct_rgsw_prepared: GGSWPrepared, BE> = GGSWPrepared::alloc_from_infos(module, &ct_rgsw); + ct_rgsw_prepared.prepare(module, &ct_rgsw, scratch.borrow()); // gglwe_(m) (x) RGSW_(X^k) = gglwe_(m * X^k) ct_gglwe_out.external_product(module, &ct_gglwe_in, &ct_rgsw_prepared, scratch.borrow()); @@ -207,48 +164,17 @@ where } #[allow(clippy::too_many_arguments)] -pub fn test_gglwe_switching_key_external_product_inplace(module: &Module) +pub fn test_gglwe_switching_key_external_product_inplace(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxSwitchRing - + VecZnxAddScalarInplace - + VecZnxSubScalarInplace - + VecZnxCopy - + VmpPMatAlloc - + VecZnxRotateInplace - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VmpPrepare, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl - + VecZnxDftAllocBytesImpl - + VecZnxBigAllocBytesImpl - + TakeSvpPPolImpl, + Module: GGLWEExternalProduct + + GGSWEncryptSk + + GLWESwitchingKeyEncryptSk + + GLWESecretPreparedFactory + + VecZnxRotateInplace + + GGSWPreparedFactory + + GGLWENoise, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, { let base2k: usize = 12; let k_out: usize = 60; @@ -263,7 +189,7 @@ where let dsize_in: usize = 1; - let gglwe_out_infos: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + let gglwe_out_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), base2k: base2k.into(), k: k_out.into(), @@ -273,7 +199,7 @@ where rank_out: rank_out.into(), }; - let ggsw_infos: GGSWCiphertextLayout = GGSWCiphertextLayout { + let ggsw_infos: GGSWLayout = GGSWLayout { n: n.into(), base2k: base2k.into(), k: k_ggsw.into(), @@ -282,8 +208,8 @@ where rank: rank_out.into(), }; - let mut ct_gglwe: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&gglwe_out_infos); - let mut ct_rgsw: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_infos); + let mut ct_gglwe: GLWESwitchingKey> = GLWESwitchingKey::alloc_from_infos(&gglwe_out_infos); + let mut ct_rgsw: GGSW> = GGSW::alloc_from_infos(&ggsw_infos); let mut pt_rgsw: ScalarZnx> = ScalarZnx::alloc(n, 1); @@ -291,10 +217,10 @@ where let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWESwitchingKey::encrypt_sk_scratch_space(module, &gglwe_out_infos) - | GGLWESwitchingKey::external_product_inplace_scratch_space(module, &gglwe_out_infos, &ggsw_infos) - | GGSWCiphertext::encrypt_sk_scratch_space(module, &ggsw_infos), + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &gglwe_out_infos) + | GLWESwitchingKey::external_product_tmp_bytes(module, &gglwe_out_infos, &gglwe_out_infos, &ggsw_infos) + | GGSW::encrypt_sk_tmp_bytes(module, &ggsw_infos), ); let r: usize = 1; @@ -303,12 +229,14 @@ where let var_xs: f64 = 0.5; - let mut sk_in: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_in.into()); + let mut sk_in: GLWESecret> = GLWESecret::alloc(n.into(), rank_in.into()); sk_in.fill_ternary_prob(var_xs, &mut source_xs); - let mut sk_out: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_out.into()); + let mut sk_out: GLWESecret> = GLWESecret::alloc(n.into(), rank_out.into()); sk_out.fill_ternary_prob(var_xs, &mut source_xs); - let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); + + let mut sk_out_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank_out.into()); + sk_out_prepared.prepare(module, &sk_out); // gglwe_{s1}(s0) = s0 -> s1 ct_gglwe.encrypt_sk( @@ -329,7 +257,8 @@ where scratch.borrow(), ); - let ct_rgsw_prepared: GGSWCiphertextPrepared, B> = ct_rgsw.prepare_alloc(module, scratch.borrow()); + let mut ct_rgsw_prepared: GGSWPrepared, BE> = GGSWPrepared::alloc_from_infos(module, &ct_rgsw); + ct_rgsw_prepared.prepare(module, &ct_rgsw, scratch.borrow()); // gglwe_(m) (x) RGSW_(X^k) = gglwe_(m * X^k) ct_gglwe.external_product_inplace(module, &ct_rgsw_prepared, scratch.borrow()); diff --git a/poulpy-core/src/tests/test_suite/external_product/ggsw_ct.rs b/poulpy-core/src/tests/test_suite/external_product/ggsw_ct.rs index 464eac2..3fe2da4 100644 --- a/poulpy-core/src/tests/test_suite/external_product/ggsw_ct.rs +++ b/poulpy-core/src/tests/test_suite/external_product/ggsw_ct.rs @@ -1,74 +1,30 @@ use poulpy_hal::{ - api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, - VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAlloc, VecZnxDftAllocBytes, - VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxRotateInplace, VecZnxSub, VecZnxSubInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd, - VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, - }, - layouts::{Backend, Module, ScalarZnx, ScalarZnxToMut, ScratchOwned, ZnxViewMut}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, - TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl, - }, + api::{ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxRotateInplace}, + layouts::{Backend, Module, ScalarZnx, ScalarZnxToMut, Scratch, ScratchOwned, ZnxViewMut}, source::Source, }; use crate::{ + GGSWEncryptSk, GGSWExternalProduct, GGSWNoise, ScratchTakeCore, encryption::SIGMA, layouts::{ - GGSWCiphertext, GGSWCiphertextLayout, GLWESecret, - prepared::{GGSWCiphertextPrepared, GLWESecretPrepared, PrepareAlloc}, + GGSW, GGSWLayout, GGSWPreparedFactory, GLWESecret, GLWESecretPreparedFactory, + prepared::{GGSWPrepared, GLWESecretPrepared}, }, noise::noise_ggsw_product, }; #[allow(clippy::too_many_arguments)] -pub fn test_ggsw_external_product(module: &Module) +pub fn test_ggsw_external_product(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxAddScalarInplace - + VecZnxCopy - + VmpPMatAlloc - + VecZnxRotateInplace - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VmpPrepare - + VecZnxBigAlloc - + VecZnxDftAlloc - + VecZnxBigNormalizeTmpBytes - + VecZnxIdftApplyTmpA, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl - + VecZnxDftAllocBytesImpl - + VecZnxBigAllocBytesImpl - + TakeSvpPPolImpl, + Module: GGSWEncryptSk + + GGSWExternalProduct + + GLWESecretPreparedFactory + + GGSWPreparedFactory + + VecZnxRotateInplace + + GGSWNoise, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, { let base2k: usize = 12; let k_in: usize = 60; @@ -84,7 +40,7 @@ where let dnum_in: usize = k_in.div_euclid(base2k * di); let dsize_in: usize = 1; - let ggsw_in_infos: GGSWCiphertextLayout = GGSWCiphertextLayout { + let ggsw_in_infos: GGSWLayout = GGSWLayout { n: n.into(), base2k: base2k.into(), k: k_in.into(), @@ -93,7 +49,7 @@ where rank: rank.into(), }; - let ggsw_out_infos: GGSWCiphertextLayout = GGSWCiphertextLayout { + let ggsw_out_infos: GGSWLayout = GGSWLayout { n: n.into(), base2k: base2k.into(), k: k_out.into(), @@ -102,7 +58,7 @@ where rank: rank.into(), }; - let ggsw_apply_infos: GGSWCiphertextLayout = GGSWCiphertextLayout { + let ggsw_apply_infos: GGSWLayout = GGSWLayout { n: n.into(), base2k: base2k.into(), k: k_apply.into(), @@ -111,9 +67,9 @@ where rank: rank.into(), }; - let mut ggsw_in: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_in_infos); - let mut ggsw_out: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_out_infos); - let mut ggsw_apply: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_apply_infos); + let mut ggsw_in: GGSW> = GGSW::alloc_from_infos(&ggsw_in_infos); + let mut ggsw_out: GGSW> = GGSW::alloc_from_infos(&ggsw_out_infos); + let mut ggsw_apply: GGSW> = GGSW::alloc_from_infos(&ggsw_apply_infos); let mut pt_in: ScalarZnx> = ScalarZnx::alloc(n, 1); let mut pt_apply: ScalarZnx> = ScalarZnx::alloc(n, 1); @@ -127,15 +83,17 @@ where pt_apply.to_mut().raw_mut()[k] = 1; //X^{k} - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, &ggsw_apply_infos) - | GGSWCiphertext::encrypt_sk_scratch_space(module, &ggsw_in_infos) - | GGSWCiphertext::external_product_scratch_space(module, &ggsw_out_infos, &ggsw_in_infos, &ggsw_apply_infos), + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GGSW::encrypt_sk_tmp_bytes(module, &ggsw_apply_infos) + | GGSW::encrypt_sk_tmp_bytes(module, &ggsw_in_infos) + | GGSW::external_product_tmp_bytes(module, &ggsw_out_infos, &ggsw_in_infos, &ggsw_apply_infos), ); - let mut sk: GLWESecret> = GLWESecret::alloc_with(n.into(), rank.into()); + let mut sk: GLWESecret> = GLWESecret::alloc(n.into(), rank.into()); sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); + + let mut sk_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank.into()); + sk_prepared.prepare(module, &sk); ggsw_apply.encrypt_sk( module, @@ -155,7 +113,8 @@ where scratch.borrow(), ); - let ct_rhs_prepared: GGSWCiphertextPrepared, B> = ggsw_apply.prepare_alloc(module, scratch.borrow()); + let mut ct_rhs_prepared: GGSWPrepared, BE> = GGSWPrepared::alloc_from_infos(module, &ggsw_apply); + ct_rhs_prepared.prepare(module, &ggsw_apply, scratch.borrow()); ggsw_out.external_product(module, &ggsw_in, &ct_rhs_prepared, scratch.borrow()); @@ -190,50 +149,16 @@ where } #[allow(clippy::too_many_arguments)] -pub fn test_ggsw_external_product_inplace(module: &Module) +pub fn test_ggsw_external_product_inplace(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxAddScalarInplace - + VecZnxCopy - + VmpPMatAlloc - + VecZnxRotateInplace - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VmpPrepare - + VecZnxBigAlloc - + VecZnxDftAlloc - + VecZnxBigNormalizeTmpBytes - + VecZnxIdftApplyTmpA, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl - + VecZnxDftAllocBytesImpl - + VecZnxBigAllocBytesImpl - + TakeSvpPPolImpl, + Module: GGSWEncryptSk + + GGSWExternalProduct + + GLWESecretPreparedFactory + + GGSWPreparedFactory + + VecZnxRotateInplace + + GGSWNoise, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, { let base2k: usize = 12; let k_out: usize = 60; @@ -247,7 +172,7 @@ where let dnum_in: usize = k_out.div_euclid(base2k * di); let dsize_in: usize = 1; - let ggsw_out_infos: GGSWCiphertextLayout = GGSWCiphertextLayout { + let ggsw_out_infos: GGSWLayout = GGSWLayout { n: n.into(), base2k: base2k.into(), k: k_out.into(), @@ -256,7 +181,7 @@ where rank: rank.into(), }; - let ggsw_apply_infos: GGSWCiphertextLayout = GGSWCiphertextLayout { + let ggsw_apply_infos: GGSWLayout = GGSWLayout { n: n.into(), base2k: base2k.into(), k: k_apply.into(), @@ -265,8 +190,8 @@ where rank: rank.into(), }; - let mut ggsw_out: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_out_infos); - let mut ggsw_apply: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_apply_infos); + let mut ggsw_out: GGSW> = GGSW::alloc_from_infos(&ggsw_out_infos); + let mut ggsw_apply: GGSW> = GGSW::alloc_from_infos(&ggsw_apply_infos); let mut pt_in: ScalarZnx> = ScalarZnx::alloc(n, 1); let mut pt_apply: ScalarZnx> = ScalarZnx::alloc(n, 1); @@ -281,15 +206,17 @@ where pt_apply.to_mut().raw_mut()[k] = 1; //X^{k} - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, &ggsw_apply_infos) - | GGSWCiphertext::encrypt_sk_scratch_space(module, &ggsw_out_infos) - | GGSWCiphertext::external_product_inplace_scratch_space(module, &ggsw_out_infos, &ggsw_apply_infos), + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GGSW::encrypt_sk_tmp_bytes(module, &ggsw_apply_infos) + | GGSW::encrypt_sk_tmp_bytes(module, &ggsw_out_infos) + | GGSW::external_product_tmp_bytes(module, &ggsw_out_infos, &ggsw_out_infos, &ggsw_apply_infos), ); - let mut sk: GLWESecret> = GLWESecret::alloc_with(n.into(), rank.into()); + let mut sk: GLWESecret> = GLWESecret::alloc(n.into(), rank.into()); sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); + + let mut sk_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank.into()); + sk_prepared.prepare(module, &sk); ggsw_apply.encrypt_sk( module, @@ -309,7 +236,8 @@ where scratch.borrow(), ); - let ct_rhs_prepared: GGSWCiphertextPrepared, B> = ggsw_apply.prepare_alloc(module, scratch.borrow()); + let mut ct_rhs_prepared: GGSWPrepared, BE> = GGSWPrepared::alloc_from_infos(module, &ggsw_apply); + ct_rhs_prepared.prepare(module, &ggsw_apply, scratch.borrow()); ggsw_out.external_product_inplace(module, &ct_rhs_prepared, scratch.borrow()); diff --git a/poulpy-core/src/tests/test_suite/external_product/glwe_ct.rs b/poulpy-core/src/tests/test_suite/external_product/glwe_ct.rs index 1a21a0c..0425d35 100644 --- a/poulpy-core/src/tests/test_suite/external_product/glwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/external_product/glwe_ct.rs @@ -1,66 +1,32 @@ use poulpy_hal::{ - api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, - VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, - VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotateInplace, VecZnxSub, VecZnxSubInplace, - VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, - }, - layouts::{Backend, Module, ScalarZnx, ScratchOwned, ZnxViewMut}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, - TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, - }, + api::{ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxFillUniform, VecZnxRotateInplace}, + layouts::{Backend, Module, ScalarZnx, Scratch, ScratchOwned, ZnxViewMut}, source::Source, }; use crate::{ + GGSWEncryptSk, GLWEEncryptSk, GLWEExternalProduct, GLWENoise, ScratchTakeCore, encryption::SIGMA, layouts::{ - GGSWCiphertext, GGSWCiphertextLayout, GLWECiphertext, GLWECiphertextLayout, GLWEPlaintext, GLWESecret, - prepared::{GGSWCiphertextPrepared, GLWESecretPrepared, PrepareAlloc}, + GGSW, GGSWLayout, GGSWPreparedFactory, GLWE, GLWELayout, GLWEPlaintext, GLWESecret, GLWESecretPreparedFactory, + prepared::{GGSWPrepared, GLWESecretPrepared}, }, noise::noise_ggsw_product, }; #[allow(clippy::too_many_arguments)] -pub fn test_glwe_external_product(module: &Module) +pub fn test_glwe_external_product(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume + Module: GGSWEncryptSk + + GGSWPreparedFactory + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxNormalizeTmpBytes - + VecZnxAddScalarInplace - + VecZnxRotateInplace - + VmpPMatAlloc - + VmpPrepare - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeSvpPPolImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl, + + GLWEExternalProduct + + GLWEEncryptSk + + GLWENoise + + VecZnxRotateInplace + + GLWESecretPreparedFactory, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, { let base2k: usize = 12; let k_in: usize = 45; @@ -73,21 +39,21 @@ where let n: usize = module.n(); let dnum: usize = k_in.div_ceil(base2k * dsize); - let glwe_in_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_in_infos: GLWELayout = GLWELayout { n: n.into(), base2k: base2k.into(), k: k_in.into(), rank: rank.into(), }; - let glwe_out_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_out_infos: GLWELayout = GLWELayout { n: n.into(), base2k: base2k.into(), k: k_out.into(), rank: rank.into(), }; - let ggsw_apply_infos: GGSWCiphertextLayout = GGSWCiphertextLayout { + let ggsw_apply_infos: GGSWLayout = GGSWLayout { n: n.into(), base2k: base2k.into(), k: k_ggsw.into(), @@ -96,11 +62,11 @@ where rank: rank.into(), }; - let mut ggsw_apply: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_apply_infos); - let mut glwe_in: GLWECiphertext> = GLWECiphertext::alloc(&glwe_in_infos); - let mut glwe_out: GLWECiphertext> = GLWECiphertext::alloc(&glwe_out_infos); + let mut ggsw_apply: GGSW> = GGSW::alloc_from_infos(&ggsw_apply_infos); + let mut glwe_in: GLWE> = GLWE::alloc_from_infos(&glwe_in_infos); + let mut glwe_out: GLWE> = GLWE::alloc_from_infos(&glwe_out_infos); let mut pt_ggsw: ScalarZnx> = ScalarZnx::alloc(n, 1); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_in_infos); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_in_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -115,15 +81,17 @@ where pt_ggsw.raw_mut()[k] = 1; // X^{k} - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, &ggsw_apply_infos) - | GLWECiphertext::encrypt_sk_scratch_space(module, &glwe_in_infos) - | GLWECiphertext::external_product_scratch_space(module, &glwe_out_infos, &glwe_in_infos, &ggsw_apply_infos), + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GGSW::encrypt_sk_tmp_bytes(module, &ggsw_apply_infos) + | GLWE::encrypt_sk_tmp_bytes(module, &glwe_in_infos) + | GLWE::external_product_tmp_bytes(module, &glwe_out_infos, &glwe_in_infos, &ggsw_apply_infos), ); - let mut sk: GLWESecret> = GLWESecret::alloc_with(n.into(), rank.into()); + let mut sk: GLWESecret> = GLWESecret::alloc(n.into(), rank.into()); sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); + + let mut sk_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank.into()); + sk_prepared.prepare(module, &sk); ggsw_apply.encrypt_sk( module, @@ -143,7 +111,8 @@ where scratch.borrow(), ); - let ct_ggsw_prepared: GGSWCiphertextPrepared, B> = ggsw_apply.prepare_alloc(module, scratch.borrow()); + let mut ct_ggsw_prepared: GGSWPrepared, BE> = GGSWPrepared::alloc_from_infos(module, &ggsw_apply); + ct_ggsw_prepared.prepare(module, &ggsw_apply, scratch.borrow()); glwe_out.external_product(module, &glwe_in, &ct_ggsw_prepared, scratch.borrow()); @@ -176,43 +145,18 @@ where } #[allow(clippy::too_many_arguments)] -pub fn test_glwe_external_product_inplace(module: &Module) +pub fn test_glwe_external_product_inplace(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume + Module: GGSWEncryptSk + + GGSWPreparedFactory + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxNormalizeTmpBytes - + VecZnxAddScalarInplace - + VecZnxRotateInplace - + VmpPMatAlloc - + VmpPrepare - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeSvpPPolImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl, + + GLWEExternalProduct + + GLWEEncryptSk + + GLWENoise + + VecZnxRotateInplace + + GLWESecretPreparedFactory, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, { let base2k: usize = 12; let k_out: usize = 60; @@ -224,14 +168,14 @@ where let n: usize = module.n(); let dnum: usize = k_out.div_ceil(base2k * dsize); - let glwe_out_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_out_infos: GLWELayout = GLWELayout { n: n.into(), base2k: base2k.into(), k: k_out.into(), rank: rank.into(), }; - let ggsw_apply_infos: GGSWCiphertextLayout = GGSWCiphertextLayout { + let ggsw_apply_infos: GGSWLayout = GGSWLayout { n: n.into(), base2k: base2k.into(), k: k_ggsw.into(), @@ -240,10 +184,10 @@ where rank: rank.into(), }; - let mut ggsw_apply: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_apply_infos); - let mut glwe_out: GLWECiphertext> = GLWECiphertext::alloc(&glwe_out_infos); + let mut ggsw_apply: GGSW> = GGSW::alloc_from_infos(&ggsw_apply_infos); + let mut glwe_out: GLWE> = GLWE::alloc_from_infos(&glwe_out_infos); let mut pt_ggsw: ScalarZnx> = ScalarZnx::alloc(n, 1); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_out_infos); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_out_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -258,15 +202,17 @@ where pt_ggsw.raw_mut()[k] = 1; // X^{k} - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, &ggsw_apply_infos) - | GLWECiphertext::encrypt_sk_scratch_space(module, &glwe_out_infos) - | GLWECiphertext::external_product_inplace_scratch_space(module, &glwe_out_infos, &ggsw_apply_infos), + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GGSW::encrypt_sk_tmp_bytes(module, &ggsw_apply_infos) + | GLWE::encrypt_sk_tmp_bytes(module, &glwe_out_infos) + | GLWE::external_product_tmp_bytes(module, &glwe_out_infos, &glwe_out_infos, &ggsw_apply_infos), ); - let mut sk: GLWESecret> = GLWESecret::alloc_with(n.into(), rank.into()); + let mut sk: GLWESecret> = GLWESecret::alloc(n.into(), rank.into()); sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); + + let mut sk_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank.into()); + sk_prepared.prepare(module, &sk); ggsw_apply.encrypt_sk( module, @@ -286,7 +232,8 @@ where scratch.borrow(), ); - let ct_ggsw_prepared: GGSWCiphertextPrepared, B> = ggsw_apply.prepare_alloc(module, scratch.borrow()); + let mut ct_ggsw_prepared: GGSWPrepared, BE> = GGSWPrepared::alloc_from_infos(module, &ggsw_apply); + ct_ggsw_prepared.prepare(module, &ggsw_apply, scratch.borrow()); glwe_out.external_product_inplace(module, &ct_ggsw_prepared, scratch.borrow()); diff --git a/poulpy-core/src/tests/test_suite/keyswitch/gglwe_ct.rs b/poulpy-core/src/tests/test_suite/keyswitch/gglwe_ct.rs index f131556..548d1f0 100644 --- a/poulpy-core/src/tests/test_suite/keyswitch/gglwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/keyswitch/gglwe_ct.rs @@ -1,68 +1,28 @@ use poulpy_hal::{ - api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, - VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VecZnxDftApply, - VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, - VecZnxSubInplace, VecZnxSubScalarInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, - VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, - }, - layouts::{Backend, Module, ScratchOwned}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, - TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, - }, + api::{ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow}, + layouts::{Backend, Module, Scratch, ScratchOwned}, source::Source, }; use crate::{ + GGLWEKeyswitch, GGLWENoise, GLWESwitchingKeyEncryptSk, ScratchTakeCore, encryption::SIGMA, layouts::{ - GGLWESwitchingKey, GGLWESwitchingKeyLayout, GLWESecret, - prepared::{GGLWESwitchingKeyPrepared, GLWESecretPrepared, PrepareAlloc}, + GLWESecret, GLWESecretPreparedFactory, GLWESwitchingKey, GLWESwitchingKeyLayout, GLWESwitchingKeyPreparedFactory, + prepared::{GLWESecretPrepared, GLWESwitchingKeyPrepared}, }, noise::log2_std_noise_gglwe_product, }; -pub fn test_gglwe_switching_key_keyswitch(module: &Module) +pub fn test_gglwe_switching_key_keyswitch(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxNormalizeTmpBytes - + VecZnxAddScalarInplace - + VmpPMatAlloc - + VmpPrepare - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxBigNormalizeTmpBytes - + VecZnxSwitchRing - + VecZnxSubScalarInplace, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeSvpPPolImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl, + Module: GLWESwitchingKeyEncryptSk + + GGLWEKeyswitch + + GLWESwitchingKeyPreparedFactory + + GLWESecretPreparedFactory + + GGLWENoise, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, { let base2k: usize = 12; let k_in: usize = 60; @@ -80,7 +40,7 @@ where let dnum_apply: usize = k_in.div_ceil(base2k * di); let dsize_in: usize = 1; - let gglwe_s0s1_infos: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + let gglwe_s0s1_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), base2k: base2k.into(), k: k_in.into(), @@ -90,7 +50,7 @@ where rank_out: rank_out_s0s1.into(), }; - let gglwe_s1s2_infos: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + let gglwe_s1s2_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), base2k: base2k.into(), k: k_ksk.into(), @@ -100,7 +60,7 @@ where rank_out: rank_out_s1s2.into(), }; - let gglwe_s0s2_infos: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + let gglwe_s0s2_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), base2k: base2k.into(), k: k_out.into(), @@ -110,35 +70,38 @@ where rank_out: rank_out_s1s2.into(), }; - let mut gglwe_s0s1: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&gglwe_s0s1_infos); - let mut gglwe_s1s2: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&gglwe_s1s2_infos); - let mut gglwe_s0s2: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&gglwe_s0s2_infos); + let mut gglwe_s0s1: GLWESwitchingKey> = GLWESwitchingKey::alloc_from_infos(&gglwe_s0s1_infos); + let mut gglwe_s1s2: GLWESwitchingKey> = GLWESwitchingKey::alloc_from_infos(&gglwe_s1s2_infos); + let mut gglwe_s0s2: GLWESwitchingKey> = GLWESwitchingKey::alloc_from_infos(&gglwe_s0s2_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch_enc: ScratchOwned = ScratchOwned::alloc( - GGLWESwitchingKey::encrypt_sk_scratch_space(module, &gglwe_s0s1_infos) - | GGLWESwitchingKey::encrypt_sk_scratch_space(module, &gglwe_s1s2_infos) - | GGLWESwitchingKey::encrypt_sk_scratch_space(module, &gglwe_s0s2_infos), + let mut scratch_enc: ScratchOwned = ScratchOwned::alloc( + GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &gglwe_s0s1_infos) + | GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &gglwe_s1s2_infos) + | GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &gglwe_s0s2_infos), ); - let mut scratch_apply: ScratchOwned = ScratchOwned::alloc(GGLWESwitchingKey::keyswitch_scratch_space( + let mut scratch_apply: ScratchOwned = ScratchOwned::alloc(GLWESwitchingKey::keyswitch_tmp_bytes( module, &gglwe_s0s1_infos, &gglwe_s0s2_infos, &gglwe_s1s2_infos, )); - let mut sk0: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_in_s0s1.into()); + let mut sk0: GLWESecret> = GLWESecret::alloc(n.into(), rank_in_s0s1.into()); sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_out_s0s1.into()); + let mut sk1: GLWESecret> = GLWESecret::alloc(n.into(), rank_out_s0s1.into()); sk1.fill_ternary_prob(0.5, &mut source_xs); - let mut sk2: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_out_s1s2.into()); + let mut sk2: GLWESecret> = GLWESecret::alloc(n.into(), rank_out_s1s2.into()); sk2.fill_ternary_prob(0.5, &mut source_xs); - let sk2_prepared: GLWESecretPrepared, B> = sk2.prepare_alloc(module, scratch_apply.borrow()); + + let mut sk2_prepared: GLWESecretPrepared, BE> = + GLWESecretPrepared::alloc(module, rank_out_s1s2.into()); + sk2_prepared.prepare(module, &sk2); // gglwe_{s1}(s0) = s0 -> s1 gglwe_s0s1.encrypt_sk( @@ -160,8 +123,9 @@ where scratch_enc.borrow(), ); - let gglwe_s1s2_prepared: GGLWESwitchingKeyPrepared, B> = - gglwe_s1s2.prepare_alloc(module, scratch_apply.borrow()); + let mut gglwe_s1s2_prepared: GLWESwitchingKeyPrepared, BE> = + GLWESwitchingKeyPrepared::alloc_from_infos(module, &gglwe_s1s2); + gglwe_s1s2_prepared.prepare(module, &gglwe_s1s2, scratch_apply.borrow()); // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) gglwe_s0s2.keyswitch( @@ -194,45 +158,15 @@ where } #[allow(clippy::too_many_arguments)] -pub fn test_gglwe_switching_key_keyswitch_inplace(module: &Module) +pub fn test_gglwe_switching_key_keyswitch_inplace(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxNormalizeTmpBytes - + VecZnxAddScalarInplace - + VmpPMatAlloc - + VmpPrepare - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxBigNormalizeTmpBytes - + VecZnxSwitchRing - + VecZnxSubScalarInplace, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeSvpPPolImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl, + Module: GLWESwitchingKeyEncryptSk + + GGLWEKeyswitch + + GLWESecretPreparedFactory + + GGLWENoise + + GLWESwitchingKeyPreparedFactory, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, { let base2k: usize = 12; let k_out: usize = 60; @@ -246,7 +180,7 @@ where let dnum: usize = k_out.div_ceil(base2k * di); let dsize_in: usize = 1; - let gglwe_s0s1_infos: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + let gglwe_s0s1_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), base2k: base2k.into(), k: k_out.into(), @@ -256,7 +190,7 @@ where rank_out: rank_out.into(), }; - let gglwe_s1s2_infos: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + let gglwe_s1s2_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), base2k: base2k.into(), k: k_ksk.into(), @@ -266,34 +200,37 @@ where rank_out: rank_out.into(), }; - let mut gglwe_s0s1: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&gglwe_s0s1_infos); - let mut gglwe_s1s2: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&gglwe_s1s2_infos); + let mut gglwe_s0s1: GLWESwitchingKey> = GLWESwitchingKey::alloc_from_infos(&gglwe_s0s1_infos); + let mut gglwe_s1s2: GLWESwitchingKey> = GLWESwitchingKey::alloc_from_infos(&gglwe_s1s2_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch_enc: ScratchOwned = ScratchOwned::alloc( - GGLWESwitchingKey::encrypt_sk_scratch_space(module, &gglwe_s0s1_infos) - | GGLWESwitchingKey::encrypt_sk_scratch_space(module, &gglwe_s1s2_infos), + let mut scratch_enc: ScratchOwned = ScratchOwned::alloc( + GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &gglwe_s0s1_infos) + | GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &gglwe_s1s2_infos), ); - let mut scratch_apply: ScratchOwned = ScratchOwned::alloc(GGLWESwitchingKey::keyswitch_inplace_scratch_space( + let mut scratch_apply: ScratchOwned = ScratchOwned::alloc(GLWESwitchingKey::keyswitch_tmp_bytes( module, &gglwe_s0s1_infos, + &gglwe_s0s1_infos, &gglwe_s1s2_infos, )); let var_xs: f64 = 0.5; - let mut sk0: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_in.into()); + let mut sk0: GLWESecret> = GLWESecret::alloc(n.into(), rank_in.into()); sk0.fill_ternary_prob(var_xs, &mut source_xs); - let mut sk1: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_out.into()); + let mut sk1: GLWESecret> = GLWESecret::alloc(n.into(), rank_out.into()); sk1.fill_ternary_prob(var_xs, &mut source_xs); - let mut sk2: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_out.into()); + let mut sk2: GLWESecret> = GLWESecret::alloc(n.into(), rank_out.into()); sk2.fill_ternary_prob(var_xs, &mut source_xs); - let sk2_prepared: GLWESecretPrepared, B> = sk2.prepare_alloc(module, scratch_apply.borrow()); + + let mut sk2_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank_out.into()); + sk2_prepared.prepare(module, &sk2); // gglwe_{s1}(s0) = s0 -> s1 gglwe_s0s1.encrypt_sk( @@ -315,13 +252,14 @@ where scratch_enc.borrow(), ); - let gglwe_s1s2_prepared: GGLWESwitchingKeyPrepared, B> = - gglwe_s1s2.prepare_alloc(module, scratch_apply.borrow()); + let mut gglwe_s1s2_prepared: GLWESwitchingKeyPrepared, BE> = + GLWESwitchingKeyPrepared::alloc_from_infos(module, &gglwe_s1s2); + gglwe_s1s2_prepared.prepare(module, &gglwe_s1s2, scratch_apply.borrow()); // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) gglwe_s0s1.keyswitch_inplace(module, &gglwe_s1s2_prepared, scratch_apply.borrow()); - let gglwe_s0s2: GGLWESwitchingKey> = gglwe_s0s1; + let gglwe_s0s2: GLWESwitchingKey> = gglwe_s0s1; let max_noise: f64 = log2_std_noise_gglwe_product( n as f64, diff --git a/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs b/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs index c7e7c82..b582d89 100644 --- a/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs +++ b/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs @@ -1,75 +1,33 @@ use poulpy_hal::{ - api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, - SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, - VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAddInplace, VecZnxDftAlloc, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, - VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, - VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, - }, - layouts::{Backend, Module, ScalarZnx, ScratchOwned}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, - TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, - }, + api::{ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow}, + layouts::{Backend, Module, ScalarZnx, Scratch, ScratchOwned}, source::Source, }; use crate::{ + GGSWEncryptSk, GGSWKeyswitch, GGSWNoise, GLWESwitchingKeyEncryptSk, GLWETensorKeyEncryptSk, ScratchTakeCore, encryption::SIGMA, layouts::{ - GGLWESwitchingKey, GGLWESwitchingKeyLayout, GGLWETensorKey, GGLWETensorKeyLayout, GGSWCiphertext, GGSWCiphertextLayout, - GLWESecret, - prepared::{GGLWESwitchingKeyPrepared, GGLWETensorKeyPrepared, GLWESecretPrepared, PrepareAlloc}, + GGSW, GGSWLayout, GLWESecret, GLWESecretPreparedFactory, GLWESwitchingKey, GLWESwitchingKeyLayout, + GLWESwitchingKeyPreparedFactory, GLWETensorKey, GLWETensorKeyLayout, GLWETensorKeyPreparedFactory, + prepared::{GLWESecretPrepared, GLWESwitchingKeyPrepared, GLWETensorKeyPrepared}, }, noise::noise_ggsw_keyswitch, }; #[allow(clippy::too_many_arguments)] -pub fn test_ggsw_keyswitch(module: &Module) +pub fn test_ggsw_keyswitch(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxNormalizeTmpBytes - + VecZnxAddScalarInplace - + VmpPMatAlloc - + VmpPrepare - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxBigNormalizeTmpBytes - + VecZnxSwitchRing - + SvpApplyDftToDft - + VecZnxIdftApplyTmpA - + VecZnxDftCopy - + VecZnxDftAddInplace - + VecZnxBigAlloc - + VecZnxDftAlloc, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeSvpPPolImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl, + Module: GGSWEncryptSk + + GLWESwitchingKeyEncryptSk + + GLWETensorKeyEncryptSk + + GGSWKeyswitch + + GLWESecretPreparedFactory + + GLWETensorKeyPreparedFactory + + GLWESwitchingKeyPreparedFactory + + GGSWNoise, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, { let base2k: usize = 12; let k_in: usize = 54; @@ -85,7 +43,7 @@ where let dsize_in: usize = 1; - let ggsw_in_infos: GGSWCiphertextLayout = GGSWCiphertextLayout { + let ggsw_in_infos: GGSWLayout = GGSWLayout { n: n.into(), base2k: base2k.into(), k: k_in.into(), @@ -94,7 +52,7 @@ where rank: rank.into(), }; - let ggsw_out_infos: GGSWCiphertextLayout = GGSWCiphertextLayout { + let ggsw_out_infos: GGSWLayout = GGSWLayout { n: n.into(), base2k: base2k.into(), k: k_out.into(), @@ -103,7 +61,7 @@ where rank: rank.into(), }; - let tsk_infos: GGLWETensorKeyLayout = GGLWETensorKeyLayout { + let tsk_infos: GLWETensorKeyLayout = GLWETensorKeyLayout { n: n.into(), base2k: base2k.into(), k: k_tsk.into(), @@ -112,7 +70,7 @@ where rank: rank.into(), }; - let ksk_apply_infos: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + let ksk_apply_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), base2k: base2k.into(), k: k_ksk.into(), @@ -122,21 +80,21 @@ where rank_out: rank.into(), }; - let mut ggsw_in: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_in_infos); - let mut ggsw_out: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_out_infos); - let mut tsk: GGLWETensorKey> = GGLWETensorKey::alloc(&tsk_infos); - let mut ksk: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&ksk_apply_infos); + let mut ggsw_in: GGSW> = GGSW::alloc_from_infos(&ggsw_in_infos); + let mut ggsw_out: GGSW> = GGSW::alloc_from_infos(&ggsw_out_infos); + let mut tsk: GLWETensorKey> = GLWETensorKey::alloc_from_infos(&tsk_infos); + let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc_from_infos(&ksk_apply_infos); let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, &ggsw_in_infos) - | GGLWESwitchingKey::encrypt_sk_scratch_space(module, &ksk_apply_infos) - | GGLWETensorKey::encrypt_sk_scratch_space(module, &tsk_infos) - | GGSWCiphertext::keyswitch_scratch_space( + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GGSW::encrypt_sk_tmp_bytes(module, &ggsw_in_infos) + | GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &ksk_apply_infos) + | GLWETensorKey::encrypt_sk_tmp_bytes(module, &tsk_infos) + | GGSW::keyswitch_tmp_bytes( module, &ggsw_out_infos, &ggsw_in_infos, @@ -147,13 +105,17 @@ where let var_xs: f64 = 0.5; - let mut sk_in: GLWESecret> = GLWESecret::alloc_with(n.into(), rank.into()); + let mut sk_in: GLWESecret> = GLWESecret::alloc(n.into(), rank.into()); sk_in.fill_ternary_prob(var_xs, &mut source_xs); - let sk_in_dft: GLWESecretPrepared, B> = sk_in.prepare_alloc(module, scratch.borrow()); - let mut sk_out: GLWESecret> = GLWESecret::alloc_with(n.into(), rank.into()); + let mut sk_in_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank.into()); + sk_in_prepared.prepare(module, &sk_in); + + let mut sk_out: GLWESecret> = GLWESecret::alloc(n.into(), rank.into()); sk_out.fill_ternary_prob(var_xs, &mut source_xs); - let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); + + let mut sk_out_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank.into()); + sk_out_prepared.prepare(module, &sk_out); ksk.encrypt_sk( module, @@ -176,14 +138,18 @@ where ggsw_in.encrypt_sk( module, &pt_scalar, - &sk_in_dft, + &sk_in_prepared, &mut source_xa, &mut source_xe, scratch.borrow(), ); - let ksk_prepared: GGLWESwitchingKeyPrepared, B> = ksk.prepare_alloc(module, scratch.borrow()); - let tsk_prepared: GGLWETensorKeyPrepared, B> = tsk.prepare_alloc(module, scratch.borrow()); + let mut ksk_prepared: GLWESwitchingKeyPrepared, BE> = + GLWESwitchingKeyPrepared::alloc_from_infos(module, &ksk); + ksk_prepared.prepare(module, &ksk, scratch.borrow()); + + let mut tsk_prepared: GLWETensorKeyPrepared, BE> = GLWETensorKeyPrepared::alloc_from_infos(module, &tsk); + tsk_prepared.prepare(module, &tsk, scratch.borrow()); ggsw_out.keyswitch( module, @@ -215,50 +181,18 @@ where } #[allow(clippy::too_many_arguments)] -pub fn test_ggsw_keyswitch_inplace(module: &Module) +pub fn test_ggsw_keyswitch_inplace(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxNormalizeTmpBytes - + VecZnxAddScalarInplace - + VmpPMatAlloc - + VmpPrepare - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxBigNormalizeTmpBytes - + VecZnxSwitchRing - + SvpApplyDftToDft - + VecZnxIdftApplyTmpA - + VecZnxDftCopy - + VecZnxDftAddInplace - + VecZnxBigAlloc - + VecZnxDftAlloc, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeSvpPPolImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl, + Module: GGSWEncryptSk + + GLWESwitchingKeyEncryptSk + + GLWETensorKeyEncryptSk + + GGSWKeyswitch + + GLWESecretPreparedFactory + + GLWETensorKeyPreparedFactory + + GLWESwitchingKeyPreparedFactory + + GGSWNoise, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, { let base2k: usize = 12; let k_out: usize = 54; @@ -273,7 +207,7 @@ where let dsize_in: usize = 1; - let ggsw_out_infos: GGSWCiphertextLayout = GGSWCiphertextLayout { + let ggsw_out_infos: GGSWLayout = GGSWLayout { n: n.into(), base2k: base2k.into(), k: k_out.into(), @@ -282,7 +216,7 @@ where rank: rank.into(), }; - let tsk_infos: GGLWETensorKeyLayout = GGLWETensorKeyLayout { + let tsk_infos: GLWETensorKeyLayout = GLWETensorKeyLayout { n: n.into(), base2k: base2k.into(), k: k_tsk.into(), @@ -291,7 +225,7 @@ where rank: rank.into(), }; - let ksk_apply_infos: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + let ksk_apply_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), base2k: base2k.into(), k: k_ksk.into(), @@ -301,31 +235,41 @@ where rank_out: rank.into(), }; - let mut ggsw_out: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_out_infos); - let mut tsk: GGLWETensorKey> = GGLWETensorKey::alloc(&tsk_infos); - let mut ksk: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&ksk_apply_infos); + let mut ggsw_out: GGSW> = GGSW::alloc_from_infos(&ggsw_out_infos); + let mut tsk: GLWETensorKey> = GLWETensorKey::alloc_from_infos(&tsk_infos); + let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc_from_infos(&ksk_apply_infos); let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, &ggsw_out_infos) - | GGLWESwitchingKey::encrypt_sk_scratch_space(module, &ksk_apply_infos) - | GGLWETensorKey::encrypt_sk_scratch_space(module, &tsk_infos) - | GGSWCiphertext::keyswitch_inplace_scratch_space(module, &ggsw_out_infos, &ksk_apply_infos, &tsk_infos), + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GGSW::encrypt_sk_tmp_bytes(module, &ggsw_out_infos) + | GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &ksk_apply_infos) + | GLWETensorKey::encrypt_sk_tmp_bytes(module, &tsk_infos) + | GGSW::keyswitch_tmp_bytes( + module, + &ggsw_out_infos, + &ggsw_out_infos, + &ksk_apply_infos, + &tsk_infos, + ), ); let var_xs: f64 = 0.5; - let mut sk_in: GLWESecret> = GLWESecret::alloc_with(n.into(), rank.into()); + let mut sk_in: GLWESecret> = GLWESecret::alloc(n.into(), rank.into()); sk_in.fill_ternary_prob(var_xs, &mut source_xs); - let sk_in_dft: GLWESecretPrepared, B> = sk_in.prepare_alloc(module, scratch.borrow()); - let mut sk_out: GLWESecret> = GLWESecret::alloc_with(n.into(), rank.into()); + let mut sk_in_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank.into()); + sk_in_prepared.prepare(module, &sk_in); + + let mut sk_out: GLWESecret> = GLWESecret::alloc(n.into(), rank.into()); sk_out.fill_ternary_prob(var_xs, &mut source_xs); - let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); + + let mut sk_out_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank.into()); + sk_out_prepared.prepare(module, &sk_out); ksk.encrypt_sk( module, @@ -348,14 +292,18 @@ where ggsw_out.encrypt_sk( module, &pt_scalar, - &sk_in_dft, + &sk_in_prepared, &mut source_xa, &mut source_xe, scratch.borrow(), ); - let ksk_prepared: GGLWESwitchingKeyPrepared, B> = ksk.prepare_alloc(module, scratch.borrow()); - let tsk_prepared: GGLWETensorKeyPrepared, B> = tsk.prepare_alloc(module, scratch.borrow()); + let mut ksk_prepared: GLWESwitchingKeyPrepared, BE> = + GLWESwitchingKeyPrepared::alloc_from_infos(module, &ksk); + ksk_prepared.prepare(module, &ksk, scratch.borrow()); + + let mut tsk_prepared: GLWETensorKeyPrepared, BE> = GLWETensorKeyPrepared::alloc_from_infos(module, &tsk); + tsk_prepared.prepare(module, &tsk, scratch.borrow()); ggsw_out.keyswitch_inplace(module, &ksk_prepared, &tsk_prepared, scratch.borrow()); diff --git a/poulpy-core/src/tests/test_suite/keyswitch/glwe_ct.rs b/poulpy-core/src/tests/test_suite/keyswitch/glwe_ct.rs index d745f1d..90cc543 100644 --- a/poulpy-core/src/tests/test_suite/keyswitch/glwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/keyswitch/glwe_ct.rs @@ -1,68 +1,32 @@ use poulpy_hal::{ - api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, - VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VecZnxDftApply, - VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, - VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, - VmpPrepare, - }, - layouts::{Backend, Module, ScratchOwned}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, - TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, - }, + api::{ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxFillUniform}, + layouts::{Backend, Module, Scratch, ScratchOwned}, source::Source, }; use crate::{ + GLWEEncryptSk, GLWEKeyswitch, GLWENoise, GLWESwitchingKeyEncryptSk, ScratchTakeCore, encryption::SIGMA, layouts::{ - GGLWESwitchingKey, GGLWESwitchingKeyLayout, GLWECiphertext, GLWECiphertextLayout, GLWEPlaintext, GLWESecret, - prepared::{GGLWESwitchingKeyPrepared, GLWESecretPrepared, PrepareAlloc}, + GLWE, GLWELayout, GLWEPlaintext, GLWESecret, GLWESecretPreparedFactory, GLWESwitchingKey, GLWESwitchingKeyLayout, + GLWESwitchingKeyPreparedFactory, + prepared::{GLWESecretPrepared, GLWESwitchingKeyPrepared}, }, noise::log2_std_noise_gglwe_product, }; #[allow(clippy::too_many_arguments)] -pub fn test_glwe_keyswitch(module: &Module) +pub fn test_glwe_keyswitch(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxNormalizeTmpBytes - + VecZnxAddScalarInplace - + VmpPMatAlloc - + VmpPrepare - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxBigNormalizeTmpBytes - + VecZnxSwitchRing, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeSvpPPolImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl, + Module: VecZnxFillUniform + + GLWESwitchingKeyEncryptSk + + GLWEEncryptSk + + GLWEKeyswitch + + GLWESecretPreparedFactory + + GLWESwitchingKeyPreparedFactory + + GLWENoise, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, { let base2k: usize = 12; let k_in: usize = 45; @@ -77,21 +41,21 @@ where let n: usize = module.n(); let dnum: usize = k_in.div_ceil(base2k * dsize); - let glwe_in_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_in_infos: GLWELayout = GLWELayout { n: n.into(), base2k: base2k.into(), k: k_in.into(), rank: rank_in.into(), }; - let glwe_out_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_out_infos: GLWELayout = GLWELayout { n: n.into(), base2k: base2k.into(), k: k_out.into(), rank: rank_out.into(), }; - let key_apply: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + let ksk: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), base2k: base2k.into(), k: k_ksk.into(), @@ -101,10 +65,10 @@ where rank_out: rank_out.into(), }; - let mut ksk: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&key_apply); - let mut glwe_in: GLWECiphertext> = GLWECiphertext::alloc(&glwe_in_infos); - let mut glwe_out: GLWECiphertext> = GLWECiphertext::alloc(&glwe_out_infos); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_in_infos); + let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc_from_infos(&ksk); + let mut glwe_in: GLWE> = GLWE::alloc_from_infos(&glwe_in_infos); + let mut glwe_out: GLWE> = GLWE::alloc_from_infos(&glwe_out_infos); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_in_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -112,19 +76,23 @@ where module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa); - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWESwitchingKey::encrypt_sk_scratch_space(module, &key_apply) - | GLWECiphertext::encrypt_sk_scratch_space(module, &glwe_in_infos) - | GLWECiphertext::keyswitch_scratch_space(module, &glwe_out_infos, &glwe_in_infos, &key_apply), + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &ksk) + | GLWE::encrypt_sk_tmp_bytes(module, &glwe_in_infos) + | GLWE::keyswitch_tmp_bytes(module, &glwe_out_infos, &glwe_in_infos, &ksk), ); - let mut sk_in: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_in.into()); + let mut sk_in: GLWESecret> = GLWESecret::alloc(n.into(), rank_in.into()); sk_in.fill_ternary_prob(0.5, &mut source_xs); - let sk_in_prepared: GLWESecretPrepared, B> = sk_in.prepare_alloc(module, scratch.borrow()); - let mut sk_out: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_out.into()); + let mut sk_in_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank_in.into()); + sk_in_prepared.prepare(module, &sk_in); + + let mut sk_out: GLWESecret> = GLWESecret::alloc(n.into(), rank_out.into()); sk_out.fill_ternary_prob(0.5, &mut source_xs); - let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); + + let mut sk_out_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank_out.into()); + sk_out_prepared.prepare(module, &sk_out); ksk.encrypt_sk( module, @@ -144,7 +112,9 @@ where scratch.borrow(), ); - let ksk_prepared: GGLWESwitchingKeyPrepared, B> = ksk.prepare_alloc(module, scratch.borrow()); + let mut ksk_prepared: GLWESwitchingKeyPrepared, BE> = + GLWESwitchingKeyPrepared::alloc_from_infos(module, &ksk); + ksk_prepared.prepare(module, &ksk, scratch.borrow()); glwe_out.keyswitch(module, &glwe_in, &ksk_prepared, scratch.borrow()); @@ -167,44 +137,17 @@ where } } -pub fn test_glwe_keyswitch_inplace(module: &Module) +pub fn test_glwe_keyswitch_inplace(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxNormalizeTmpBytes - + VecZnxAddScalarInplace - + VmpPMatAlloc - + VmpPrepare - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxBigNormalizeTmpBytes - + VecZnxSwitchRing, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeSvpPPolImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl, + Module: VecZnxFillUniform + + GLWESwitchingKeyEncryptSk + + GLWEEncryptSk + + GLWEKeyswitch + + GLWESecretPreparedFactory + + GLWESwitchingKeyPreparedFactory + + GLWENoise, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, { let base2k: usize = 12; let k_out: usize = 45; @@ -217,14 +160,14 @@ where let n: usize = module.n(); let dnum: usize = k_out.div_ceil(base2k * dsize); - let glwe_out_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_out_infos: GLWELayout = GLWELayout { n: n.into(), base2k: base2k.into(), k: k_out.into(), rank: rank.into(), }; - let key_apply_infos: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + let ksk_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), base2k: base2k.into(), k: k_ksk.into(), @@ -234,9 +177,9 @@ where rank_out: rank.into(), }; - let mut key_apply: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&key_apply_infos); - let mut glwe_out: GLWECiphertext> = GLWECiphertext::alloc(&glwe_out_infos); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_out_infos); + let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc_from_infos(&ksk_infos); + let mut glwe_out: GLWE> = GLWE::alloc_from_infos(&glwe_out_infos); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_out_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -244,21 +187,25 @@ where module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa); - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWESwitchingKey::encrypt_sk_scratch_space(module, &key_apply_infos) - | GLWECiphertext::encrypt_sk_scratch_space(module, &glwe_out_infos) - | GLWECiphertext::keyswitch_inplace_scratch_space(module, &glwe_out_infos, &key_apply_infos), + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &ksk_infos) + | GLWE::encrypt_sk_tmp_bytes(module, &glwe_out_infos) + | GLWE::keyswitch_tmp_bytes(module, &glwe_out_infos, &glwe_out_infos, &ksk_infos), ); - let mut sk_in: GLWESecret> = GLWESecret::alloc_with(n.into(), rank.into()); + let mut sk_in: GLWESecret> = GLWESecret::alloc(n.into(), rank.into()); sk_in.fill_ternary_prob(0.5, &mut source_xs); - let sk_in_prepared: GLWESecretPrepared, B> = sk_in.prepare_alloc(module, scratch.borrow()); - let mut sk_out: GLWESecret> = GLWESecret::alloc_with(n.into(), rank.into()); + let mut sk_in_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank.into()); + sk_in_prepared.prepare(module, &sk_in); + + let mut sk_out: GLWESecret> = GLWESecret::alloc(n.into(), rank.into()); sk_out.fill_ternary_prob(0.5, &mut source_xs); - let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); - key_apply.encrypt_sk( + let mut sk_out_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank.into()); + sk_out_prepared.prepare(module, &sk_out); + + ksk.encrypt_sk( module, &sk_in, &sk_out, @@ -276,7 +223,9 @@ where scratch.borrow(), ); - let ksk_prepared: GGLWESwitchingKeyPrepared, B> = key_apply.prepare_alloc(module, scratch.borrow()); + let mut ksk_prepared: GLWESwitchingKeyPrepared, BE> = + GLWESwitchingKeyPrepared::alloc_from_infos(module, &ksk); + ksk_prepared.prepare(module, &ksk, scratch.borrow()); glwe_out.keyswitch_inplace(module, &ksk_prepared, scratch.borrow()); diff --git a/poulpy-core/src/tests/test_suite/keyswitch/lwe_ct.rs b/poulpy-core/src/tests/test_suite/keyswitch/lwe_ct.rs index 1c29c70..7617b09 100644 --- a/poulpy-core/src/tests/test_suite/keyswitch/lwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/keyswitch/lwe_ct.rs @@ -1,68 +1,23 @@ use poulpy_hal::{ - api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigAddInplace, - VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, - VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, ZnAddNormal, ZnFillUniform, ZnNormalizeInplace, - }, - layouts::{Backend, Module, ScratchOwned, ZnxView}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, - TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, - }, + api::{ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow}, + layouts::{Backend, Module, Scratch, ScratchOwned, ZnxView}, source::Source, }; -use crate::layouts::{ - LWECiphertext, LWECiphertextLayout, LWEPlaintext, LWESecret, LWESwitchingKey, LWESwitchingKeyLayout, - prepared::{LWESwitchingKeyPrepared, PrepareAlloc}, +use crate::{ + LWEDecrypt, LWEEncryptSk, LWEKeySwitch, LWESwitchingKeyEncrypt, ScratchTakeCore, + layouts::{ + LWE, LWELayout, LWEPlaintext, LWESecret, LWESwitchingKey, LWESwitchingKeyLayout, LWESwitchingKeyPreparedFactory, + prepared::LWESwitchingKeyPrepared, + }, }; -pub fn test_lwe_keyswitch(module: &Module) +pub fn test_lwe_keyswitch(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxNormalizeTmpBytes - + VecZnxAddScalarInplace - + VmpPMatAlloc - + VmpPrepare - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxBigNormalizeTmpBytes - + VecZnxSwitchRing - + VecZnxAutomorphismInplace - + ZnNormalizeInplace - + ZnFillUniform - + ZnAddNormal - + VecZnxCopy, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeSvpPPolImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl, + Module: + LWEKeySwitch + LWESwitchingKeyEncrypt + LWEEncryptSk + LWESwitchingKeyPreparedFactory + LWEDecrypt, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, { let n: usize = module.n(); let base2k: usize = 17; @@ -86,21 +41,21 @@ where dnum: dnum.into(), }; - let lwe_in_infos: LWECiphertextLayout = LWECiphertextLayout { + let lwe_in_infos: LWELayout = LWELayout { n: n_lwe_in.into(), base2k: base2k.into(), k: k_lwe_ct.into(), }; - let lwe_out_infos: LWECiphertextLayout = LWECiphertextLayout { + let lwe_out_infos: LWELayout = LWELayout { n: n_lwe_out.into(), k: k_lwe_ct.into(), base2k: base2k.into(), }; - let mut scratch: ScratchOwned = ScratchOwned::alloc( - LWESwitchingKey::encrypt_sk_scratch_space(module, &key_apply_infos) - | LWECiphertext::keyswitch_scratch_space(module, &lwe_out_infos, &lwe_in_infos, &key_apply_infos), + let mut scratch: ScratchOwned = ScratchOwned::alloc( + LWESwitchingKey::encrypt_sk_tmp_bytes(module, &key_apply_infos) + | LWE::keyswitch_tmp_bytes(module, &lwe_out_infos, &lwe_in_infos, &key_apply_infos), ); let mut sk_lwe_in: LWESecret> = LWESecret::alloc(n_lwe_in.into()); @@ -111,10 +66,10 @@ where let data: i64 = 17; - let mut lwe_pt_in: LWEPlaintext> = LWEPlaintext::alloc_with(base2k.into(), k_lwe_pt.into()); + let mut lwe_pt_in: LWEPlaintext> = LWEPlaintext::alloc(base2k.into(), k_lwe_pt.into()); lwe_pt_in.encode_i64(data, k_lwe_pt.into()); - let mut lwe_ct_in: LWECiphertext> = LWECiphertext::alloc(&lwe_in_infos); + let mut lwe_ct_in: LWE> = LWE::alloc_from_infos(&lwe_in_infos); lwe_ct_in.encrypt_sk( module, &lwe_pt_in, @@ -123,7 +78,7 @@ where &mut source_xe, ); - let mut ksk: LWESwitchingKey> = LWESwitchingKey::alloc(&key_apply_infos); + let mut ksk: LWESwitchingKey> = LWESwitchingKey::alloc_from_infos(&key_apply_infos); ksk.encrypt_sk( module, @@ -134,13 +89,14 @@ where scratch.borrow(), ); - let mut lwe_ct_out: LWECiphertext> = LWECiphertext::alloc(&lwe_out_infos); + let mut lwe_ct_out: LWE> = LWE::alloc_from_infos(&lwe_out_infos); - let ksk_prepared: LWESwitchingKeyPrepared, B> = ksk.prepare_alloc(module, scratch.borrow()); + let mut ksk_prepared: LWESwitchingKeyPrepared, BE> = LWESwitchingKeyPrepared::alloc_from_infos(module, &ksk); + ksk_prepared.prepare(module, &ksk, scratch.borrow()); lwe_ct_out.keyswitch(module, &lwe_ct_in, &ksk_prepared, scratch.borrow()); - let mut lwe_pt_out: LWEPlaintext> = LWEPlaintext::alloc(&lwe_out_infos); + let mut lwe_pt_out: LWEPlaintext> = LWEPlaintext::alloc_from_infos(&lwe_out_infos); lwe_ct_out.decrypt(module, &mut lwe_pt_out, &sk_lwe_out); assert_eq!(lwe_pt_in.data.at(0, 0)[0], lwe_pt_out.data.at(0, 0)[0]); diff --git a/poulpy-core/src/tests/test_suite/packing.rs b/poulpy-core/src/tests/test_suite/packing.rs index de7fc9d..029e059 100644 --- a/poulpy-core/src/tests/test_suite/packing.rs +++ b/poulpy-core/src/tests/test_suite/packing.rs @@ -1,78 +1,32 @@ use std::collections::HashMap; use poulpy_hal::{ - api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, - VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigAutomorphismInplace, VecZnxBigNormalize, - VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, - VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubInplace, - VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, - }, - layouts::{Backend, Module, ScratchOwned}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, - TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, - }, + api::{ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow}, + layouts::{Backend, Module, Scratch, ScratchOwned}, source::Source, }; use crate::{ - GLWEOperations, GLWEPacker, + GLWEAutomorphismKeyEncryptSk, GLWEDecrypt, GLWEEncryptSk, GLWEPacker, GLWEPacking, GLWERotate, GLWESub, ScratchTakeCore, layouts::{ - GGLWEAutomorphismKey, GGLWEAutomorphismKeyLayout, GLWECiphertext, GLWECiphertextLayout, GLWEPlaintext, GLWESecret, - prepared::{GGLWEAutomorphismKeyPrepared, GLWESecretPrepared, PrepareAlloc}, + GLWE, GLWEAutomorphismKey, GLWEAutomorphismKeyLayout, GLWEAutomorphismKeyPreparedFactory, GLWELayout, GLWEPlaintext, + GLWESecret, GLWESecretPreparedFactory, + prepared::{GLWEAutomorphismKeyPrepared, GLWESecretPrepared}, }, }; -pub fn test_glwe_packing(module: &Module) +pub fn test_glwe_packing(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxAutomorphism - + VecZnxBigAutomorphismInplace - + VecZnxBigSubSmallNegateInplace - + VecZnxNegateInplace - + VecZnxRshInplace - + VecZnxRotateInplace - + VecZnxBigNormalize - + VecZnxDftApply - + VecZnxRotate - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxNormalizeTmpBytes - + VecZnxAddScalarInplace - + VmpPMatAlloc - + VmpPrepare - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxBigNormalizeTmpBytes - + VecZnxSwitchRing - + VecZnxAutomorphismInplace - + VecZnxCopy, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeSvpPPolImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl, + Module: GLWEEncryptSk + + GLWEAutomorphismKeyEncryptSk + + GLWEAutomorphismKeyPreparedFactory + + GLWEPacking + + GLWESecretPreparedFactory + + GLWESub + + GLWEDecrypt + + GLWERotate, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, { let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -88,14 +42,14 @@ where let dnum: usize = k_ct.div_ceil(base2k * dsize); - let glwe_out_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_out_infos: GLWELayout = GLWELayout { n: n.into(), base2k: base2k.into(), k: k_ct.into(), rank: rank.into(), }; - let key_infos: GGLWEAutomorphismKeyLayout = GGLWEAutomorphismKeyLayout { + let key_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n: n.into(), base2k: base2k.into(), k: k_ksk.into(), @@ -104,17 +58,19 @@ where dnum: dnum.into(), }; - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::encrypt_sk_scratch_space(module, &glwe_out_infos) - | GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, &key_infos) - | GLWEPacker::scratch_space(module, &glwe_out_infos, &key_infos), + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GLWE::encrypt_sk_tmp_bytes(module, &glwe_out_infos) + | GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &key_infos) + | GLWEPacker::tmp_bytes(module, &glwe_out_infos, &key_infos), ); - let mut sk: GLWESecret> = GLWESecret::alloc(&glwe_out_infos); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_out_infos); sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_dft: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_out_infos); + let mut sk_dft: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc_from_infos(module, &sk); + sk_dft.prepare(module, &sk); + + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_out_infos); let mut data: Vec = vec![0i64; n]; data.iter_mut().enumerate().for_each(|(i, x)| { *x = i as i64; @@ -124,8 +80,8 @@ where let gal_els: Vec = GLWEPacker::galois_elements(module); - let mut auto_keys: HashMap, B>> = HashMap::new(); - let mut tmp: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&key_infos); + let mut auto_keys: HashMap, BE>> = HashMap::new(); + let mut tmp: GLWEAutomorphismKey> = GLWEAutomorphismKey::alloc_from_infos(&key_infos); gal_els.iter().for_each(|gal_el| { tmp.encrypt_sk( module, @@ -135,15 +91,17 @@ where &mut source_xe, scratch.borrow(), ); - let atk_prepared: GGLWEAutomorphismKeyPrepared, B> = tmp.prepare_alloc(module, scratch.borrow()); + let mut atk_prepared: GLWEAutomorphismKeyPrepared, BE> = + GLWEAutomorphismKeyPrepared::alloc_from_infos(module, &tmp); + atk_prepared.prepare(module, &tmp, scratch.borrow()); auto_keys.insert(*gal_el, atk_prepared); }); let log_batch: usize = 0; - let mut packer: GLWEPacker = GLWEPacker::new(&glwe_out_infos, log_batch); + let mut packer: GLWEPacker = GLWEPacker::alloc(&glwe_out_infos, log_batch); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&glwe_out_infos); + let mut ct: GLWE> = GLWE::alloc_from_infos(&glwe_out_infos); ct.encrypt_sk( module, @@ -166,24 +124,19 @@ where scratch.borrow(), ); - pt.rotate_inplace(module, -(1 << log_batch), scratch.borrow()); // X^-batch * pt + module.glwe_rotate_inplace(-(1 << log_batch), &mut pt, scratch.borrow()); // X^-batch * pt if reverse_bits_msb(i, log_n as u32).is_multiple_of(5) { packer.add(module, Some(&ct), &auto_keys, scratch.borrow()); } else { - packer.add( - module, - None::<&GLWECiphertext>>, - &auto_keys, - scratch.borrow(), - ) + packer.add(module, None::<&GLWE>>, &auto_keys, scratch.borrow()) } }); - let mut res: GLWECiphertext> = GLWECiphertext::alloc(&glwe_out_infos); + let mut res: GLWE> = GLWE::alloc_from_infos(&glwe_out_infos); packer.flush(module, &mut res); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_out_infos); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_out_infos); let mut data: Vec = vec![0i64; n]; data.iter_mut().enumerate().for_each(|(i, x)| { if i.is_multiple_of(5) { @@ -195,7 +148,7 @@ where res.decrypt(module, &mut pt, &sk_dft, scratch.borrow()); - pt.sub_inplace_ab(module, &pt_want); + module.glwe_sub_inplace(&mut pt, &pt_want); let noise_have: f64 = pt.std().log2(); diff --git a/poulpy-core/src/tests/test_suite/trace.rs b/poulpy-core/src/tests/test_suite/trace.rs index bf348ca..ed2ed79 100644 --- a/poulpy-core/src/tests/test_suite/trace.rs +++ b/poulpy-core/src/tests/test_suite/trace.rs @@ -1,77 +1,36 @@ use std::collections::HashMap; use poulpy_hal::{ - api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxBigAddInplace, - VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigAutomorphismInplace, VecZnxBigNormalize, - VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, - VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, - VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, - VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, - }, - layouts::{Backend, Module, ScratchOwned, ZnxView, ZnxViewMut}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, - TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, - }, + api::{ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxFillUniform, VecZnxNormalizeInplace, VecZnxSubInplace}, + layouts::{Backend, Module, Scratch, ScratchOwned, ZnxView, ZnxViewMut}, source::Source, }; use crate::{ + GLWEAutomorphismKeyEncryptSk, GLWEDecrypt, GLWEEncryptSk, ScratchTakeCore, encryption::SIGMA, + glwe_trace::GLWETrace, layouts::{ - GGLWEAutomorphismKey, GGLWEAutomorphismKeyLayout, GLWECiphertext, GLWECiphertextLayout, GLWEPlaintext, GLWESecret, - LWEInfos, - prepared::{GGLWEAutomorphismKeyPrepared, GLWESecretPrepared, PrepareAlloc}, + GLWE, GLWEAutomorphismKey, GLWEAutomorphismKeyLayout, GLWEAutomorphismKeyPreparedFactory, GLWELayout, GLWEPlaintext, + GLWESecret, GLWESecretPreparedFactory, LWEInfos, + prepared::{GLWEAutomorphismKeyPrepared, GLWESecretPrepared}, }, noise::var_noise_gglwe_product, }; -pub fn test_glwe_trace_inplace(module: &Module) +pub fn test_glwe_trace_inplace(module: &Module) where - Module: VecZnxDftAllocBytes - + VecZnxAutomorphism - + VecZnxBigAutomorphismInplace - + VecZnxBigSubSmallNegateInplace - + VecZnxRshInplace - + VecZnxRotateInplace - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume + Module: GLWETrace + + GLWEEncryptSk + + GLWEDecrypt + + GLWEAutomorphismKeyEncryptSk + + GLWEAutomorphismKeyPreparedFactory + VecZnxFillUniform + + GLWESecretPreparedFactory + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxNormalizeTmpBytes - + VecZnxAddScalarInplace - + VmpPMatAlloc - + VmpPrepare - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxBigNormalizeTmpBytes - + VecZnxSwitchRing - + VecZnxCopy, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeSvpPPolImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl, + + VecZnxNormalizeInplace, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, { let base2k: usize = 8; let k: usize = 54; @@ -83,14 +42,14 @@ where let dsize: usize = 1; let dnum: usize = k.div_ceil(base2k * dsize); - let glwe_out_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_out_infos: GLWELayout = GLWELayout { n: n.into(), base2k: base2k.into(), k: k.into(), rank: rank.into(), }; - let key_infos: GGLWEAutomorphismKeyLayout = GGLWEAutomorphismKeyLayout { + let key_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n: n.into(), base2k: base2k.into(), k: k_autokey.into(), @@ -99,24 +58,26 @@ where dnum: dnum.into(), }; - let mut glwe_out: GLWECiphertext> = GLWECiphertext::alloc(&glwe_out_infos); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_out_infos); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_out_infos); + let mut glwe_out: GLWE> = GLWE::alloc_from_infos(&glwe_out_infos); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_out_infos); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_out_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::encrypt_sk_scratch_space(module, &glwe_out_infos) - | GLWECiphertext::decrypt_scratch_space(module, &glwe_out_infos) - | GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, &key_infos) - | GLWECiphertext::trace_inplace_scratch_space(module, &glwe_out_infos, &key_infos), + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GLWE::encrypt_sk_tmp_bytes(module, &glwe_out_infos) + | GLWE::decrypt_tmp_bytes(module, &glwe_out_infos) + | GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &key_infos) + | GLWE::trace_tmp_bytes(module, &glwe_out_infos, &glwe_out_infos, &key_infos), ); - let mut sk: GLWESecret> = GLWESecret::alloc(&glwe_out_infos); + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_out_infos); sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_dft: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); + + let mut sk_dft: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc_from_infos(module, &sk); + sk_dft.prepare(module, &sk); let mut data_want: Vec = vec![0i64; n]; @@ -135,9 +96,9 @@ where scratch.borrow(), ); - let mut auto_keys: HashMap, B>> = HashMap::new(); - let gal_els: Vec = GLWECiphertext::trace_galois_elements(module); - let mut tmp: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&key_infos); + let mut auto_keys: HashMap, BE>> = HashMap::new(); + let gal_els: Vec = GLWE::trace_galois_elements(module); + let mut tmp: GLWEAutomorphismKey> = GLWEAutomorphismKey::alloc_from_infos(&key_infos); gal_els.iter().for_each(|gal_el| { tmp.encrypt_sk( module, @@ -147,7 +108,9 @@ where &mut source_xe, scratch.borrow(), ); - let atk_prepared: GGLWEAutomorphismKeyPrepared, B> = tmp.prepare_alloc(module, scratch.borrow()); + let mut atk_prepared: GLWEAutomorphismKeyPrepared, BE> = + GLWEAutomorphismKeyPrepared::alloc_from_infos(module, &tmp); + atk_prepared.prepare(module, &tmp, scratch.borrow()); auto_keys.insert(*gal_el, atk_prepared); }); diff --git a/poulpy-hal/src/api/module.rs b/poulpy-hal/src/api/module.rs index 6e5faed..a18af44 100644 --- a/poulpy-hal/src/api/module.rs +++ b/poulpy-hal/src/api/module.rs @@ -4,3 +4,16 @@ use crate::layouts::Backend; pub trait ModuleNew { fn new(n: u64) -> Self; } + +pub trait ModuleN { + fn n(&self) -> usize; +} + +pub trait ModuleLogN +where + Self: ModuleN, +{ + fn log_n(&self) -> usize { + (u64::BITS - (self.n() as u64 - 1).leading_zeros()) as usize + } +} diff --git a/poulpy-hal/src/api/scratch.rs b/poulpy-hal/src/api/scratch.rs index 38901bf..9e3c484 100644 --- a/poulpy-hal/src/api/scratch.rs +++ b/poulpy-hal/src/api/scratch.rs @@ -1,4 +1,7 @@ -use crate::layouts::{Backend, MatZnx, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}; +use crate::{ + api::{ModuleN, SvpPPolBytesOf, VecZnxBigBytesOf, VecZnxDftBytesOf, VmpPMatBytesOf}, + layouts::{Backend, MatZnx, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}, +}; /// Allocates a new [crate::layouts::ScratchOwned] of `size` aligned bytes. pub trait ScratchOwnedAlloc { @@ -25,70 +28,102 @@ pub trait TakeSlice { fn take_slice(&mut self, len: usize) -> (&mut [T], &mut Self); } -/// Take a slice of bytes from a [Scratch], wraps it into a [ScalarZnx] and returns it -/// as well as a new [Scratch] minus the taken array of bytes. -pub trait TakeScalarZnx { - fn take_scalar_znx(&mut self, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self); -} +impl ScratchTakeBasic for Scratch where Self: TakeSlice {} -/// Take a slice of bytes from a [Scratch], wraps it into a [SvpPPol] and returns it -/// as well as a new [Scratch] minus the taken array of bytes. -pub trait TakeSvpPPol { - fn take_svp_ppol(&mut self, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Self); -} +pub trait ScratchTakeBasic +where + Self: TakeSlice, +{ + fn take_scalar_znx(&mut self, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) { + let (take_slice, rem_slice) = self.take_slice(ScalarZnx::bytes_of(n, cols)); + (ScalarZnx::from_data(take_slice, n, cols), rem_slice) + } -/// Take a slice of bytes from a [Scratch], wraps it into a [VecZnx] and returns it -/// as well as a new [Scratch] minus the taken array of bytes. -pub trait TakeVecZnx { - fn take_vec_znx(&mut self, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self); -} + fn take_svp_ppol(&mut self, module: &M, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Self) + where + M: SvpPPolBytesOf + ModuleN, + { + let (take_slice, rem_slice) = self.take_slice(module.bytes_of_svp_ppol(cols)); + (SvpPPol::from_data(take_slice, module.n(), cols), rem_slice) + } -/// Take a slice of bytes from a [Scratch], slices it into a vector of [VecZnx] aand returns it -/// as well as a new [Scratch] minus the taken array of bytes. -pub trait TakeVecZnxSlice { - fn take_vec_znx_slice(&mut self, len: usize, n: usize, cols: usize, size: usize) -> (Vec>, &mut Self); -} + fn take_vec_znx(&mut self, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) { + let (take_slice, rem_slice) = self.take_slice(VecZnx::bytes_of(n, cols, size)); + (VecZnx::from_data(take_slice, n, cols, size), rem_slice) + } -/// Take a slice of bytes from a [Scratch], wraps it into a [VecZnxBig] and returns it -/// as well as a new [Scratch] minus the taken array of bytes. -pub trait TakeVecZnxBig { - fn take_vec_znx_big(&mut self, n: usize, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, &mut Self); -} + fn take_vec_znx_big(&mut self, module: &M, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, &mut Self) + where + M: VecZnxBigBytesOf + ModuleN, + { + let (take_slice, rem_slice) = self.take_slice(module.bytes_of_vec_znx_big(cols, size)); + ( + VecZnxBig::from_data(take_slice, module.n(), cols, size), + rem_slice, + ) + } -/// Take a slice of bytes from a [Scratch], wraps it into a [VecZnxDft] and returns it -/// as well as a new [Scratch] minus the taken array of bytes. -pub trait TakeVecZnxDft { - fn take_vec_znx_dft(&mut self, n: usize, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, &mut Self); -} + fn take_vec_znx_dft(&mut self, module: &M, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, &mut Self) + where + M: VecZnxDftBytesOf + ModuleN, + { + let (take_slice, rem_slice) = self.take_slice(module.bytes_of_vec_znx_dft(cols, size)); -/// Take a slice of bytes from a [Scratch], slices it into a vector of [VecZnxDft] and returns it -/// as well as a new [Scratch] minus the taken array of bytes. -pub trait TakeVecZnxDftSlice { - fn take_vec_znx_dft_slice( + ( + VecZnxDft::from_data(take_slice, module.n(), cols, size), + rem_slice, + ) + } + + fn take_vec_znx_dft_slice( &mut self, + module: &M, len: usize, - n: usize, cols: usize, size: usize, - ) -> (Vec>, &mut Self); -} + ) -> (Vec>, &mut Self) + where + M: VecZnxDftBytesOf + ModuleN, + { + let mut scratch: &mut Self = self; + let mut slice: Vec> = Vec::with_capacity(len); + for _ in 0..len { + let (znx, new_scratch) = scratch.take_vec_znx_dft(module, cols, size); + scratch = new_scratch; + slice.push(znx); + } + (slice, scratch) + } -/// Take a slice of bytes from a [Scratch], wraps it into a [VmpPMat] and returns it -/// as well as a new [Scratch] minus the taken array of bytes. -pub trait TakeVmpPMat { - fn take_vmp_pmat( + fn take_vec_znx_slice(&mut self, len: usize, n: usize, cols: usize, size: usize) -> (Vec>, &mut Self) { + let mut scratch: &mut Self = self; + let mut slice: Vec> = Vec::with_capacity(len); + for _ in 0..len { + let (znx, new_scratch) = scratch.take_vec_znx(n, cols, size); + scratch = new_scratch; + slice.push(znx); + } + (slice, scratch) + } + + fn take_vmp_pmat( &mut self, - n: usize, + module: &M, rows: usize, cols_in: usize, cols_out: usize, size: usize, - ) -> (VmpPMat<&mut [u8], B>, &mut Self); -} + ) -> (VmpPMat<&mut [u8], B>, &mut Self) + where + M: VmpPMatBytesOf + ModuleN, + { + let (take_slice, rem_slice) = self.take_slice(module.bytes_of_vmp_pmat(rows, cols_in, cols_out, size)); + ( + VmpPMat::from_data(take_slice, module.n(), rows, cols_in, cols_out, size), + rem_slice, + ) + } -/// Take a slice of bytes from a [Scratch], wraps it into a [MatZnx] and returns it -/// as well as a new [Scratch] minus the taken array of bytes. -pub trait TakeMatZnx { fn take_mat_znx( &mut self, n: usize, @@ -96,5 +131,11 @@ pub trait TakeMatZnx { cols_in: usize, cols_out: usize, size: usize, - ) -> (MatZnx<&mut [u8]>, &mut Self); + ) -> (MatZnx<&mut [u8]>, &mut Self) { + let (take_slice, rem_slice) = self.take_slice(MatZnx::bytes_of(n, rows, cols_in, cols_out, size)); + ( + MatZnx::from_data(take_slice, n, rows, cols_in, cols_out, size), + rem_slice, + ) + } } diff --git a/poulpy-hal/src/api/svp_ppol.rs b/poulpy-hal/src/api/svp_ppol.rs index 5a72367..6678584 100644 --- a/poulpy-hal/src/api/svp_ppol.rs +++ b/poulpy-hal/src/api/svp_ppol.rs @@ -8,8 +8,8 @@ pub trait SvpPPolAlloc { } /// Returns the size in bytes to allocate a [crate::layouts::SvpPPol]. -pub trait SvpPPolAllocBytes { - fn svp_ppol_alloc_bytes(&self, cols: usize) -> usize; +pub trait SvpPPolBytesOf { + fn bytes_of_svp_ppol(&self, cols: usize) -> usize; } /// Consume a vector of bytes into a [crate::layouts::MatZnx]. diff --git a/poulpy-hal/src/api/vec_znx_big.rs b/poulpy-hal/src/api/vec_znx_big.rs index 08159bb..8cb5105 100644 --- a/poulpy-hal/src/api/vec_znx_big.rs +++ b/poulpy-hal/src/api/vec_znx_big.rs @@ -16,8 +16,8 @@ pub trait VecZnxBigAlloc { } /// Returns the size in bytes to allocate a [crate::layouts::VecZnxBig]. -pub trait VecZnxBigAllocBytes { - fn vec_znx_big_alloc_bytes(&self, cols: usize, size: usize) -> usize; +pub trait VecZnxBigBytesOf { + fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize; } /// Consume a vector of bytes into a [crate::layouts::VecZnxBig]. diff --git a/poulpy-hal/src/api/vec_znx_dft.rs b/poulpy-hal/src/api/vec_znx_dft.rs index 58589c3..3a003a9 100644 --- a/poulpy-hal/src/api/vec_znx_dft.rs +++ b/poulpy-hal/src/api/vec_znx_dft.rs @@ -10,8 +10,8 @@ pub trait VecZnxDftFromBytes { fn vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxDftOwned; } -pub trait VecZnxDftAllocBytes { - fn vec_znx_dft_alloc_bytes(&self, cols: usize, size: usize) -> usize; +pub trait VecZnxDftBytesOf { + fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize; } pub trait VecZnxDftApply { diff --git a/poulpy-hal/src/api/vmp_pmat.rs b/poulpy-hal/src/api/vmp_pmat.rs index 3d0e248..de3433a 100644 --- a/poulpy-hal/src/api/vmp_pmat.rs +++ b/poulpy-hal/src/api/vmp_pmat.rs @@ -6,8 +6,8 @@ pub trait VmpPMatAlloc { fn vmp_pmat_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned; } -pub trait VmpPMatAllocBytes { - fn vmp_pmat_alloc_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize; +pub trait VmpPMatBytesOf { + fn bytes_of_vmp_pmat(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize; } pub trait VmpPMatFromBytes { diff --git a/poulpy-hal/src/delegates/module.rs b/poulpy-hal/src/delegates/module.rs index 0e3a455..fa01c2f 100644 --- a/poulpy-hal/src/delegates/module.rs +++ b/poulpy-hal/src/delegates/module.rs @@ -1,5 +1,5 @@ use crate::{ - api::ModuleNew, + api::{ModuleN, ModuleNew}, layouts::{Backend, Module}, oep::ModuleNewImpl, }; @@ -12,3 +12,12 @@ where B::new_impl(n) } } + +impl ModuleN for Module +where + B: Backend, +{ + fn n(&self) -> usize { + self.n() + } +} diff --git a/poulpy-hal/src/delegates/scratch.rs b/poulpy-hal/src/delegates/scratch.rs index ac022e3..b91afbb 100644 --- a/poulpy-hal/src/delegates/scratch.rs +++ b/poulpy-hal/src/delegates/scratch.rs @@ -1,14 +1,7 @@ use crate::{ - api::{ - ScratchAvailable, ScratchFromBytes, ScratchOwnedAlloc, ScratchOwnedBorrow, TakeMatZnx, TakeScalarZnx, TakeSlice, - TakeSvpPPol, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice, TakeVecZnxSlice, TakeVmpPMat, - }, - layouts::{Backend, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}, - oep::{ - ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeMatZnxImpl, - TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxDftSliceImpl, - TakeVecZnxImpl, TakeVecZnxSliceImpl, TakeVmpPMatImpl, - }, + api::{ScratchAvailable, ScratchFromBytes, ScratchOwnedAlloc, ScratchOwnedBorrow, TakeSlice}, + layouts::{Backend, Scratch, ScratchOwned}, + oep::{ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeSliceImpl}, }; impl ScratchOwnedAlloc for ScratchOwned @@ -55,104 +48,3 @@ where B::take_slice_impl(self, len) } } - -impl TakeScalarZnx for Scratch -where - B: Backend + TakeScalarZnxImpl, -{ - fn take_scalar_znx(&mut self, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) { - B::take_scalar_znx_impl(self, n, cols) - } -} - -impl TakeSvpPPol for Scratch -where - B: Backend + TakeSvpPPolImpl, -{ - fn take_svp_ppol(&mut self, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Self) { - B::take_svp_ppol_impl(self, n, cols) - } -} - -impl TakeVecZnx for Scratch -where - B: Backend + TakeVecZnxImpl, -{ - fn take_vec_znx(&mut self, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) { - B::take_vec_znx_impl(self, n, cols, size) - } -} - -impl TakeVecZnxSlice for Scratch -where - B: Backend + TakeVecZnxSliceImpl, -{ - fn take_vec_znx_slice(&mut self, len: usize, n: usize, cols: usize, size: usize) -> (Vec>, &mut Self) { - B::take_vec_znx_slice_impl(self, len, n, cols, size) - } -} - -impl TakeVecZnxBig for Scratch -where - B: Backend + TakeVecZnxBigImpl, -{ - fn take_vec_znx_big(&mut self, n: usize, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, &mut Self) { - B::take_vec_znx_big_impl(self, n, cols, size) - } -} - -impl TakeVecZnxDft for Scratch -where - B: Backend + TakeVecZnxDftImpl, -{ - fn take_vec_znx_dft(&mut self, n: usize, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, &mut Self) { - B::take_vec_znx_dft_impl(self, n, cols, size) - } -} - -impl TakeVecZnxDftSlice for Scratch -where - B: Backend + TakeVecZnxDftSliceImpl, -{ - fn take_vec_znx_dft_slice( - &mut self, - len: usize, - n: usize, - cols: usize, - size: usize, - ) -> (Vec>, &mut Self) { - B::take_vec_znx_dft_slice_impl(self, len, n, cols, size) - } -} - -impl TakeVmpPMat for Scratch -where - B: Backend + TakeVmpPMatImpl, -{ - fn take_vmp_pmat( - &mut self, - n: usize, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - ) -> (VmpPMat<&mut [u8], B>, &mut Self) { - B::take_vmp_pmat_impl(self, n, rows, cols_in, cols_out, size) - } -} - -impl TakeMatZnx for Scratch -where - B: Backend + TakeMatZnxImpl, -{ - fn take_mat_znx( - &mut self, - n: usize, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - ) -> (MatZnx<&mut [u8]>, &mut Self) { - B::take_mat_znx_impl(self, n, rows, cols_in, cols_out, size) - } -} diff --git a/poulpy-hal/src/delegates/svp_ppol.rs b/poulpy-hal/src/delegates/svp_ppol.rs index 54a99b2..de36fd0 100644 --- a/poulpy-hal/src/delegates/svp_ppol.rs +++ b/poulpy-hal/src/delegates/svp_ppol.rs @@ -1,6 +1,6 @@ use crate::{ api::{ - SvpApplyDft, SvpApplyDftToDft, SvpApplyDftToDftAdd, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, + SvpApplyDft, SvpApplyDftToDft, SvpApplyDftToDftAdd, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPPolFromBytes, SvpPrepare, }, layouts::{ @@ -30,12 +30,12 @@ where } } -impl SvpPPolAllocBytes for Module +impl SvpPPolBytesOf for Module where B: Backend + SvpPPolAllocBytesImpl, { - fn svp_ppol_alloc_bytes(&self, cols: usize) -> usize { - B::svp_ppol_alloc_bytes_impl(self.n(), cols) + fn bytes_of_svp_ppol(&self, cols: usize) -> usize { + B::svp_ppol_bytes_of_impl(self.n(), cols) } } diff --git a/poulpy-hal/src/delegates/vec_znx_big.rs b/poulpy-hal/src/delegates/vec_znx_big.rs index 1556a87..a1cc307 100644 --- a/poulpy-hal/src/delegates/vec_znx_big.rs +++ b/poulpy-hal/src/delegates/vec_znx_big.rs @@ -1,7 +1,7 @@ use crate::{ api::{ VecZnxBigAdd, VecZnxBigAddInplace, VecZnxBigAddNormal, VecZnxBigAddSmall, VecZnxBigAddSmallInplace, VecZnxBigAlloc, - VecZnxBigAllocBytes, VecZnxBigAutomorphism, VecZnxBigAutomorphismInplace, VecZnxBigAutomorphismInplaceTmpBytes, + VecZnxBigAutomorphism, VecZnxBigAutomorphismInplace, VecZnxBigAutomorphismInplaceTmpBytes, VecZnxBigBytesOf, VecZnxBigFromBytes, VecZnxBigFromSmall, VecZnxBigNegate, VecZnxBigNegateInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSub, VecZnxBigSubInplace, VecZnxBigSubNegateInplace, VecZnxBigSubSmallA, VecZnxBigSubSmallB, VecZnxBigSubSmallInplace, VecZnxBigSubSmallNegateInplace, @@ -49,12 +49,12 @@ where } } -impl VecZnxBigAllocBytes for Module +impl VecZnxBigBytesOf for Module where B: Backend + VecZnxBigAllocBytesImpl, { - fn vec_znx_big_alloc_bytes(&self, cols: usize, size: usize) -> usize { - B::vec_znx_big_alloc_bytes_impl(self.n(), cols, size) + fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize { + B::vec_znx_big_bytes_of_impl(self.n(), cols, size) } } diff --git a/poulpy-hal/src/delegates/vec_znx_dft.rs b/poulpy-hal/src/delegates/vec_znx_dft.rs index 3736e34..16a583f 100644 --- a/poulpy-hal/src/delegates/vec_znx_dft.rs +++ b/poulpy-hal/src/delegates/vec_znx_dft.rs @@ -1,8 +1,8 @@ use crate::{ api::{ - VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy, - VecZnxDftFromBytes, VecZnxDftSub, VecZnxDftSubInplace, VecZnxDftSubNegateInplace, VecZnxDftZero, VecZnxIdftApply, - VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxIdftApplyTmpBytes, + VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftApply, VecZnxDftBytesOf, VecZnxDftCopy, VecZnxDftFromBytes, + VecZnxDftSub, VecZnxDftSubInplace, VecZnxDftSubNegateInplace, VecZnxDftZero, VecZnxIdftApply, VecZnxIdftApplyConsume, + VecZnxIdftApplyTmpA, VecZnxIdftApplyTmpBytes, }, layouts::{ Backend, Data, Module, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef, @@ -24,12 +24,12 @@ where } } -impl VecZnxDftAllocBytes for Module +impl VecZnxDftBytesOf for Module where B: Backend + VecZnxDftAllocBytesImpl, { - fn vec_znx_dft_alloc_bytes(&self, cols: usize, size: usize) -> usize { - B::vec_znx_dft_alloc_bytes_impl(self.n(), cols, size) + fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize { + B::vec_znx_dft_bytes_of_impl(self.n(), cols, size) } } diff --git a/poulpy-hal/src/delegates/vmp_pmat.rs b/poulpy-hal/src/delegates/vmp_pmat.rs index a875a40..2c65508 100644 --- a/poulpy-hal/src/delegates/vmp_pmat.rs +++ b/poulpy-hal/src/delegates/vmp_pmat.rs @@ -1,7 +1,7 @@ use crate::{ api::{ VmpApplyDft, VmpApplyDftTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftAddTmpBytes, - VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPMatAllocBytes, VmpPMatFromBytes, VmpPrepare, VmpPrepareTmpBytes, + VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPMatBytesOf, VmpPMatFromBytes, VmpPrepare, VmpPrepareTmpBytes, }, layouts::{ Backend, MatZnxToRef, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, VmpPMatOwned, VmpPMatToMut, @@ -23,12 +23,12 @@ where } } -impl VmpPMatAllocBytes for Module +impl VmpPMatBytesOf for Module where B: Backend + VmpPMatAllocBytesImpl, { - fn vmp_pmat_alloc_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { - B::vmp_pmat_alloc_bytes_impl(self.n(), rows, cols_in, cols_out, size) + fn bytes_of_vmp_pmat(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { + B::vmp_pmat_bytes_of_impl(self.n(), rows, cols_in, cols_out, size) } } diff --git a/poulpy-hal/src/layouts/mat_znx.rs b/poulpy-hal/src/layouts/mat_znx.rs index 01be1a1..5ccb860 100644 --- a/poulpy-hal/src/layouts/mat_znx.rs +++ b/poulpy-hal/src/layouts/mat_znx.rs @@ -114,12 +114,12 @@ impl MatZnx { } impl MatZnx> { - pub fn alloc_bytes(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { - rows * cols_in * VecZnx::>::alloc_bytes(n, cols_out, size) + pub fn bytes_of(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { + rows * cols_in * VecZnx::>::bytes_of(n, cols_out, size) } pub fn alloc(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self { - let data: Vec = alloc_aligned(Self::alloc_bytes(n, rows, cols_in, cols_out, size)); + let data: Vec = alloc_aligned(Self::bytes_of(n, rows, cols_in, cols_out, size)); Self { data, n, @@ -132,7 +132,7 @@ impl MatZnx> { pub fn from_bytes(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); - assert!(data.len() == Self::alloc_bytes(n, rows, cols_in, cols_out, size)); + assert!(data.len() == Self::bytes_of(n, rows, cols_in, cols_out, size)); Self { data, n, @@ -153,7 +153,7 @@ impl MatZnx { } let self_ref: MatZnx<&[u8]> = self.to_ref(); - let nb_bytes: usize = VecZnx::>::alloc_bytes(self.n, self.cols_out, self.size); + let nb_bytes: usize = VecZnx::>::bytes_of(self.n, self.cols_out, self.size); let start: usize = nb_bytes * self.cols() * row + col * nb_bytes; let end: usize = start + nb_bytes; @@ -181,7 +181,7 @@ impl MatZnx { let size: usize = self.size(); let self_ref: MatZnx<&mut [u8]> = self.to_mut(); - let nb_bytes: usize = VecZnx::>::alloc_bytes(n, cols_out, size); + let nb_bytes: usize = VecZnx::>::bytes_of(n, cols_out, size); let start: usize = nb_bytes * cols_in * row + col * nb_bytes; let end: usize = start + nb_bytes; diff --git a/poulpy-hal/src/layouts/module.rs b/poulpy-hal/src/layouts/module.rs index 61e312c..0556a6f 100644 --- a/poulpy-hal/src/layouts/module.rs +++ b/poulpy-hal/src/layouts/module.rs @@ -1,13 +1,20 @@ -use std::{fmt::Display, marker::PhantomData, ptr::NonNull}; +use std::{ + fmt::{Debug, Display}, + marker::PhantomData, + ptr::NonNull, +}; use rand_distr::num_traits::Zero; -use crate::GALOISGENERATOR; +use crate::{ + GALOISGENERATOR, + api::{ModuleLogN, ModuleN}, +}; #[allow(clippy::missing_safety_doc)] pub trait Backend: Sized { - type ScalarBig: Copy + Zero + Display; - type ScalarPrep: Copy + Zero + Display; + type ScalarBig: Copy + Zero + Display + Debug; + type ScalarPrep: Copy + Zero + Display + Debug; type Handle: 'static; fn layout_prep_word_count() -> usize; fn layout_big_word_count() -> usize; @@ -75,36 +82,56 @@ impl Module { pub fn log_n(&self) -> usize { (usize::BITS - (self.n() - 1).leading_zeros()) as _ } +} - #[inline] - pub fn cyclotomic_order(&self) -> u64 { +pub trait CyclotomicOrder +where + Self: ModuleN, +{ + fn cyclotomic_order(&self) -> i64 { (self.n() << 1) as _ } +} +impl ModuleLogN for Module where Self: ModuleN {} + +impl CyclotomicOrder for Module where Self: ModuleN {} + +#[inline(always)] +pub fn galois_element(generator: i64, cyclotomic_order: i64) -> i64 { + if generator == 0 { + return 1; + } + + let g_exp: u64 = mod_exp_u64(GALOISGENERATOR, generator.unsigned_abs() as usize) & (cyclotomic_order - 1) as u64; + g_exp as i64 * generator.signum() +} + +pub trait GaloisElement +where + Self: CyclotomicOrder, +{ // Returns GALOISGENERATOR^|generator| * sign(generator) - #[inline] - pub fn galois_element(&self, generator: i64) -> i64 { - if generator == 0 { - return 1; - } - ((mod_exp_u64(GALOISGENERATOR, generator.unsigned_abs() as usize) & (self.cyclotomic_order() - 1)) as i64) - * generator.signum() + fn galois_element(&self, generator: i64) -> i64 { + galois_element(generator, self.cyclotomic_order()) } // Returns gen^-1 - #[inline] - pub fn galois_element_inv(&self, gal_el: i64) -> i64 { + fn galois_element_inv(&self, gal_el: i64) -> i64 { if gal_el == 0 { panic!("cannot invert 0") } - ((mod_exp_u64( + + let g_exp: u64 = mod_exp_u64( gal_el.unsigned_abs(), (self.cyclotomic_order() - 1) as usize, - ) & (self.cyclotomic_order() - 1)) as i64) - * gal_el.signum() + ) & (self.cyclotomic_order() - 1) as u64; + g_exp as i64 * gal_el.signum() } } +impl GaloisElement for Module where Self: CyclotomicOrder {} + impl Drop for Module { fn drop(&mut self) { unsafe { B::destroy(self.ptr) } diff --git a/poulpy-hal/src/layouts/scalar_znx.rs b/poulpy-hal/src/layouts/scalar_znx.rs index 296071b..bdf2159 100644 --- a/poulpy-hal/src/layouts/scalar_znx.rs +++ b/poulpy-hal/src/layouts/scalar_znx.rs @@ -132,18 +132,18 @@ impl ScalarZnx { } impl ScalarZnx> { - pub fn alloc_bytes(n: usize, cols: usize) -> usize { + pub fn bytes_of(n: usize, cols: usize) -> usize { n * cols * size_of::() } pub fn alloc(n: usize, cols: usize) -> Self { - let data: Vec = alloc_aligned::(Self::alloc_bytes(n, cols)); + let data: Vec = alloc_aligned::(Self::bytes_of(n, cols)); Self { data, n, cols } } pub fn from_bytes(n: usize, cols: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); - assert!(data.len() == Self::alloc_bytes(n, cols)); + assert!(data.len() == Self::bytes_of(n, cols)); Self { data, n, cols } } } diff --git a/poulpy-hal/src/layouts/svp_ppol.rs b/poulpy-hal/src/layouts/svp_ppol.rs index 50523fc..234ebca 100644 --- a/poulpy-hal/src/layouts/svp_ppol.rs +++ b/poulpy-hal/src/layouts/svp_ppol.rs @@ -77,7 +77,7 @@ where B: SvpPPolAllocBytesImpl, { pub fn alloc(n: usize, cols: usize) -> Self { - let data: Vec = alloc_aligned::(B::svp_ppol_alloc_bytes_impl(n, cols)); + let data: Vec = alloc_aligned::(B::svp_ppol_bytes_of_impl(n, cols)); Self { data: data.into(), n, @@ -88,7 +88,7 @@ where pub fn from_bytes(n: usize, cols: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); - assert!(data.len() == B::svp_ppol_alloc_bytes_impl(n, cols)); + assert!(data.len() == B::svp_ppol_bytes_of_impl(n, cols)); Self { data: data.into(), n, diff --git a/poulpy-hal/src/layouts/vec_znx.rs b/poulpy-hal/src/layouts/vec_znx.rs index d40ef4b..c084934 100644 --- a/poulpy-hal/src/layouts/vec_znx.rs +++ b/poulpy-hal/src/layouts/vec_znx.rs @@ -110,7 +110,7 @@ impl ZnxView for VecZnx { } impl VecZnx> { - pub fn rsh_scratch_space(n: usize) -> usize { + pub fn rsh_tmp_bytes(n: usize) -> usize { n * std::mem::size_of::() } } @@ -125,12 +125,12 @@ impl ZnxZero for VecZnx { } impl VecZnx> { - pub fn alloc_bytes(n: usize, cols: usize, size: usize) -> usize { + pub fn bytes_of(n: usize, cols: usize, size: usize) -> usize { n * cols * size * size_of::() } pub fn alloc(n: usize, cols: usize, size: usize) -> Self { - let data: Vec = alloc_aligned::(Self::alloc_bytes(n, cols, size)); + let data: Vec = alloc_aligned::(Self::bytes_of(n, cols, size)); Self { data, n, @@ -142,7 +142,7 @@ impl VecZnx> { pub fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); - assert!(data.len() == Self::alloc_bytes(n, cols, size)); + assert!(data.len() == Self::bytes_of(n, cols, size)); Self { data, n, diff --git a/poulpy-hal/src/layouts/vec_znx_big.rs b/poulpy-hal/src/layouts/vec_znx_big.rs index c50cf66..73a3e0f 100644 --- a/poulpy-hal/src/layouts/vec_znx_big.rs +++ b/poulpy-hal/src/layouts/vec_znx_big.rs @@ -96,7 +96,7 @@ where B: VecZnxBigAllocBytesImpl, { pub fn alloc(n: usize, cols: usize, size: usize) -> Self { - let data = alloc_aligned::(B::vec_znx_big_alloc_bytes_impl(n, cols, size)); + let data = alloc_aligned::(B::vec_znx_big_bytes_of_impl(n, cols, size)); Self { data: data.into(), n, @@ -109,7 +109,7 @@ where pub fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); - assert!(data.len() == B::vec_znx_big_alloc_bytes_impl(n, cols, size)); + assert!(data.len() == B::vec_znx_big_bytes_of_impl(n, cols, size)); Self { data: data.into(), n, diff --git a/poulpy-hal/src/layouts/vec_znx_dft.rs b/poulpy-hal/src/layouts/vec_znx_dft.rs index 3dc92d5..19d28e1 100644 --- a/poulpy-hal/src/layouts/vec_znx_dft.rs +++ b/poulpy-hal/src/layouts/vec_znx_dft.rs @@ -116,7 +116,7 @@ where B: VecZnxDftAllocBytesImpl, { pub fn alloc(n: usize, cols: usize, size: usize) -> Self { - let data: Vec = alloc_aligned::(B::vec_znx_dft_alloc_bytes_impl(n, cols, size)); + let data: Vec = alloc_aligned::(B::vec_znx_dft_bytes_of_impl(n, cols, size)); Self { data: data.into(), n, @@ -129,7 +129,7 @@ where pub fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); - assert!(data.len() == B::vec_znx_dft_alloc_bytes_impl(n, cols, size)); + assert!(data.len() == B::vec_znx_dft_bytes_of_impl(n, cols, size)); Self { data: data.into(), n, diff --git a/poulpy-hal/src/layouts/vmp_pmat.rs b/poulpy-hal/src/layouts/vmp_pmat.rs index ce83458..bd469ec 100644 --- a/poulpy-hal/src/layouts/vmp_pmat.rs +++ b/poulpy-hal/src/layouts/vmp_pmat.rs @@ -88,9 +88,7 @@ where B: VmpPMatAllocBytesImpl, { pub fn alloc(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self { - let data: Vec = alloc_aligned(B::vmp_pmat_alloc_bytes_impl( - n, rows, cols_in, cols_out, size, - )); + let data: Vec = alloc_aligned(B::vmp_pmat_bytes_of_impl(n, rows, cols_in, cols_out, size)); Self { data: data.into(), n, @@ -104,7 +102,7 @@ where pub fn from_bytes(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); - assert!(data.len() == B::vmp_pmat_alloc_bytes_impl(n, rows, cols_in, cols_out, size)); + assert!(data.len() == B::vmp_pmat_bytes_of_impl(n, rows, cols_in, cols_out, size)); Self { data: data.into(), n, diff --git a/poulpy-hal/src/layouts/zn.rs b/poulpy-hal/src/layouts/zn.rs index 40f5622..00f8067 100644 --- a/poulpy-hal/src/layouts/zn.rs +++ b/poulpy-hal/src/layouts/zn.rs @@ -98,7 +98,7 @@ impl ZnxView for Zn { } impl Zn> { - pub fn rsh_scratch_space(n: usize) -> usize { + pub fn rsh_tmp_bytes(n: usize) -> usize { n * std::mem::size_of::() } } @@ -113,12 +113,12 @@ impl ZnxZero for Zn { } impl Zn> { - pub fn alloc_bytes(n: usize, cols: usize, size: usize) -> usize { + pub fn bytes_of(n: usize, cols: usize, size: usize) -> usize { n * cols * size * size_of::() } pub fn alloc(n: usize, cols: usize, size: usize) -> Self { - let data: Vec = alloc_aligned::(Self::alloc_bytes(n, cols, size)); + let data: Vec = alloc_aligned::(Self::bytes_of(n, cols, size)); Self { data, n, @@ -130,7 +130,7 @@ impl Zn> { pub fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); - assert!(data.len() == Self::alloc_bytes(n, cols, size)); + assert!(data.len() == Self::bytes_of(n, cols, size)); Self { data, n, diff --git a/poulpy-hal/src/layouts/znx_base.rs b/poulpy-hal/src/layouts/znx_base.rs index 7ca75b2..45e30d4 100644 --- a/poulpy-hal/src/layouts/znx_base.rs +++ b/poulpy-hal/src/layouts/znx_base.rs @@ -1,3 +1,5 @@ +use std::fmt::{Debug, Display}; + use crate::{ layouts::{Backend, Data, DataMut, DataRef}, source::Source, @@ -48,7 +50,7 @@ pub trait DataViewMut: DataView { } pub trait ZnxView: ZnxInfos + DataView { - type Scalar: Copy + Zero; + type Scalar: Copy + Zero + Display + Debug; /// Returns a non-mutable pointer to the underlying coefficients array. fn as_ptr(&self) -> *const Self::Scalar { diff --git a/poulpy-hal/src/oep/scratch.rs b/poulpy-hal/src/oep/scratch.rs index 51a9c56..c973b04 100644 --- a/poulpy-hal/src/oep/scratch.rs +++ b/poulpy-hal/src/oep/scratch.rs @@ -1,4 +1,4 @@ -use crate::layouts::{Backend, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}; +use crate::layouts::{Backend, Scratch, ScratchOwned}; /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation. @@ -39,111 +39,3 @@ pub unsafe trait ScratchAvailableImpl { pub unsafe trait TakeSliceImpl { fn take_slice_impl(scratch: &mut Scratch, len: usize) -> (&mut [T], &mut Scratch); } - -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation. -/// * See [crate::api::TakeScalarZnx] for corresponding public API. -/// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait TakeScalarZnxImpl { - fn take_scalar_znx_impl(scratch: &mut Scratch, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Scratch); -} - -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation. -/// * See [crate::api::TakeSvpPPol] for corresponding public API. -/// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait TakeSvpPPolImpl { - fn take_svp_ppol_impl(scratch: &mut Scratch, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Scratch); -} - -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation. -/// * See [crate::api::TakeVecZnx] for corresponding public API. -/// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait TakeVecZnxImpl { - fn take_vec_znx_impl(scratch: &mut Scratch, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Scratch); -} - -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation. -/// * See [crate::api::TakeVecZnxSlice] for corresponding public API. -/// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait TakeVecZnxSliceImpl { - fn take_vec_znx_slice_impl( - scratch: &mut Scratch, - len: usize, - n: usize, - cols: usize, - size: usize, - ) -> (Vec>, &mut Scratch); -} - -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation. -/// * See [crate::api::TakeVecZnxBig] for corresponding public API. -/// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait TakeVecZnxBigImpl { - fn take_vec_znx_big_impl( - scratch: &mut Scratch, - n: usize, - cols: usize, - size: usize, - ) -> (VecZnxBig<&mut [u8], B>, &mut Scratch); -} - -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation. -/// * See [crate::api::TakeVecZnxDft] for corresponding public API. -/// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait TakeVecZnxDftImpl { - fn take_vec_znx_dft_impl( - scratch: &mut Scratch, - n: usize, - cols: usize, - size: usize, - ) -> (VecZnxDft<&mut [u8], B>, &mut Scratch); -} - -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation. -/// * See [crate::api::TakeVecZnxDftSlice] for corresponding public API. -/// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait TakeVecZnxDftSliceImpl { - fn take_vec_znx_dft_slice_impl( - scratch: &mut Scratch, - len: usize, - n: usize, - cols: usize, - size: usize, - ) -> (Vec>, &mut Scratch); -} - -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation. -/// * See [crate::api::TakeVmpPMat] for corresponding public API. -/// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait TakeVmpPMatImpl { - fn take_vmp_pmat_impl( - scratch: &mut Scratch, - n: usize, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - ) -> (VmpPMat<&mut [u8], B>, &mut Scratch); -} - -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation. -/// * See [crate::api::TakeMatZnx] for corresponding public API. -/// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait TakeMatZnxImpl { - fn take_mat_znx_impl( - scratch: &mut Scratch, - n: usize, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - ) -> (MatZnx<&mut [u8]>, &mut Scratch); -} diff --git a/poulpy-hal/src/oep/svp_ppol.rs b/poulpy-hal/src/oep/svp_ppol.rs index 6550b6f..42c50ea 100644 --- a/poulpy-hal/src/oep/svp_ppol.rs +++ b/poulpy-hal/src/oep/svp_ppol.rs @@ -23,7 +23,7 @@ pub unsafe trait SvpPPolAllocImpl { /// * See [crate::api::SvpPPolAllocBytes] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait SvpPPolAllocBytesImpl { - fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize; + fn svp_ppol_bytes_of_impl(n: usize, cols: usize) -> usize; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) diff --git a/poulpy-hal/src/oep/vec_znx_big.rs b/poulpy-hal/src/oep/vec_znx_big.rs index b2bd4c8..4c12e6a 100644 --- a/poulpy-hal/src/oep/vec_znx_big.rs +++ b/poulpy-hal/src/oep/vec_znx_big.rs @@ -35,7 +35,7 @@ pub unsafe trait VecZnxBigFromBytesImpl { /// * See [crate::api::VecZnxBigAllocBytes] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxBigAllocBytesImpl { - fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize; + fn vec_znx_big_bytes_of_impl(n: usize, cols: usize, size: usize) -> usize; } #[allow(clippy::too_many_arguments)] diff --git a/poulpy-hal/src/oep/vec_znx_dft.rs b/poulpy-hal/src/oep/vec_znx_dft.rs index e5a2bcb..0f9288b 100644 --- a/poulpy-hal/src/oep/vec_znx_dft.rs +++ b/poulpy-hal/src/oep/vec_znx_dft.rs @@ -42,7 +42,7 @@ pub unsafe trait VecZnxDftApplyImpl { /// * See [crate::api::VecZnxDftAllocBytes] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxDftAllocBytesImpl { - fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize; + fn vec_znx_dft_bytes_of_impl(n: usize, cols: usize, size: usize) -> usize; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) diff --git a/poulpy-hal/src/oep/vmp_pmat.rs b/poulpy-hal/src/oep/vmp_pmat.rs index e399f00..bdca416 100644 --- a/poulpy-hal/src/oep/vmp_pmat.rs +++ b/poulpy-hal/src/oep/vmp_pmat.rs @@ -15,7 +15,7 @@ pub unsafe trait VmpPMatAllocImpl { /// * See [crate::api::VmpPMatAllocBytes] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VmpPMatAllocBytesImpl { - fn vmp_pmat_alloc_bytes_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize; + fn vmp_pmat_bytes_of_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) diff --git a/poulpy-hal/src/reference/fft64/vmp.rs b/poulpy-hal/src/reference/fft64/vmp.rs index f6fb73c..07e1a8d 100644 --- a/poulpy-hal/src/reference/fft64/vmp.rs +++ b/poulpy-hal/src/reference/fft64/vmp.rs @@ -140,7 +140,7 @@ where assert!(a.cols() <= cols); } - let (data, tmp_bytes) = tmp_bytes.split_at_mut(BE::vec_znx_dft_alloc_bytes_impl(n, cols, size)); + let (data, tmp_bytes) = tmp_bytes.split_at_mut(BE::vec_znx_dft_bytes_of_impl(n, cols, size)); let mut a_dft: VecZnxDft<&mut [u8], BE> = VecZnxDft::from_data(cast_mut(data), n, cols, size); diff --git a/poulpy-schemes/benches/circuit_bootstrapping.rs b/poulpy-schemes/benches/circuit_bootstrapping.rs index 47c848e..0d90062 100644 --- a/poulpy-schemes/benches/circuit_bootstrapping.rs +++ b/poulpy-schemes/benches/circuit_bootstrapping.rs @@ -2,107 +2,46 @@ use std::hint::black_box; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; use poulpy_backend::{FFT64Avx, FFT64Ref, FFT64Spqlios}; -use poulpy_core::layouts::{ - Dsize, GGLWEAutomorphismKeyLayout, GGLWETensorKeyLayout, GGSWCiphertext, GGSWCiphertextLayout, GLWESecret, LWECiphertext, - LWECiphertextLayout, LWESecret, prepared::PrepareAlloc, +use poulpy_core::{ + GGSWNoise, GLWEDecrypt, GLWEEncryptSk, GLWEExternalProduct, LWEEncryptSk, ScratchTakeCore, + layouts::{ + Dsize, GGSW, GGSWLayout, GGSWPreparedFactory, GLWEAutomorphismKeyLayout, GLWESecret, GLWESecretPreparedFactory, + GLWETensorKeyLayout, LWE, LWELayout, LWESecret, + }, }; use poulpy_hal::{ - api::{ - ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, - SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, - VecZnxAutomorphismInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAllocBytes, - VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxCopy, - VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy, VecZnxFillUniform, - VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxRshInplace, VecZnxSub, - VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, - VmpPrepare, ZnAddNormal, ZnFillUniform, ZnNormalizeInplace, - }, - layouts::{Backend, Module, ScratchOwned}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeMatZnxImpl, TakeScalarZnxImpl, TakeSliceImpl, - TakeSvpPPolImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, - }, + api::{ModuleN, ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxRotateInplace}, + layouts::{Backend, Module, Scratch, ScratchOwned}, source::Source, }; use poulpy_schemes::tfhe::{ blind_rotation::{ - BlincRotationExecute, BlindRotationAlgo, BlindRotationKey, BlindRotationKeyAlloc, BlindRotationKeyEncryptSk, - BlindRotationKeyInfos, BlindRotationKeyLayout, BlindRotationKeyPrepared, CGGI, + BlindRotationAlgo, BlindRotationKey, BlindRotationKeyFactory, BlindRotationKeyInfos, BlindRotationKeyLayout, CGGI, }, circuit_bootstrapping::{ CircuitBootstrappingKey, CircuitBootstrappingKeyEncryptSk, CircuitBootstrappingKeyLayout, - CircuitBootstrappingKeyPrepared, CirtuitBootstrappingExecute, + CircuitBootstrappingKeyPrepared, CircuitBootstrappingKeyPreparedFactory, CirtuitBootstrappingExecute, }, }; -pub fn benc_circuit_bootstrapping(c: &mut Criterion, label: &str) +pub fn benc_circuit_bootstrapping(c: &mut Criterion, label: &str) where - Module: ModuleNew - + VecZnxFillUniform - + VecZnxAddNormal - + VecZnxNormalizeInplace - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalize - + VecZnxSub - + VecZnxAddScalarInplace - + VecZnxAutomorphism - + VecZnxSwitchRing - + VecZnxBigAllocBytes - + VecZnxIdftApplyTmpA - + SvpApplyDftToDft - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxBigAlloc - + VecZnxDftAlloc - + VecZnxBigNormalizeTmpBytes - + VmpPMatAlloc - + VmpPrepare - + SvpPrepare - + SvpPPolAlloc - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + SvpPPolAllocBytes - + VecZnxRotateInplace - + VecZnxBigAutomorphismInplace - + VecZnxRshInplace - + VecZnxDftCopy - + VecZnxNegateInplace - + VecZnxCopy - + VecZnxAutomorphismInplace - + VecZnxBigSubSmallNegateInplace - + VecZnxRotateInplaceTmpBytes - + VecZnxBigAllocBytes - + VecZnxDftAddInplace - + VecZnxRotate - + ZnFillUniform - + ZnAddNormal - + ZnNormalizeInplace, - ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, - B: Backend - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + TakeVecZnxDftImpl - + ScratchAvailableImpl - + TakeVecZnxImpl - + TakeScalarZnxImpl - + TakeSvpPPolImpl - + TakeVecZnxBigImpl - + TakeVecZnxDftSliceImpl - + TakeMatZnxImpl - + TakeVecZnxSliceImpl - + TakeSliceImpl, - BlindRotationKey, BRA>: PrepareAlloc, BRA, B>>, - BlindRotationKeyPrepared, BRA, B>: BlincRotationExecute, - BlindRotationKey, BRA>: BlindRotationKeyAlloc + BlindRotationKeyEncryptSk, + Module: ModuleNew + + ModuleN + + GLWESecretPreparedFactory + + GLWEExternalProduct + + GLWEDecrypt + + LWEEncryptSk + + CircuitBootstrappingKeyEncryptSk + + CircuitBootstrappingKeyPreparedFactory + + CirtuitBootstrappingExecute + + GGSWPreparedFactory + + GGSWNoise + + GLWEEncryptSk + + VecZnxRotateInplace, + BlindRotationKey, BRA>: BlindRotationKeyFactory, // TODO find a way to remove this bound or move it to CBT KEY + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchTakeCore, { let group_name: String = format!("circuit_bootstrapping::{label}"); @@ -113,86 +52,38 @@ where extension_factor: usize, k_pt: usize, block_size: usize, - lwe_infos: LWECiphertextLayout, - ggsw_infos: GGSWCiphertextLayout, + lwe_infos: LWELayout, + ggsw_infos: GGSWLayout, cbt_infos: CircuitBootstrappingKeyLayout, } - fn runner(params: &Params) -> impl FnMut() + fn runner(params: &Params) -> impl FnMut() where - Module: ModuleNew - + VecZnxFillUniform - + VecZnxAddNormal - + VecZnxNormalizeInplace - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalize - + VecZnxSub - + VecZnxAddScalarInplace - + VecZnxAutomorphism - + VecZnxSwitchRing - + VecZnxBigAllocBytes - + VecZnxIdftApplyTmpA - + SvpApplyDftToDft - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxBigAlloc - + VecZnxDftAlloc - + VecZnxBigNormalizeTmpBytes - + VmpPMatAlloc - + VmpPrepare - + SvpPrepare - + SvpPPolAlloc - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + SvpPPolAllocBytes - + VecZnxRotateInplace - + VecZnxBigAutomorphismInplace - + VecZnxRshInplace - + VecZnxDftCopy - + VecZnxNegateInplace - + VecZnxCopy - + VecZnxAutomorphismInplace - + VecZnxBigSubSmallNegateInplace - + VecZnxRotateInplaceTmpBytes - + VecZnxBigAllocBytes - + VecZnxDftAddInplace - + VecZnxRotate - + ZnFillUniform - + ZnAddNormal - + ZnNormalizeInplace, - B: Backend - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + TakeVecZnxDftImpl - + ScratchAvailableImpl - + TakeVecZnxImpl - + TakeScalarZnxImpl - + TakeSvpPPolImpl - + TakeVecZnxBigImpl - + TakeVecZnxDftSliceImpl - + TakeMatZnxImpl - + TakeVecZnxSliceImpl - + TakeSliceImpl, - BlindRotationKey, BRA>: PrepareAlloc, BRA, B>>, - BlindRotationKeyPrepared, BRA, B>: BlincRotationExecute, - BlindRotationKey, BRA>: BlindRotationKeyAlloc + BlindRotationKeyEncryptSk, + Module: ModuleNew + + ModuleN + + GLWESecretPreparedFactory + + GLWEExternalProduct + + GLWEDecrypt + + LWEEncryptSk + + CircuitBootstrappingKeyEncryptSk + + CircuitBootstrappingKeyPreparedFactory + + CirtuitBootstrappingExecute + + GGSWPreparedFactory + + GGSWNoise + + GLWEEncryptSk + + VecZnxRotateInplace, + BlindRotationKey, BRA>: BlindRotationKeyFactory, /* TODO find a way to remove this bound or move it to CBT KEY */ + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchTakeCore, { // Scratch space (4MB) - let mut scratch: ScratchOwned = ScratchOwned::alloc(1 << 22); + let mut scratch: ScratchOwned = ScratchOwned::alloc(1 << 22); let n_glwe: poulpy_core::layouts::Degree = params.cbt_infos.layout_brk.n_glwe(); let n_lwe: poulpy_core::layouts::Degree = params.cbt_infos.layout_brk.n_lwe(); let rank: poulpy_core::layouts::Rank = params.cbt_infos.layout_brk.rank; - let module: Module = Module::::new(n_glwe.as_u32() as u64); + let module: Module = Module::::new(n_glwe.as_u32() as u64); let mut source_xs: Source = Source::new([1u8; 32]); let mut source_xa: Source = Source::new([1u8; 32]); @@ -202,25 +93,26 @@ where sk_lwe.fill_binary_block(params.block_size, &mut source_xs); sk_lwe.fill_zero(); - let mut sk_glwe: GLWESecret> = GLWESecret::alloc_with(n_glwe, rank); + let mut sk_glwe: GLWESecret> = GLWESecret::alloc(n_glwe, rank); sk_glwe.fill_ternary_prob(0.5, &mut source_xs); - let ct_lwe: LWECiphertext> = LWECiphertext::alloc(¶ms.lwe_infos); + let ct_lwe: LWE> = LWE::alloc_from_infos(¶ms.lwe_infos); // Circuit bootstrapping evaluation key - let cbt_key: CircuitBootstrappingKey, BRA> = CircuitBootstrappingKey::encrypt_sk( + let mut cbt_key: CircuitBootstrappingKey, BRA> = CircuitBootstrappingKey::alloc_from_infos(¶ms.cbt_infos); + cbt_key.encrypt_sk( &module, &sk_lwe, &sk_glwe, - ¶ms.cbt_infos, &mut source_xa, &mut source_xe, scratch.borrow(), ); - let mut res: GGSWCiphertext> = GGSWCiphertext::alloc(¶ms.ggsw_infos); - let cbt_prepared: CircuitBootstrappingKeyPrepared, BRA, B> = cbt_key.prepare_alloc(&module, scratch.borrow()); - + let mut res: GGSW> = GGSW::alloc_from_infos(¶ms.ggsw_infos); + let mut cbt_prepared: CircuitBootstrappingKeyPrepared, BRA, BE> = + CircuitBootstrappingKeyPrepared::alloc_from_infos(&module, ¶ms.cbt_infos); + cbt_prepared.prepare(&module, &cbt_key, scratch.borrow()); move || { cbt_prepared.execute_to_constant( &module, @@ -238,13 +130,13 @@ where name: String::from("1-bit"), extension_factor: 1, k_pt: 1, - lwe_infos: LWECiphertextLayout { + lwe_infos: LWELayout { n: 574_u32.into(), k: 13_u32.into(), base2k: 13_u32.into(), }, block_size: 7, - ggsw_infos: GGSWCiphertextLayout { + ggsw_infos: GGSWLayout { n: 1024_u32.into(), base2k: 13_u32.into(), k: 26_u32.into(), @@ -261,7 +153,7 @@ where dnum: 3_u32.into(), rank: 2_u32.into(), }, - layout_atk: GGLWEAutomorphismKeyLayout { + layout_atk: GLWEAutomorphismKeyLayout { n: 1024_u32.into(), base2k: 13_u32.into(), k: 52_u32.into(), @@ -269,7 +161,7 @@ where dsize: Dsize(1), rank: 2_u32.into(), }, - layout_tsk: GGLWETensorKeyLayout { + layout_tsk: GLWETensorKeyLayout { n: 1024_u32.into(), base2k: 13_u32.into(), k: 52_u32.into(), @@ -280,7 +172,7 @@ where }, }] { let id: BenchmarkId = BenchmarkId::from_parameter(params.name.clone()); - let mut runner = runner::(¶ms); + let mut runner = runner::(¶ms); group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); } diff --git a/poulpy-schemes/examples/circuit_bootstrapping.rs b/poulpy-schemes/examples/circuit_bootstrapping.rs index 4fec699..a7c56bd 100644 --- a/poulpy-schemes/examples/circuit_bootstrapping.rs +++ b/poulpy-schemes/examples/circuit_bootstrapping.rs @@ -1,9 +1,9 @@ use poulpy_core::{ - GLWEOperations, + GLWENormalize, layouts::{ - GGLWEAutomorphismKeyLayout, GGLWETensorKeyLayout, GGSWCiphertext, GGSWCiphertextLayout, GLWECiphertext, - GLWECiphertextLayout, GLWEPlaintext, GLWESecret, LWECiphertext, LWECiphertextLayout, LWEInfos, LWEPlaintext, LWESecret, - prepared::{GGSWCiphertextPrepared, GLWESecretPrepared, PrepareAlloc}, + GGSW, GGSWLayout, GLWE, GLWEAutomorphismKeyLayout, GLWELayout, GLWEPlaintext, GLWESecret, GLWETensorKeyLayout, LWE, + LWEInfos, LWELayout, LWEPlaintext, LWESecret, + prepared::{GGSWPrepared, GLWESecretPrepared}, }, }; use std::time::Instant; @@ -22,10 +22,7 @@ use poulpy_hal::{ use poulpy_schemes::tfhe::{ blind_rotation::{BlindRotationKeyLayout, CGGI}, - circuit_bootstrapping::{ - CircuitBootstrappingKey, CircuitBootstrappingKeyEncryptSk, CircuitBootstrappingKeyLayout, - CircuitBootstrappingKeyPrepared, CirtuitBootstrappingExecute, - }, + circuit_bootstrapping::{CircuitBootstrappingKey, CircuitBootstrappingKeyLayout, CircuitBootstrappingKeyPrepared}, }; fn main() { @@ -89,7 +86,7 @@ fn main() { dnum: rows_brk.into(), rank: rank.into(), }, - layout_atk: GGLWEAutomorphismKeyLayout { + layout_atk: GLWEAutomorphismKeyLayout { n: n_glwe.into(), base2k: base2k.into(), k: k_trace.into(), @@ -97,7 +94,7 @@ fn main() { dsize: 1_u32.into(), rank: rank.into(), }, - layout_tsk: GGLWETensorKeyLayout { + layout_tsk: GLWETensorKeyLayout { n: n_glwe.into(), base2k: base2k.into(), k: k_tsk.into(), @@ -107,7 +104,7 @@ fn main() { }, }; - let ggsw_infos: GGSWCiphertextLayout = GGSWCiphertextLayout { + let ggsw_infos: GGSWLayout = GGSWLayout { n: n_glwe.into(), base2k: base2k.into(), k: k_ggsw_res.into(), @@ -116,7 +113,7 @@ fn main() { rank: rank.into(), }; - let lwe_infos = LWECiphertextLayout { + let lwe_infos = LWELayout { n: n_lwe.into(), k: k_lwe_ct.into(), base2k: base2k.into(), @@ -140,18 +137,19 @@ fn main() { sk_lwe.fill_zero(); // GLWE secret - let mut sk_glwe: GLWESecret> = GLWESecret::alloc_with(n_glwe.into(), rank.into()); + let mut sk_glwe: GLWESecret> = GLWESecret::alloc(n_glwe.into(), rank.into()); sk_glwe.fill_ternary_prob(0.5, &mut source_xs); // sk_glwe.fill_zero(); // GLWE secret prepared (opaque backend dependant write only struct) - let sk_glwe_prepared: GLWESecretPrepared, BackendImpl> = sk_glwe.prepare_alloc(&module, scratch.borrow()); + let mut sk_glwe_prepared: GLWESecretPrepared, BackendImpl> = GLWESecretPrepared::alloc(&module, rank.into()); + sk_glwe_prepared.prepare(&module, &sk_glwe); // Plaintext value to circuit bootstrap let data: i64 = 1 % (1 << k_lwe_pt); // LWE plaintext - let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc_with(base2k.into(), k_lwe_pt.into()); + let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc(base2k.into(), k_lwe_pt.into()); // LWE plaintext(data * 2^{- (k_lwe_pt - 1)}) pt_lwe.encode_i64(data, (k_lwe_pt + 1).into()); // +1 for padding bit @@ -167,7 +165,7 @@ fn main() { println!("pt_lwe: {pt_lwe}"); // LWE ciphertext - let mut ct_lwe: LWECiphertext> = LWECiphertext::alloc(&lwe_infos); + let mut ct_lwe: LWE> = LWE::alloc_from_infos(&lwe_infos); // Encrypt LWE Plaintext ct_lwe.encrypt_sk(&module, &pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe); @@ -175,23 +173,26 @@ fn main() { let now: Instant = Instant::now(); // Circuit bootstrapping evaluation key - let cbt_key: CircuitBootstrappingKey, CGGI> = CircuitBootstrappingKey::encrypt_sk( + let mut cbt_key: CircuitBootstrappingKey, CGGI> = CircuitBootstrappingKey::alloc_from_infos(&cbt_infos); + + cbt_key.encrypt_sk( &module, &sk_lwe, &sk_glwe, - &cbt_infos, &mut source_xa, &mut source_xe, scratch.borrow(), ); + println!("CBT-KGEN: {} ms", now.elapsed().as_millis()); // Output GGSW - let mut res: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_infos); + let mut res: GGSW> = GGSW::alloc_from_infos(&ggsw_infos); // Circuit bootstrapping key prepared (opaque backend dependant write only struct) - let cbt_prepared: CircuitBootstrappingKeyPrepared, CGGI, BackendImpl> = - cbt_key.prepare_alloc(&module, scratch.borrow()); + let mut cbt_prepared: CircuitBootstrappingKeyPrepared, CGGI, BackendImpl> = + CircuitBootstrappingKeyPrepared::alloc_from_infos(&module, &cbt_infos); + cbt_prepared.prepare(&module, &cbt_key, scratch.borrow()); // Apply circuit bootstrapping: LWE(data * 2^{- (k_lwe_pt + 2)}) -> GGSW(data) let now: Instant = Instant::now(); @@ -214,7 +215,7 @@ fn main() { // Tests RLWE(1) * GGSW(data) - let glwe_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_infos: GLWELayout = GLWELayout { n: n_glwe.into(), base2k: base2k.into(), k: (k_ggsw_res - base2k).into(), @@ -222,11 +223,11 @@ fn main() { }; // GLWE ciphertext modulus - let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(&glwe_infos); + let mut ct_glwe: GLWE> = GLWE::alloc_from_infos(&glwe_infos); // Some GLWE plaintext with signed data let k_glwe_pt: usize = 3; - let mut pt_glwe: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_infos); + let mut pt_glwe: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_infos); let mut data_vec: Vec = vec![0i64; n_glwe]; data_vec .iter_mut() @@ -234,7 +235,7 @@ fn main() { .for_each(|(x, y)| *y = (x % (1 << (k_glwe_pt - 1))) as i64 - (1 << (k_glwe_pt - 2))); pt_glwe.encode_vec_i64(&data_vec, (k_lwe_pt + 2).into()); - pt_glwe.normalize_inplace(&module, scratch.borrow()); + module.glwe_normalize_inplace(&mut pt_glwe, scratch.borrow()); println!("{}", pt_glwe); @@ -249,13 +250,14 @@ fn main() { ); // Prepare GGSW output of circuit bootstrapping (opaque backend dependant write only struct) - let res_prepared: GGSWCiphertextPrepared, BackendImpl> = res.prepare_alloc(&module, scratch.borrow()); + let mut res_prepared: GGSWPrepared, BackendImpl> = GGSWPrepared::alloc_from_infos(&module, &res); + res_prepared.prepare(&module, &res, scratch.borrow()); // Apply GLWE x GGSW ct_glwe.external_product_inplace(&module, &res_prepared, scratch.borrow()); // Decrypt - let mut pt_res: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_infos); + let mut pt_res: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_infos); ct_glwe.decrypt(&module, &mut pt_res, &sk_glwe_prepared, scratch.borrow()); println!("pt_res: {:?}", &pt_res.data.at(0, 0)[..64]); diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs index 938c45e..db4a1de 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs @@ -1,78 +1,46 @@ use itertools::Itertools; -use poulpy_core::layouts::prepared::GGSWCiphertextPreparedToRef; +use poulpy_core::layouts::prepared::GGSWPreparedToRef; use poulpy_hal::layouts::{Backend, DataMut, DataRef, Module, Scratch}; use crate::tfhe::bdd_arithmetic::{ - BitCircuitInfo, Circuit, CircuitExecute, FheUintBlocks, FheUintBlocksPrep, UnsignedInteger, circuits, + ExecuteBDDCircuit, FheUintBlocks, FheUintBlocksPrepared, GetBitCircuitInfo, UnsignedInteger, circuits, }; -/// Operations Z x Z -> Z -pub(crate) struct Circuits2WTo1W(pub &'static Circuit); +impl ExecuteBDDCircuit2WTo1W for Module where Self: Sized + ExecuteBDDCircuit {} -pub trait EvalBDD2WTo1W { - fn eval_bdd_2w_to_1w( - &self, - module: &Module, - out: &mut FheUintBlocks, - a: &FheUintBlocksPrep, - b: &FheUintBlocksPrep, - scratch: &mut Scratch, - ) where - R: DataMut, - A: DataRef, - B: DataRef; -} - -impl EvalBDD2WTo1W - for Circuits2WTo1W +pub trait ExecuteBDDCircuit2WTo1W where - Circuit: CircuitExecute, + Self: Sized + ExecuteBDDCircuit, { - fn eval_bdd_2w_to_1w( + /// Operations Z x Z -> Z + fn execute_bdd_circuit_2w_to_1w( &self, - module: &Module, out: &mut FheUintBlocks, - a: &FheUintBlocksPrep, - b: &FheUintBlocksPrep, + circuit: &C, + a: &FheUintBlocksPrepared, + b: &FheUintBlocksPrepared, scratch: &mut Scratch, ) where + C: GetBitCircuitInfo, R: DataMut, A: DataRef, B: DataRef, - { - eval_bdd_2w_to_1w(module, self.0, out, a, b, scratch); - } -} - -pub fn eval_bdd_2w_to_1w, BE: Backend>( - module: &Module, - circuit: &C, - out: &mut FheUintBlocks, - a: &FheUintBlocksPrep, - b: &FheUintBlocksPrep, - scratch: &mut Scratch, -) { - #[cfg(debug_assertions)] { assert_eq!(out.blocks.len(), T::WORD_SIZE); assert_eq!(b.blocks.len(), T::WORD_SIZE); assert_eq!(b.blocks.len(), T::WORD_SIZE); + + // Collects inputs into a single array + let inputs: Vec<&dyn GGSWPreparedToRef> = a + .blocks + .iter() + .map(|x| x as &dyn GGSWPreparedToRef) + .chain(b.blocks.iter().map(|x| x as &dyn GGSWPreparedToRef)) + .collect_vec(); + + // Evaluates out[i] = circuit[i](a, b) + self.execute_bdd_circuit(&mut out.blocks, &inputs, circuit, scratch); } - - // Collects inputs into a single array - let inputs: Vec<&dyn GGSWCiphertextPreparedToRef> = a - .blocks - .iter() - .map(|x| x as &dyn GGSWCiphertextPreparedToRef) - .chain( - b.blocks - .iter() - .map(|x| x as &dyn GGSWCiphertextPreparedToRef), - ) - .collect_vec(); - - // Evaluates out[i] = circuit[i](a, b) - circuit.execute(module, &mut out.blocks, &inputs, scratch); } #[macro_export] @@ -80,13 +48,14 @@ macro_rules! define_bdd_2w_to_1w_trait { ($(#[$meta:meta])* $vis:vis $trait_name:ident, $method_name:ident) => { $(#[$meta])* $vis trait $trait_name { - fn $method_name( + fn $method_name( &mut self, - module: &Module, - a: &FheUintBlocksPrep, - b: &FheUintBlocksPrep, + module: &M, + a: &FheUintBlocksPrepared, + b: &FheUintBlocksPrepared, scratch: &mut Scratch, ) where + M: ExecuteBDDCircuit2WTo1W, A: DataRef, B: DataRef; } @@ -96,23 +65,19 @@ macro_rules! define_bdd_2w_to_1w_trait { #[macro_export] macro_rules! impl_bdd_2w_to_1w_trait { ($trait_name:ident, $method_name:ident, $ty:ty, $n:literal, $circuit_ty:ty, $output_circuits:path) => { - impl $trait_name<$ty, BE> for FheUintBlocks - where - Circuits2WTo1W<$circuit_ty, $n>: EvalBDD2WTo1W, - { - fn $method_name( + impl $trait_name<$ty, BE> for FheUintBlocks { + fn $method_name( &mut self, - module: &Module, - a: &FheUintBlocksPrep, - b: &FheUintBlocksPrep, + module: &M, + a: &FheUintBlocksPrepared, + b: &FheUintBlocksPrepared, scratch: &mut Scratch, ) where + M: ExecuteBDDCircuit2WTo1W<$ty, BE>, A: DataRef, B: DataRef, { - const OP: Circuits2WTo1W<$circuit_ty, $n> = Circuits2WTo1W::<$circuit_ty, $n>(&$output_circuits); - - OP.eval_bdd_2w_to_1w(module, self, a, b, scratch); + module.execute_bdd_circuit_2w_to_1w(self, &$output_circuits, a, b, scratch) } } }; diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block.rs index 145ce6b..82d985e 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block.rs @@ -1,30 +1,23 @@ use std::marker::PhantomData; -use poulpy_core::layouts::{Base2K, GLWECiphertext, GLWEInfos, GLWEPlaintextLayout, LWEInfos, Rank, TorusPrecision}; - -use poulpy_core::{TakeGLWEPt, layouts::prepared::GLWESecretPrepared}; -use poulpy_hal::api::VecZnxBigAllocBytes; -#[cfg(test)] -use poulpy_hal::api::{ - ScratchAvailable, TakeVecZnx, VecZnxAddInplace, VecZnxAddNormal, VecZnxFillUniform, VecZnxNormalize, VecZnxSub, +use poulpy_core::{ + GLWEDecrypt, GLWENoise, + layouts::{Base2K, GLWE, GLWEInfos, GLWEPlaintextLayout, GLWESecretPreparedToRef, LWEInfos, Rank, TorusPrecision}, }; + +#[cfg(test)] +use poulpy_core::GLWEEncryptSk; +use poulpy_core::ScratchTakeCore; +use poulpy_hal::layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}; +#[cfg(test)] #[cfg(test)] use poulpy_hal::source::Source; -use poulpy_hal::{ - api::{ - TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, - VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalizeTmpBytes, - }, - layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}, -}; - -use poulpy_hal::api::{SvpApplyDftToDftInplace, VecZnxNormalizeInplace, VecZnxSubInplace}; use crate::tfhe::bdd_arithmetic::{FromBits, ToBits, UnsignedInteger}; /// An FHE ciphertext encrypting the bits of an [UnsignedInteger]. pub struct FheUintBlocks { - pub(crate) blocks: Vec>, + pub(crate) blocks: Vec>, pub(crate) _base: u8, pub(crate) _phantom: PhantomData, } @@ -62,7 +55,7 @@ impl FheUintBlocks, T> { pub(crate) fn alloc_with(module: &Module, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self { Self { blocks: (0..T::WORD_SIZE) - .map(|_| GLWECiphertext::alloc_with(module.n().into(), base2k, k, rank)) + .map(|_| GLWE::alloc(module.n().into(), base2k, k, rank)) .collect(), _base: 1, _phantom: PhantomData, @@ -77,26 +70,14 @@ impl FheUintBlocks { &mut self, module: &Module, value: T, - sk: &GLWESecretPrepared, + sk: &S, source_xa: &mut Source, source_xe: &mut Source, scratch: &mut Scratch, ) where - S: DataRef, - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeGLWEPt, + S: GLWESecretPreparedToRef + GLWEInfos, + Module: GLWEEncryptSk, + Scratch: ScratchTakeCore, { use poulpy_core::layouts::GLWEPlaintextLayout; @@ -113,30 +94,21 @@ impl FheUintBlocks { k: 1_usize.into(), }; - let (mut pt, scratch_1) = scratch.take_glwe_pt(&pt_infos); + let (mut pt, scratch_1) = scratch.take_glwe_plaintext(&pt_infos); for i in 0..T::WORD_SIZE { - pt.encode_coeff_i64(value.bit(i) as i64, TorusPrecision(1), 0); - self.blocks[i].encrypt_sk(&module, &pt, sk, source_xa, source_xe, scratch_1); + pt.encode_coeff_i64(value.bit(i) as i64, TorusPrecision(2), 0); + self.blocks[i].encrypt_sk(module, &pt, sk, source_xa, source_xe, scratch_1); } } } impl FheUintBlocks { - pub fn decrypt( - &self, - module: &Module, - sk: &GLWESecretPrepared, - scratch: &mut Scratch, - ) -> T + pub fn decrypt(&self, module: &Module, sk: &S, scratch: &mut Scratch) -> T where - Module: VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize, - Scratch: TakeVecZnxDft + TakeVecZnxBig + TakeGLWEPt, + Module: GLWEDecrypt, + S: GLWESecretPreparedToRef + GLWEInfos, + Scratch: ScratchTakeCore, { #[cfg(debug_assertions)] { @@ -151,7 +123,7 @@ impl FheUintBlocks { k: self.k(), }; - let (mut pt, scratch_1) = scratch.take_glwe_pt(&pt_infos); + let (mut pt, scratch_1) = scratch.take_glwe_plaintext(&pt_infos); let mut bits: Vec = vec![0u8; T::WORD_SIZE]; @@ -167,26 +139,11 @@ impl FheUintBlocks { T::from_bits(&bits) } - pub fn noise( - &self, - module: &Module, - sk: &GLWESecretPrepared, - want: T, - scratch: &mut Scratch, - ) -> Vec + pub fn noise(&self, module: &Module, sk: &S, want: T, scratch: &mut Scratch) -> Vec where - Module: VecZnxDftAllocBytes - + VecZnxBigAllocBytes - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalizeTmpBytes - + VecZnxSubInplace - + VecZnxNormalizeInplace, - Scratch: TakeGLWEPt + TakeVecZnxDft + TakeVecZnxBig, + Module: GLWENoise, + S: GLWESecretPreparedToRef + GLWEInfos, + Scratch: ScratchTakeCore, { #[cfg(debug_assertions)] { @@ -201,7 +158,7 @@ impl FheUintBlocks { k: 1_usize.into(), }; - let (mut pt_want, scratch_1) = scratch.take_glwe_pt(&pt_infos); + let (mut pt_want, scratch_1) = scratch.take_glwe_plaintext(&pt_infos); let mut noise: Vec = vec![0f64; T::WORD_SIZE]; diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block_debug.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block_debug.rs new file mode 100644 index 0000000..f291c62 --- /dev/null +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block_debug.rs @@ -0,0 +1,147 @@ +use std::marker::PhantomData; + +use crate::tfhe::bdd_arithmetic::{BDDKeyPrepared, FheUintBlockDebugPrepare, ToBits}; +use crate::tfhe::{ + bdd_arithmetic::{FheUintBlocks, UnsignedInteger}, + blind_rotation::BlindRotationAlgo, + circuit_bootstrapping::CirtuitBootstrappingExecute, +}; +use poulpy_core::GGSWNoise; +#[cfg(test)] +use poulpy_core::layouts::{Base2K, Dnum, Dsize, Rank, TorusPrecision}; +use poulpy_core::layouts::{GGSW, GLWESecretPreparedToRef}; +use poulpy_core::{ + LWEFromGLWE, ScratchTakeCore, + layouts::{GGSWInfos, GGSWPreparedFactory, GLWEInfos, LWE, LWEInfos}, +}; +#[cfg(test)] +use poulpy_hal::api::ModuleN; +use poulpy_hal::layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}; + +pub(crate) struct FheUintBlocksPreparedDebug { + pub(crate) blocks: Vec>, + pub(crate) _base: u8, + pub(crate) _phantom: PhantomData, +} + +#[cfg(test)] +impl FheUintBlocksPreparedDebug, T> { + #[allow(dead_code)] + pub(crate) fn alloc(module: &M, infos: &A) -> Self + where + M: ModuleN, + A: GGSWInfos, + { + Self::alloc_with( + module, + infos.base2k(), + infos.k(), + infos.dnum(), + infos.dsize(), + infos.rank(), + ) + } + + #[allow(dead_code)] + pub(crate) fn alloc_with(module: &M, base2k: Base2K, k: TorusPrecision, dnum: Dnum, dsize: Dsize, rank: Rank) -> Self + where + M: ModuleN, + { + Self { + blocks: (0..T::WORD_SIZE) + .map(|_| GGSW::alloc(module.n().into(), base2k, k, rank, dnum, dsize)) + .collect(), + _base: 1, + _phantom: PhantomData, + } + } +} + +impl LWEInfos for FheUintBlocksPreparedDebug { + fn base2k(&self) -> poulpy_core::layouts::Base2K { + self.blocks[0].base2k() + } + + fn k(&self) -> poulpy_core::layouts::TorusPrecision { + self.blocks[0].k() + } + + fn n(&self) -> poulpy_core::layouts::Degree { + self.blocks[0].n() + } +} + +impl GLWEInfos for FheUintBlocksPreparedDebug { + fn rank(&self) -> poulpy_core::layouts::Rank { + self.blocks[0].rank() + } +} + +impl GGSWInfos for FheUintBlocksPreparedDebug { + fn dsize(&self) -> poulpy_core::layouts::Dsize { + self.blocks[0].dsize() + } + + fn dnum(&self) -> poulpy_core::layouts::Dnum { + self.blocks[0].dnum() + } +} + +impl FheUintBlocksPreparedDebug { + pub(crate) fn noise(&self, module: &M, sk: &S, want: T) + where + S: GLWESecretPreparedToRef, + M: GGSWNoise, + { + for (i, ggsw) in self.blocks.iter().enumerate() { + use poulpy_hal::layouts::{ScalarZnx, ZnxViewMut}; + let mut pt_want = ScalarZnx::alloc(self.n().into(), 1); + pt_want.at_mut(0, 0)[0] = want.bit(i) as i64; + ggsw.print_noise(module, sk, &pt_want); + } + } +} + +impl FheUintBlockDebugPrepare for Module +where + Self: LWEFromGLWE + CirtuitBootstrappingExecute + GGSWPreparedFactory, + Scratch: ScratchTakeCore, +{ + fn fhe_uint_block_debug_prepare( + &self, + res: &mut FheUintBlocksPreparedDebug, + bits: &FheUintBlocks, + key: &BDDKeyPrepared, + scratch: &mut Scratch, + ) where + DM: DataMut, + DR0: DataRef, + DR1: DataRef, + { + assert_eq!(res.blocks.len(), bits.blocks.len()); + + let mut lwe: LWE> = LWE::alloc_from_infos(&bits.blocks[0]); //TODO: add TakeLWE + for (dst, src) in res.blocks.iter_mut().zip(bits.blocks.iter()) { + lwe.from_glwe(self, src, &key.ks, scratch); + key.cbt.execute_to_constant(self, dst, &lwe, 1, 1, scratch); + } + } +} + +impl FheUintBlocksPreparedDebug { + pub fn prepare( + &mut self, + module: &M, + other: &FheUintBlocks, + key: &BDDKeyPrepared, + scratch: &mut Scratch, + ) where + BRA: BlindRotationAlgo, + O: DataRef, + K: DataRef, + M: FheUintBlockDebugPrepare, + Scratch: ScratchTakeCore, + { + module.fhe_uint_block_debug_prepare(self, other, key, scratch); + } +} diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block_prepared.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block_prepared.rs index 3cd50a6..3753f99 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block_prepared.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block_prepared.rs @@ -1,93 +1,60 @@ use std::marker::PhantomData; use poulpy_core::layouts::{ - Base2K, Dnum, Dsize, GGSWInfos, GLWEInfos, LWEInfos, Rank, TorusPrecision, prepared::GGSWCiphertextPrepared, + Base2K, Dnum, Dsize, GGSWInfos, GGSWPreparedFactory, GLWEInfos, LWEInfos, Rank, TorusPrecision, prepared::GGSWPrepared, }; #[cfg(test)] -use poulpy_core::{ - TakeGGSW, - layouts::{GGSWCiphertext, prepared::GLWESecretPrepared}, -}; -use poulpy_hal::{ - api::VmpPMatAlloc, - layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}, -}; +use poulpy_core::{GGSWEncryptSk, ScratchTakeCore, layouts::GLWESecretPreparedToRef}; +use poulpy_hal::layouts::{Backend, Data, DataRef, Module}; #[cfg(test)] use poulpy_hal::{ - api::{ - ScratchAvailable, SvpApplyDftToDftInplace, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, - VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAllocBytes, - VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, - VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, - VecZnxSubInplace, VmpPrepare, - }, - oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl}, + api::ModuleN, + layouts::{DataMut, Scratch}, source::Source, }; -use crate::tfhe::bdd_arithmetic::{FheUintBlocks, FheUintPrepare, ToBits, UnsignedInteger}; - #[cfg(test)] -pub(crate) struct FheUintBlocksPrepDebug { - pub(crate) blocks: Vec>, +use crate::tfhe::bdd_arithmetic::ToBits; +use crate::tfhe::bdd_arithmetic::UnsignedInteger; + +/// A prepared FHE ciphertext encrypting the bits of an [UnsignedInteger]. +pub struct FheUintBlocksPrepared { + pub(crate) blocks: Vec>, pub(crate) _base: u8, pub(crate) _phantom: PhantomData, } -#[cfg(test)] -impl FheUintBlocksPrepDebug, T> { - #[allow(dead_code)] - pub(crate) fn alloc(module: &Module, infos: &A) -> Self - where - A: GGSWInfos, - { - Self::alloc_with( - module, - infos.base2k(), - infos.k(), - infos.dnum(), - infos.dsize(), - infos.rank(), - ) - } +impl FheUintBlocksPreparedFactory for Module where + Self: Sized + GGSWPreparedFactory +{ +} - #[allow(dead_code)] - pub(crate) fn alloc_with( - module: &Module, +pub trait FheUintBlocksPreparedFactory +where + Self: Sized + GGSWPreparedFactory, +{ + fn alloc_fhe_uint_blocks_prepared( + &self, base2k: Base2K, k: TorusPrecision, dnum: Dnum, dsize: Dsize, rank: Rank, - ) -> Self { - Self { + ) -> FheUintBlocksPrepared, T, BE> { + FheUintBlocksPrepared { blocks: (0..T::WORD_SIZE) - .map(|_| GGSWCiphertext::alloc_with(module.n().into(), base2k, k, rank, dnum, dsize)) + .map(|_| GGSWPrepared::alloc(self, base2k, k, dnum, dsize, rank)) .collect(), _base: 1, _phantom: PhantomData, } } -} -/// A prepared FHE ciphertext encrypting the bits of an [UnsignedInteger]. -pub struct FheUintBlocksPrep { - pub(crate) blocks: Vec>, - pub(crate) _base: u8, - pub(crate) _phantom: PhantomData, -} - -impl FheUintBlocksPrep, BE, T> -where - Module: VmpPMatAlloc, -{ - #[allow(dead_code)] - pub(crate) fn alloc(module: &Module, infos: &A) -> Self + fn alloc_fhe_uint_blocks_prepared_from_infos(&self, infos: &A) -> FheUintBlocksPrepared, T, BE> where A: GGSWInfos, { - Self::alloc_with( - module, + self.alloc_fhe_uint_blocks_prepared( infos.base2k(), infos.k(), infos.dnum(), @@ -95,130 +62,90 @@ where infos.rank(), ) } +} + +impl FheUintBlocksPrepared, T, BE> { + #[allow(dead_code)] + pub(crate) fn alloc(module: &M, infos: &A) -> Self + where + A: GGSWInfos, + M: FheUintBlocksPreparedFactory, + { + module.alloc_fhe_uint_blocks_prepared_from_infos(infos) + } #[allow(dead_code)] - pub(crate) fn alloc_with(module: &Module, base2k: Base2K, k: TorusPrecision, dnum: Dnum, dsize: Dsize, rank: Rank) -> Self + pub(crate) fn alloc_with(module: &M, base2k: Base2K, k: TorusPrecision, dnum: Dnum, dsize: Dsize, rank: Rank) -> Self where - Module: VmpPMatAlloc, + M: FheUintBlocksPreparedFactory, { - Self { - blocks: (0..T::WORD_SIZE) - .map(|_| GGSWCiphertextPrepared::alloc_with(module, base2k, k, dnum, dsize, rank)) - .collect(), - _base: 1, - _phantom: PhantomData, - } + module.alloc_fhe_uint_blocks_prepared(base2k, k, dnum, dsize, rank) } } -impl FheUintBlocksPrep { - #[allow(dead_code)] - #[cfg(test)] - pub(crate) fn encrypt_sk( - &mut self, - module: &Module, +#[cfg(test)] +impl FheUintBlocksPreparedEncryptSk for Module where + Self: Sized + ModuleN + GGSWEncryptSk + GGSWPreparedFactory +{ +} + +#[cfg(test)] +pub trait FheUintBlocksPreparedEncryptSk +where + Self: Sized + ModuleN + GGSWEncryptSk + GGSWPreparedFactory, +{ + fn fhe_uint_blocks_prepared_encrypt_sk( + &self, + res: &mut FheUintBlocksPrepared, value: T, - sk: &GLWESecretPrepared, + sk: &S, source_xa: &mut Source, source_xe: &mut Source, scratch: &mut Scratch, ) where - S: DataRef, - Module: VecZnxAddScalarInplace - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + VmpPrepare, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeGGSW + TakeScalarZnx, + DM: DataMut, + S: GLWESecretPreparedToRef + GLWEInfos, + Scratch: ScratchTakeCore, { - #[cfg(debug_assertions)] - { - assert!(module.n().is_multiple_of(T::WORD_SIZE)); - assert_eq!(self.n(), module.n() as u32); - assert_eq!(sk.n(), module.n() as u32); - } + use poulpy_hal::api::ScratchTakeBasic; - let (mut tmp_ggsw, scratch_1) = scratch.take_ggsw(self); - let (mut pt, scratch_2) = scratch_1.take_scalar_znx(module.n(), 1); + assert!(self.n().is_multiple_of(T::WORD_SIZE)); + assert_eq!(res.n(), self.n() as u32); + assert_eq!(sk.n(), self.n() as u32); + + let (mut tmp_ggsw, scratch_1) = scratch.take_ggsw(res); + let (mut pt, scratch_2) = scratch_1.take_scalar_znx(self.n(), 1); for i in 0..T::WORD_SIZE { - use poulpy_core::layouts::prepared::Prepare; use poulpy_hal::layouts::ZnxViewMut; pt.at_mut(0, 0)[0] = value.bit(i) as i64; - tmp_ggsw.encrypt_sk(&module, &pt, sk, source_xa, source_xe, scratch_2); - self.blocks[i].prepare(module, &tmp_ggsw, scratch_2); + tmp_ggsw.encrypt_sk(self, &pt, sk, source_xa, source_xe, scratch_2); + res.blocks[i].prepare(self, &tmp_ggsw, scratch_2); } } - - /// Prepares [FheUintBits] to [FheUintBitsPrep]. - pub fn prepare(&mut self, module: &Module, bits: &FheUintBlocks, key: &KEY, scratch: &mut Scratch) - where - BIT: DataRef, - KEY: FheUintPrepare, FheUintBlocks>, - { - key.prepare(module, self, bits, scratch); - } } #[cfg(test)] -impl FheUintBlocksPrepDebug { - pub(crate) fn prepare( +impl FheUintBlocksPrepared { + pub(crate) fn encrypt_sk( &mut self, - module: &Module, - bits: &FheUintBlocks, - key: &KEY, + module: &M, + value: T, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, scratch: &mut Scratch, ) where - BIT: DataRef, - KEY: FheUintPrepare, FheUintBlocks>, + S: GLWESecretPreparedToRef + GLWEInfos, + M: FheUintBlocksPreparedEncryptSk, + Scratch: ScratchTakeCore, { - key.prepare(module, self, bits, scratch); + module.fhe_uint_blocks_prepared_encrypt_sk(self, value, sk, source_xa, source_xe, scratch); } } -#[cfg(test)] -impl FheUintBlocksPrepDebug { - #[allow(dead_code)] - pub(crate) fn noise(&self, module: &Module, sk: &GLWESecretPrepared, want: T) - where - Module: VecZnxDftAllocBytes - + VecZnxBigAllocBytes - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalizeTmpBytes - + VecZnxBigAlloc - + VecZnxDftAlloc - + VecZnxBigNormalizeTmpBytes - + VecZnxIdftApplyTmpA - + VecZnxAddScalarInplace - + VecZnxSubInplace, - BE: Backend + TakeVecZnxDftImpl + TakeVecZnxBigImpl + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, - { - for (i, ggsw) in self.blocks.iter().enumerate() { - use poulpy_hal::layouts::{ScalarZnx, ZnxViewMut}; - let mut pt_want = ScalarZnx::alloc(self.n().into(), 1); - pt_want.at_mut(0, 0)[0] = want.bit(i) as i64; - ggsw.print_noise(module, sk, &pt_want); - } - } -} - -impl LWEInfos for FheUintBlocksPrep { +impl LWEInfos for FheUintBlocksPrepared { fn base2k(&self) -> poulpy_core::layouts::Base2K { self.blocks[0].base2k() } @@ -232,46 +159,13 @@ impl LWEInfos for FheUintBlocksPrep< } } -impl GLWEInfos for FheUintBlocksPrep { +impl GLWEInfos for FheUintBlocksPrepared { fn rank(&self) -> poulpy_core::layouts::Rank { self.blocks[0].rank() } } -impl GGSWInfos for FheUintBlocksPrep { - fn dsize(&self) -> poulpy_core::layouts::Dsize { - self.blocks[0].dsize() - } - - fn dnum(&self) -> poulpy_core::layouts::Dnum { - self.blocks[0].dnum() - } -} - -#[cfg(test)] -impl LWEInfos for FheUintBlocksPrepDebug { - fn base2k(&self) -> poulpy_core::layouts::Base2K { - self.blocks[0].base2k() - } - - fn k(&self) -> poulpy_core::layouts::TorusPrecision { - self.blocks[0].k() - } - - fn n(&self) -> poulpy_core::layouts::Degree { - self.blocks[0].n() - } -} - -#[cfg(test)] -impl GLWEInfos for FheUintBlocksPrepDebug { - fn rank(&self) -> poulpy_core::layouts::Rank { - self.blocks[0].rank() - } -} - -#[cfg(test)] -impl GGSWInfos for FheUintBlocksPrepDebug { +impl GGSWInfos for FheUintBlocksPrepared { fn dsize(&self) -> poulpy_core::layouts::Dsize { self.blocks[0].dsize() } diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/mod.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/mod.rs index 8b51045..b8054d7 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/mod.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/mod.rs @@ -2,6 +2,11 @@ mod block; mod block_prepared; mod word; +#[cfg(test)] +mod block_debug; +#[cfg(test)] +pub(crate) use block_debug::*; + pub use block::*; pub use block_prepared::*; pub use word::*; diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/word.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/word.rs index cc754bc..1d3e218 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/word.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/word.rs @@ -1,22 +1,14 @@ use itertools::Itertools; use poulpy_core::{ - GLWEOperations, TakeGLWECtSlice, TakeGLWEPt, glwe_packing, + GLWECopy, GLWEDecrypt, GLWEEncryptSk, GLWEPacking, ScratchTakeCore, layouts::{ - GLWECiphertext, GLWEInfos, GLWEPlaintextLayout, LWEInfos, TorusPrecision, - prepared::{GGLWEAutomorphismKeyPrepared, GLWESecretPrepared}, + GLWE, GLWEInfos, GLWEPlaintextLayout, GLWESecretPreparedToRef, LWEInfos, TorusPrecision, + prepared::GLWEAutomorphismKeyPrepared, }, }; use poulpy_hal::{ - api::{ - ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, - VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, - VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxCopy, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, - VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, - VecZnxRshInplace, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, - VmpApplyDftToDftTmpBytes, - }, - layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}, + api::ModuleN, + layouts::{Backend, Data, DataMut, DataRef, Scratch}, source::Source, }; use std::{collections::HashMap, marker::PhantomData}; @@ -24,57 +16,33 @@ use std::{collections::HashMap, marker::PhantomData}; use crate::tfhe::bdd_arithmetic::{FromBits, ToBits, UnsignedInteger}; /// A FHE ciphertext encrypting a [UnsignedInteger]. -pub struct FheUintWord(pub(crate) GLWECiphertext, pub(crate) PhantomData); +pub struct FheUintWord(pub(crate) GLWE, pub(crate) PhantomData); impl FheUintWord { #[allow(dead_code)] - fn post_process( + fn post_process( &mut self, - module: &Module, - mut tmp_res: Vec>, - auto_keys: &HashMap>, + module: &M, + mut tmp_res: Vec>, + auto_keys: &HashMap>, scratch: &mut Scratch, ) where ATK: DataRef, - Module: VecZnxSub - + VecZnxCopy - + VecZnxNegateInplace - + VecZnxDftAllocBytes - + VecZnxAddInplace - + VmpApplyDftToDftTmpBytes - + VecZnxNormalizeTmpBytes - + VecZnxDftApply - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxIdftApplyConsume - + VecZnxBigNormalize - + VecZnxNormalize - + VecZnxRotateInplace - + VecZnxNormalizeInplace - + VecZnxSwitchRing - + VecZnxBigAutomorphismInplace - + VecZnxRshInplace - + VecZnxDftCopy - + VecZnxIdftApplyTmpA - + VecZnxSubInplace - + VecZnxBigNormalizeTmpBytes - + VecZnxBigAddSmallInplace - + VecZnxAutomorphismInplace - + VecZnxBigSubSmallNegateInplace - + VecZnxRotate, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeGLWECtSlice, + M: GLWEPacking + GLWECopy, + Scratch: ScratchTakeCore, { // Repacks the GLWE ciphertexts bits let gap: usize = module.n() / T::WORD_SIZE; let log_gap: usize = (usize::BITS - (gap - 1).leading_zeros()) as usize; - let mut cts: HashMap> = HashMap::new(); + let mut cts: HashMap> = HashMap::new(); for (i, ct) in tmp_res.iter_mut().enumerate().take(T::WORD_SIZE) { cts.insert(i * gap, ct); } - glwe_packing(module, &mut cts, log_gap, auto_keys, scratch); + + module.glwe_pack(&mut cts, log_gap, auto_keys, scratch); // And copies the repacked ciphertext on the receiver. - self.0.copy(module, cts.remove(&0).unwrap()) + module.glwe_copy(&mut self.0, cts.remove(&0).unwrap()); } } @@ -99,30 +67,18 @@ impl GLWEInfos for FheUintWord { } impl FheUintWord { - pub fn encrypt_sk( + pub fn encrypt_sk( &mut self, - module: &Module, + module: &M, data: T, - sk: &GLWESecretPrepared, + sk: &S, source_xa: &mut Source, source_xe: &mut Source, scratch: &mut Scratch, ) where - Module: VecZnxAddScalarInplace - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeGLWEPt, + S: GLWESecretPreparedToRef + GLWEInfos, + M: ModuleN + GLWEEncryptSk, + Scratch: ScratchTakeCore, { #[cfg(debug_assertions)] { @@ -145,7 +101,7 @@ impl FheUintWord { k: 1_usize.into(), }; - let (mut pt, scratch_1) = scratch.take_glwe_pt(&pt_infos); + let (mut pt, scratch_1) = scratch.take_glwe_plaintext(&pt_infos); pt.encode_vec_i64(&data_bits, TorusPrecision(1)); self.0 @@ -154,20 +110,11 @@ impl FheUintWord { } impl FheUintWord { - pub fn decrypt( - &self, - module: &Module, - sk: &GLWESecretPrepared, - scratch: &mut Scratch, - ) -> T + pub fn decrypt(&self, module: &M, sk: &S, scratch: &mut Scratch) -> T where - Module: VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize, - Scratch: TakeVecZnxDft + TakeVecZnxBig + TakeGLWEPt, + S: GLWESecretPreparedToRef + GLWEInfos, + M: GLWEDecrypt, + Scratch: ScratchTakeCore, { #[cfg(debug_assertions)] { @@ -184,7 +131,7 @@ impl FheUintWord { k: 1_usize.into(), }; - let (mut pt, scratch_1) = scratch.take_glwe_pt(&pt_infos); + let (mut pt, scratch_1) = scratch.take_glwe_plaintext(&pt_infos); self.0.decrypt(module, &mut pt, sk, scratch_1); diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs index 98e1083..2d5d086 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs @@ -1,15 +1,12 @@ use itertools::Itertools; use poulpy_core::{ - GLWEExternalProductInplace, GLWEOperations, TakeGLWECtSlice, + GLWEAdd, GLWECopy, GLWEExternalProduct, GLWESub, ScratchTakeCore, layouts::{ - GLWECiphertext, GLWECiphertextToMut, LWEInfos, - prepared::{GGSWCiphertextPrepared, GGSWCiphertextPreparedToRef}, + GLWE, LWEInfos, + prepared::{GGSWPrepared, GGSWPreparedToRef}, }, }; -use poulpy_hal::{ - api::{VecZnxAddInplace, VecZnxCopy, VecZnxNegateInplace, VecZnxSub}, - layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero}, -}; +use poulpy_hal::layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero}; use crate::tfhe::bdd_arithmetic::UnsignedInteger; @@ -31,45 +28,43 @@ pub(crate) struct BitCircuit { pub struct Circuit(pub [C; N]); -pub trait CircuitExecute -where - Self: GetBitCircuitInfo, -{ - fn execute( +pub trait ExecuteBDDCircuit { + fn execute_bdd_circuit( &self, - module: &Module, - out: &mut [GLWECiphertext], - inputs: &[&dyn GGSWCiphertextPreparedToRef], + out: &mut [GLWE], + inputs: &[&dyn GGSWPreparedToRef], + circuit: &C, scratch: &mut Scratch, ) where + C: GetBitCircuitInfo, O: DataMut; } -impl CircuitExecute for Circuit +impl ExecuteBDDCircuit for Module where - Self: GetBitCircuitInfo, - Module: Cmux + VecZnxCopy, - Scratch: TakeGLWECtSlice, + Self: Cmux + GLWECopy, + Scratch: ScratchTakeCore, { - fn execute( + fn execute_bdd_circuit( &self, - module: &Module, - out: &mut [GLWECiphertext], - inputs: &[&dyn GGSWCiphertextPreparedToRef], + out: &mut [GLWE], + inputs: &[&dyn GGSWPreparedToRef], + circuit: &C, scratch: &mut Scratch, ) where + C: GetBitCircuitInfo, O: DataMut, { #[cfg(debug_assertions)] { - assert_eq!(inputs.len(), self.input_size()); - assert!(out.len() >= self.output_size()); + assert_eq!(inputs.len(), circuit.input_size()); + assert!(out.len() >= circuit.output_size()); } - for (i, out_i) in out.iter_mut().enumerate().take(self.output_size()) { - let (nodes, levels, max_inter_state) = self.get_circuit(i); + for (i, out_i) in out.iter_mut().enumerate().take(circuit.output_size()) { + let (nodes, levels, max_inter_state) = circuit.get_circuit(i); - let (mut level, scratch_1) = scratch.take_glwe_ct_slice(max_inter_state * 2, out_i); + let (mut level, scratch_1) = scratch.take_glwe_slice(max_inter_state * 2, out_i); level.iter_mut().for_each(|ct| ct.data_mut().zero()); @@ -89,9 +84,9 @@ where for (j, node) in nodes_lvl.iter().enumerate() { if node.low_index == node.high_index { - next_level[j].copy(module, prev_level[node.low_index]); + self.glwe_copy(next_level[j], prev_level[node.low_index]); } else { - module.cmux( + self.cmux( next_level[j], prev_level[node.high_index], prev_level[node.low_index], @@ -107,7 +102,7 @@ where // handle last output // there's always only 1 node at last level let node: &Node = nodes.last().unwrap(); - module.cmux( + self.cmux( out_i, prev_level[node.high_index], prev_level[node.low_index], @@ -116,7 +111,7 @@ where ); } - for out_i in out.iter_mut().skip(self.output_size()) { + for out_i in out.iter_mut().skip(circuit.output_size()) { out_i.data_mut().zero(); } } @@ -159,14 +154,8 @@ impl Node { } pub trait Cmux { - fn cmux( - &self, - out: &mut GLWECiphertext, - t: &GLWECiphertext, - f: &GLWECiphertext, - s: &GGSWCiphertextPrepared, - scratch: &mut Scratch, - ) where + fn cmux(&self, out: &mut GLWE, t: &GLWE, f: &GLWE, s: &GGSWPrepared, scratch: &mut Scratch) + where O: DataMut, T: DataRef, F: DataRef, @@ -175,24 +164,19 @@ pub trait Cmux { impl Cmux for Module where - Module: GLWEExternalProductInplace + VecZnxSub + VecZnxCopy + VecZnxNegateInplace + VecZnxAddInplace, + Module: GLWEExternalProduct + GLWESub + GLWEAdd, + Scratch: ScratchTakeCore, { - fn cmux( - &self, - out: &mut GLWECiphertext, - t: &GLWECiphertext, - f: &GLWECiphertext, - s: &GGSWCiphertextPrepared, - scratch: &mut Scratch, - ) where + fn cmux(&self, out: &mut GLWE, t: &GLWE, f: &GLWE, s: &GGSWPrepared, scratch: &mut Scratch) + where O: DataMut, T: DataRef, F: DataRef, S: DataRef, { // let mut out: GLWECiphertext<&mut [u8]> = out.to_mut(); - out.sub(self, t, f); - out.external_product_inplace(self, s, scratch); - out.to_mut().add_inplace(self, f); + self.glwe_sub(out, t, f); + self.glwe_external_product_inplace(out, s, scratch); + self.glwe_add_inplace(out, f); } } diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs index febabed..d1c20c6 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs @@ -1,29 +1,21 @@ #[cfg(test)] -use crate::tfhe::bdd_arithmetic::FheUintBlocksPrepDebug; +use crate::tfhe::bdd_arithmetic::FheUintBlocksPreparedDebug; use crate::tfhe::{ - bdd_arithmetic::{FheUintBlocks, FheUintBlocksPrep, UnsignedInteger}, - blind_rotation::{BlindRotationAlgo, BlindRotationKey, BlindRotationKeyAlloc, BlindRotationKeyEncryptSk}, + bdd_arithmetic::{FheUintBlocks, FheUintBlocksPrepared, UnsignedInteger}, + blind_rotation::{BlindRotationAlgo, BlindRotationKey, BlindRotationKeyFactory}, circuit_bootstrapping::{ CircuitBootstrappingKey, CircuitBootstrappingKeyEncryptSk, CircuitBootstrappingKeyLayout, - CircuitBootstrappingKeyPrepared, CirtuitBootstrappingExecute, + CircuitBootstrappingKeyPrepared, CircuitBootstrappingKeyPreparedFactory, CirtuitBootstrappingExecute, }, }; use poulpy_core::{ - TakeGGSW, TakeGLWECt, + GLWEToLWESwitchingKeyEncryptSk, GetDistribution, LWEFromGLWE, ScratchTakeCore, layouts::{ - GLWESecret, GLWEToLWEKey, GLWEToLWEKeyLayout, LWECiphertext, LWESecret, - prepared::{GLWEToLWESwitchingKeyPrepared, Prepare, PrepareAlloc}, + GGSWInfos, GGSWPreparedFactory, GLWEInfos, GLWESecretToRef, GLWEToLWEKeyLayout, GLWEToLWESwitchingKey, + GLWEToLWESwitchingKeyPreparedFactory, LWE, LWEInfos, LWESecretToRef, prepared::GLWEToLWESwitchingKeyPrepared, }, }; use poulpy_hal::{ - api::{ - ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, - TakeSvpPPol, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, - VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, - VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, - VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPrepare, - }, layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}, source::Source, }; @@ -49,193 +41,256 @@ impl BDDKeyInfos for BDDKeyLayout { } } -pub struct BDDKey +pub struct BDDKey where - CBT: Data, - LWE: Data, + D: Data, BRA: BlindRotationAlgo, { - cbt: CircuitBootstrappingKey, - ks: GLWEToLWEKey, + cbt: CircuitBootstrappingKey, + ks: GLWEToLWESwitchingKey, } -impl BDDKey, Vec, BRA> { - pub fn encrypt_sk( - module: &Module, - sk_lwe: &LWESecret, - sk_glwe: &GLWESecret, - infos: &A, +impl BDDKey, BRA> +where + BlindRotationKey, BRA>: BlindRotationKeyFactory, +{ + pub fn alloc_from_infos(infos: &A) -> Self + where + A: BDDKeyInfos, + { + Self { + cbt: CircuitBootstrappingKey::alloc_from_infos(&infos.cbt_infos()), + ks: GLWEToLWESwitchingKey::alloc_from_infos(&infos.ks_infos()), + } + } +} + +pub trait BDDKeyEncryptSk { + fn bdd_key_encrypt_sk( + &self, + res: &mut BDDKey, + sk_lwe: &S0, + sk_glwe: &S1, source_xa: &mut Source, source_xe: &mut Source, scratch: &mut Scratch, - ) -> Self + ) where + D: DataMut, + S0: LWESecretToRef + GetDistribution + LWEInfos, + S1: GLWESecretToRef + GetDistribution + GLWEInfos; +} + +impl BDDKeyEncryptSk for Module +where + Self: CircuitBootstrappingKeyEncryptSk + GLWEToLWESwitchingKeyEncryptSk, + Scratch: ScratchTakeCore, +{ + fn bdd_key_encrypt_sk( + &self, + res: &mut BDDKey, + sk_lwe: &S0, + sk_glwe: &S1, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + D: DataMut, + S0: LWESecretToRef + GetDistribution + LWEInfos, + S1: GLWESecretToRef + GetDistribution + GLWEInfos, + { + res.ks + .encrypt_sk(self, sk_lwe, sk_glwe, source_xa, source_xe, scratch); + res.cbt + .encrypt_sk(self, sk_lwe, sk_glwe, source_xa, source_xe, scratch); + } +} + +impl BDDKey { + pub fn encrypt_sk( + &mut self, + module: &M, + sk_lwe: &S0, + sk_glwe: &S1, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + S0: LWESecretToRef + GetDistribution + LWEInfos, + S1: GLWESecretToRef + GetDistribution + GLWEInfos, + M: BDDKeyEncryptSk, + Scratch: ScratchTakeCore, + { + module.bdd_key_encrypt_sk(self, sk_lwe, sk_glwe, source_xa, source_xe, scratch); + } +} + +pub struct BDDKeyPrepared +where + D: Data, + BRA: BlindRotationAlgo, + BE: Backend, +{ + pub(crate) cbt: CircuitBootstrappingKeyPrepared, + pub(crate) ks: GLWEToLWESwitchingKeyPrepared, +} + +pub trait BDDKeyPreparedFactory +where + Self: Sized + CircuitBootstrappingKeyPreparedFactory + GLWEToLWESwitchingKeyPreparedFactory, +{ + fn alloc_bdd_key_from_infos(&self, infos: &A) -> BDDKeyPrepared, BRA, BE> where A: BDDKeyInfos, - DLwe: DataRef, - DGlwe: DataRef, - BlindRotationKey, BRA>: BlindRotationKeyAlloc + BlindRotationKeyEncryptSk, - Module: SvpApplyDftToDft - + VecZnxIdftApplyTmpA - + VecZnxAddScalarInplace - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + VecZnxSwitchRing - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxAutomorphism - + VecZnxAutomorphismInplace, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeSvpPPol + TakeVecZnxBig, { - let mut ks: GLWEToLWEKey> = GLWEToLWEKey::alloc(&infos.ks_infos()); - ks.encrypt_sk(module, sk_lwe, sk_glwe, source_xa, source_xe, scratch); - - Self { - cbt: CircuitBootstrappingKey::encrypt_sk( - module, - sk_lwe, - sk_glwe, - &infos.cbt_infos(), - source_xa, - source_xe, - scratch, - ), - ks, - } - } -} - -pub struct BDDKeyPrepared -where - CBT: Data, - LWE: Data, - BRA: BlindRotationAlgo, - BE: Backend, -{ - cbt: CircuitBootstrappingKeyPrepared, - ks: GLWEToLWESwitchingKeyPrepared, -} - -impl PrepareAlloc> - for BDDKey -where - CircuitBootstrappingKey: PrepareAlloc>, - GLWEToLWEKey: PrepareAlloc>, -{ - fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> BDDKeyPrepared { BDDKeyPrepared { - cbt: self.cbt.prepare_alloc(module, scratch), - ks: self.ks.prepare_alloc(module, scratch), + cbt: CircuitBootstrappingKeyPrepared::alloc_from_infos(self, &infos.cbt_infos()), + ks: GLWEToLWESwitchingKeyPrepared::alloc_from_infos(self, &infos.ks_infos()), + } + } + + fn prepare_bdd_key_tmp_bytes(&self, infos: &A) -> usize + where + A: BDDKeyInfos, + { + self.circuit_bootstrapping_key_prepare_tmp_bytes(&infos.cbt_infos()) + .max(self.prepare_glwe_to_lwe_switching_key_tmp_bytes(&infos.ks_infos())) + } + + fn prepare_bdd_key(&self, res: &mut BDDKeyPrepared, other: &BDDKey, scratch: &mut Scratch) + where + DM: DataMut, + DR: DataRef, + Scratch: ScratchTakeCore, + { + res.cbt.prepare(self, &other.cbt, scratch); + res.ks.prepare(self, &other.ks, scratch); + } +} +impl BDDKeyPreparedFactory for Module where + Self: Sized + CircuitBootstrappingKeyPreparedFactory + GLWEToLWESwitchingKeyPreparedFactory +{ +} + +impl BDDKeyPrepared, BRA, BE> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + M: BDDKeyPreparedFactory, + A: BDDKeyInfos, + { + module.alloc_bdd_key_from_infos(infos) + } +} + +impl BDDKeyPrepared { + pub fn prepare(&mut self, module: &M, other: &BDDKey, scratch: &mut Scratch) + where + DR: DataRef, + M: BDDKeyPreparedFactory, + Scratch: ScratchTakeCore, + { + module.prepare_bdd_key(self, other, scratch); + } +} + +pub trait FheUintBlocksPrepare { + fn fhe_uint_blocks_prepare_tmp_bytes( + &self, + block_size: usize, + extension_factor: usize, + res_infos: &R, + infos: &A, + ) -> usize + where + R: GGSWInfos, + A: BDDKeyInfos; + fn fhe_uint_blocks_prepare( + &self, + res: &mut FheUintBlocksPrepared, + bits: &FheUintBlocks, + key: &BDDKeyPrepared, + scratch: &mut Scratch, + ) where + DM: DataMut, + DR0: DataRef, + DR1: DataRef; +} + +impl FheUintBlocksPrepare for Module +where + Self: LWEFromGLWE + CirtuitBootstrappingExecute + GGSWPreparedFactory, + Scratch: ScratchTakeCore, +{ + fn fhe_uint_blocks_prepare_tmp_bytes( + &self, + block_size: usize, + extension_factor: usize, + res_infos: &R, + bdd_infos: &A, + ) -> usize + where + R: GGSWInfos, + A: BDDKeyInfos, + { + self.circuit_bootstrapping_execute_tmp_bytes( + block_size, + extension_factor, + res_infos, + &bdd_infos.cbt_infos(), + ) + } + + fn fhe_uint_blocks_prepare( + &self, + res: &mut FheUintBlocksPrepared, + bits: &FheUintBlocks, + key: &BDDKeyPrepared, + scratch: &mut Scratch, + ) where + DM: DataMut, + DR0: DataRef, + DR1: DataRef, + { + assert_eq!(res.blocks.len(), bits.blocks.len()); + + let mut lwe: LWE> = LWE::alloc_from_infos(&bits.blocks[0]); //TODO: add TakeLWE + let (mut tmp_ggsw, scratch_1) = scratch.take_ggsw(res); + for (dst, src) in res.blocks.iter_mut().zip(bits.blocks.iter()) { + lwe.from_glwe(self, src, &key.ks, scratch_1); + key.cbt + .execute_to_constant(self, &mut tmp_ggsw, &lwe, 1, 1, scratch_1); + dst.prepare(self, &tmp_ggsw, scratch_1); } } } -pub trait FheUintPrepare { - fn prepare(&self, module: &Module, out: &mut OUT, bits: &IN, scratch: &mut Scratch); -} - -impl FheUintPrepare, FheUintBlocks> - for BDDKeyPrepared -where - T: UnsignedInteger, - CBT: DataRef, - OUT: DataMut, - IN: DataRef, - LWE: DataRef, - BRA: BlindRotationAlgo, - BE: Backend, - Module: VmpPrepare - + VecZnxRotate - + VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable + TakeVecZnxDft + TakeGLWECt + TakeVecZnx + TakeGGSW, - CircuitBootstrappingKeyPrepared: CirtuitBootstrappingExecute, -{ - fn prepare( - &self, - module: &Module, - out: &mut FheUintBlocksPrep, - bits: &FheUintBlocks, +impl FheUintBlocksPrepared { + pub fn prepare( + &mut self, + module: &M, + other: &FheUintBlocks, + key: &BDDKeyPrepared, scratch: &mut Scratch, - ) { - #[cfg(debug_assertions)] - { - assert_eq!(out.blocks.len(), bits.blocks.len()); - } - let mut lwe: LWECiphertext> = LWECiphertext::alloc(&bits.blocks[0]); //TODO: add TakeLWE - let (mut tmp_ggsw, scratch_1) = scratch.take_ggsw(out); - for (dst, src) in out.blocks.iter_mut().zip(bits.blocks.iter()) { - lwe.from_glwe(module, src, &self.ks, scratch_1); - self.cbt - .execute_to_constant(module, &mut tmp_ggsw, &lwe, 1, 1, scratch_1); - dst.prepare(module, &tmp_ggsw, scratch_1); - } + ) where + BRA: BlindRotationAlgo, + O: DataRef, + K: DataRef, + M: FheUintBlocksPrepare, + Scratch: ScratchTakeCore, + { + module.fhe_uint_blocks_prepare(self, other, key, scratch); } } #[cfg(test)] -impl FheUintPrepare, FheUintBlocks> - for BDDKeyPrepared -where - T: UnsignedInteger, - CBT: DataRef, - OUT: DataMut, - IN: DataRef, - LWE: DataRef, - BRA: BlindRotationAlgo, - BE: Backend, - Module: VmpPrepare - + VecZnxRotate - + VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable + TakeVecZnxDft + TakeGLWECt + TakeVecZnx + TakeGGSW, - CircuitBootstrappingKeyPrepared: CirtuitBootstrappingExecute, -{ - fn prepare( +pub(crate) trait FheUintBlockDebugPrepare { + fn fhe_uint_block_debug_prepare( &self, - module: &Module, - out: &mut FheUintBlocksPrepDebug, - bits: &FheUintBlocks, + res: &mut FheUintBlocksPreparedDebug, + bits: &FheUintBlocks, + key: &BDDKeyPrepared, scratch: &mut Scratch, - ) { - #[cfg(debug_assertions)] - { - assert_eq!(out.blocks.len(), bits.blocks.len()); - } - let mut lwe: LWECiphertext> = LWECiphertext::alloc(&bits.blocks[0]); //TODO: add TakeLWE - for (dst, src) in out.blocks.iter_mut().zip(bits.blocks.iter()) { - lwe.from_glwe(module, src, &self.ks, scratch); - self.cbt - .execute_to_constant(module, dst, &lwe, 1, 1, scratch); - } - } + ) where + DM: DataMut, + DR0: DataRef, + DR1: DataRef; } diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/parameters.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/parameters.rs index 6b56f79..f807414 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/parameters.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/parameters.rs @@ -1,7 +1,7 @@ #[cfg(test)] use poulpy_core::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWEAutomorphismKeyLayout, GGLWETensorKeyLayout, GGSWCiphertextLayout, GLWECiphertextLayout, - GLWEToLWEKeyLayout, Rank, TorusPrecision, + Base2K, Degree, Dnum, Dsize, GGSWLayout, GLWEAutomorphismKeyLayout, GLWELayout, GLWETensorKeyLayout, GLWEToLWEKeyLayout, + Rank, TorusPrecision, }; #[cfg(test)] @@ -25,7 +25,7 @@ pub(crate) const TEST_BLOCK_SIZE: u32 = 7; pub(crate) const TEST_RANK: u32 = 2; #[cfg(test)] -pub(crate) static TEST_GLWE_INFOS: GLWECiphertextLayout = GLWECiphertextLayout { +pub(crate) static TEST_GLWE_INFOS: GLWELayout = GLWELayout { n: Degree(TEST_N_GLWE), base2k: Base2K(TEST_BASE2K), k: TorusPrecision(TEST_K_GLWE), @@ -33,7 +33,7 @@ pub(crate) static TEST_GLWE_INFOS: GLWECiphertextLayout = GLWECiphertextLayout { }; #[cfg(test)] -pub(crate) static TEST_GGSW_INFOS: GGSWCiphertextLayout = GGSWCiphertextLayout { +pub(crate) static TEST_GGSW_INFOS: GGSWLayout = GGSWLayout { n: Degree(TEST_N_GLWE), base2k: Base2K(TEST_BASE2K), k: TorusPrecision(TEST_K_GGSW), @@ -53,7 +53,7 @@ pub(crate) static TEST_BDD_KEY_LAYOUT: BDDKeyLayout = BDDKeyLayout { dnum: Dnum(3), rank: Rank(TEST_RANK), }, - layout_atk: GGLWEAutomorphismKeyLayout { + layout_atk: GLWEAutomorphismKeyLayout { n: Degree(TEST_N_GLWE), base2k: Base2K(TEST_BASE2K), k: TorusPrecision(52), @@ -61,7 +61,7 @@ pub(crate) static TEST_BDD_KEY_LAYOUT: BDDKeyLayout = BDDKeyLayout { dnum: Dnum(3), dsize: Dsize(1), }, - layout_tsk: GGLWETensorKeyLayout { + layout_tsk: GLWETensorKeyLayout { n: Degree(TEST_N_GLWE), base2k: Base2K(TEST_BASE2K), k: TorusPrecision(52), diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/test.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/test.rs index 1889e02..19155ea 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/test.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/test.rs @@ -2,43 +2,24 @@ use std::time::Instant; use poulpy_backend::FFT64Ref; use poulpy_core::{ - TakeGGSW, TakeGLWEPt, - layouts::{ - GGSWCiphertextLayout, GLWECiphertextLayout, GLWESecret, LWEInfos, LWESecret, - prepared::{GLWESecretPrepared, PrepareAlloc}, - }, + GGSWNoise, GLWEDecrypt, GLWENoise, ScratchTakeCore, + layouts::{GGSWLayout, GLWELayout, GLWESecret, GLWESecretPreparedFactory, LWEInfos, LWESecret, prepared::GLWESecretPrepared}, }; use poulpy_hal::{ - api::{ - ModuleNew, ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDft, SvpApplyDftToDftInplace, - SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeSlice, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, - VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigAutomorphismInplace, - VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxCopy, VecZnxDftAddInplace, - VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy, VecZnxFillUniform, VecZnxIdftApplyConsume, - VecZnxIdftApplyTmpA, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, - VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxRshInplace, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, - VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, ZnAddNormal, ZnFillUniform, - ZnNormalizeInplace, - }, + api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow}, layouts::{Backend, Module, Scratch, ScratchOwned}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeMatZnxImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, - TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, - }, source::Source, }; use rand::RngCore; use crate::tfhe::{ bdd_arithmetic::{ - Add, BDDKey, BDDKeyLayout, BDDKeyPrepared, FheUintBlocks, FheUintBlocksPrep, FheUintBlocksPrepDebug, Sub, - TEST_BDD_KEY_LAYOUT, TEST_BLOCK_SIZE, TEST_GGSW_INFOS, TEST_GLWE_INFOS, TEST_N_LWE, - }, - blind_rotation::{ - BlincRotationExecute, BlindRotationAlgo, BlindRotationKey, BlindRotationKeyAlloc, BlindRotationKeyEncryptSk, - BlindRotationKeyPrepared, CGGI, + Add, BDDKey, BDDKeyEncryptSk, BDDKeyLayout, BDDKeyPrepared, BDDKeyPreparedFactory, ExecuteBDDCircuit2WTo1W, + FheUintBlockDebugPrepare, FheUintBlocks, FheUintBlocksPrepare, FheUintBlocksPrepared, FheUintBlocksPreparedDebug, + FheUintBlocksPreparedEncryptSk, FheUintBlocksPreparedFactory, Sub, TEST_BDD_KEY_LAYOUT, TEST_BLOCK_SIZE, TEST_GGSW_INFOS, + TEST_GLWE_INFOS, TEST_N_LWE, }, + blind_rotation::{BlindRotationAlgo, BlindRotationKey, BlindRotationKeyFactory, CGGI}, }; #[test] @@ -48,67 +29,24 @@ fn test_bdd_2w_to_1w_fft64_ref() { fn test_bdd_2w_to_1w() where - Module: ModuleNew + SvpPPolAlloc + SvpPrepare + VmpPMatAlloc, + Module: ModuleNew + + GLWESecretPreparedFactory + + GLWEDecrypt + + GLWENoise + + FheUintBlocksPreparedFactory + + FheUintBlocksPreparedEncryptSk + + FheUintBlockDebugPrepare + + BDDKeyEncryptSk + + BDDKeyPreparedFactory + + GGSWNoise + + FheUintBlocksPrepare + + ExecuteBDDCircuit2WTo1W, + BlindRotationKey, BRA>: BlindRotationKeyFactory, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, - Module: VecZnxAddScalarInplace - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + VmpPrepare, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeGGSW + TakeScalarZnx + TakeSlice, - Module: VecZnxCopy + VecZnxNegateInplace + VmpApplyDftToDftTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd, - Module: VecZnxBigAddInplace + VecZnxBigAddSmallInplace + VecZnxBigNormalize, - Scratch: TakeVecZnxDft + TakeVecZnxBig + TakeGLWEPt, - Module: VecZnxAutomorphism - + VecZnxSwitchRing - + VecZnxBigAllocBytes - + VecZnxIdftApplyTmpA - + SvpApplyDftToDft - + VecZnxBigAlloc - + VecZnxDftAlloc - + VecZnxBigNormalizeTmpBytes - + SvpPPolAllocBytes - + VecZnxRotateInplace - + VecZnxBigAutomorphismInplace - + VecZnxRshInplace - + VecZnxDftCopy - + VecZnxAutomorphismInplace - + VecZnxBigSubSmallNegateInplace - + VecZnxRotateInplaceTmpBytes - + VecZnxBigAllocBytes - + VecZnxDftAddInplace - + VecZnxRotate - + ZnFillUniform - + ZnAddNormal - + ZnNormalizeInplace, - BE: Backend - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + TakeVecZnxDftImpl - + ScratchAvailableImpl - + TakeVecZnxImpl - + TakeScalarZnxImpl - + TakeSvpPPolImpl - + TakeVecZnxBigImpl - + TakeVecZnxDftSliceImpl - + TakeMatZnxImpl - + TakeVecZnxSliceImpl, - BlindRotationKey, BRA>: PrepareAlloc, BRA, BE>>, - BlindRotationKeyPrepared, BRA, BE>: BlincRotationExecute, - BlindRotationKey, BRA>: BlindRotationKeyAlloc + BlindRotationKeyEncryptSk, + Scratch: ScratchTakeCore, { - let glwe_infos: GLWECiphertextLayout = TEST_GLWE_INFOS; - let ggsw_infos: GGSWCiphertextLayout = TEST_GGSW_INFOS; + let glwe_infos: GLWELayout = TEST_GLWE_INFOS; + let ggsw_infos: GGSWLayout = TEST_GGSW_INFOS; let n_glwe: usize = glwe_infos.n().into(); @@ -120,9 +58,10 @@ where let mut scratch: ScratchOwned = ScratchOwned::alloc(1 << 22); - let mut sk_glwe: GLWESecret> = GLWESecret::alloc(&glwe_infos); + let mut sk_glwe: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_infos); sk_glwe.fill_ternary_prob(0.5, &mut source_xs); - let sk_glwe_prep: GLWESecretPrepared, BE> = sk_glwe.prepare_alloc(&module, scratch.borrow()); + let mut sk_glwe_prep: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc_from_infos(&module, &glwe_infos); + sk_glwe_prep.prepare(&module, &sk_glwe); let a: u32 = source.next_u32(); let b: u32 = source.next_u32(); @@ -130,12 +69,15 @@ where println!("a: {a}"); println!("b: {b}"); - let mut a_enc_prep: FheUintBlocksPrep, BE, u32> = FheUintBlocksPrep::, BE, u32>::alloc(&module, &ggsw_infos); - let mut b_enc_prep: FheUintBlocksPrep, BE, u32> = FheUintBlocksPrep::, BE, u32>::alloc(&module, &ggsw_infos); + let mut a_enc_prep: FheUintBlocksPrepared, u32, BE> = + FheUintBlocksPrepared::, u32, BE>::alloc(&module, &ggsw_infos); + let mut b_enc_prep: FheUintBlocksPrepared, u32, BE> = + FheUintBlocksPrepared::, u32, BE>::alloc(&module, &ggsw_infos); let mut c_enc: FheUintBlocks, u32> = FheUintBlocks::, u32>::alloc(&module, &glwe_infos); - let mut c_enc_prep_debug: FheUintBlocksPrepDebug, u32> = - FheUintBlocksPrepDebug::, u32>::alloc(&module, &ggsw_infos); - let mut c_enc_prep: FheUintBlocksPrep, BE, u32> = FheUintBlocksPrep::, BE, u32>::alloc(&module, &ggsw_infos); + let mut c_enc_prep_debug: FheUintBlocksPreparedDebug, u32> = + FheUintBlocksPreparedDebug::, u32>::alloc(&module, &ggsw_infos); + let mut c_enc_prep: FheUintBlocksPrepared, u32, BE> = + FheUintBlocksPrepared::, u32, BE>::alloc(&module, &ggsw_infos); a_enc_prep.encrypt_sk( &module, @@ -178,17 +120,19 @@ where let bdd_key_infos: BDDKeyLayout = TEST_BDD_KEY_LAYOUT; + let mut bdd_key: BDDKey, BRA> = BDDKey::alloc_from_infos(&bdd_key_infos); + let now: Instant = Instant::now(); - let bdd_key: BDDKey, Vec, BRA> = BDDKey::encrypt_sk( + bdd_key.encrypt_sk( &module, &sk_lwe, &sk_glwe, - &bdd_key_infos, &mut source_xa, &mut source_xe, scratch.borrow(), ); - let bdd_key_prepared: BDDKeyPrepared, Vec, BRA, BE> = bdd_key.prepare_alloc(&module, scratch.borrow()); + let mut bdd_key_prepared: BDDKeyPrepared, BRA, BE> = BDDKeyPrepared::alloc_from_infos(&module, &bdd_key_infos); + bdd_key_prepared.prepare(&module, &bdd_key, scratch.borrow()); println!("BDD-KGEN: {} ms", now.elapsed().as_millis()); let now: Instant = Instant::now(); diff --git a/poulpy-schemes/src/tfhe/blind_rotation/cggi_algo.rs b/poulpy-schemes/src/tfhe/blind_rotation/algorithms/cggi/algorithm.rs similarity index 52% rename from poulpy-schemes/src/tfhe/blind_rotation/cggi_algo.rs rename to poulpy-schemes/src/tfhe/blind_rotation/algorithms/cggi/algorithm.rs index 03f36e9..b9ec277 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/cggi_algo.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/algorithms/cggi/algorithm.rs @@ -1,159 +1,142 @@ use itertools::izip; use poulpy_hal::{ api::{ - ScratchAvailable, SvpApplyDftToDft, SvpPPolAllocBytes, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice, - TakeVecZnxSlice, VecZnxAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, - VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftApply, - VecZnxDftSubInplace, VecZnxDftZero, VecZnxIdftApply, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpBytes, - VecZnxMulXpMinusOneInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, - VecZnxSubInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + ModuleN, ScratchTakeBasic, SvpApplyDftToDft, VecZnxBigAddSmallInplace, VecZnxBigBytesOf, VecZnxBigNormalize, + VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAddInplace, VecZnxDftApply, VecZnxDftBytesOf, VecZnxDftSubInplace, + VecZnxDftZero, VecZnxIdftApply, VecZnxIdftApplyTmpBytes, VecZnxRotate, VmpApplyDftToDft, VmpApplyDftToDftTmpBytes, }, - layouts::{Backend, DataMut, DataRef, Module, Scratch, SvpPPol, VecZnx, ZnxView, ZnxZero}, + layouts::{Backend, DataMut, DataRef, Module, Scratch, SvpPPol, VecZnx, ZnxZero}, }; use poulpy_core::{ - Distribution, GLWEOperations, TakeGLWECt, - layouts::{GGSWInfos, GLWECiphertext, GLWECiphertextToMut, GLWEInfos, LWECiphertext, LWECiphertextToRef, LWEInfos}, + Distribution, GLWEAdd, GLWEExternalProduct, GLWEMulXpMinusOne, GLWENormalize, ScratchTakeCore, + layouts::{GGSWInfos, GLWE, GLWEInfos, GLWEToMut, LWE, LWEInfos, LWEToRef}, }; use crate::tfhe::blind_rotation::{ - BlincRotationExecute, BlindRotationKeyInfos, BlindRotationKeyPrepared, CGGI, LookUpTable, LookUpTableRotationDirection, + BlindRotationExecute, BlindRotationKeyInfos, BlindRotationKeyPrepared, CGGI, LookupTable, mod_switch_2n, }; -#[allow(clippy::too_many_arguments)] -pub fn cggi_blind_rotate_scratch_space( - module: &Module, - block_size: usize, - extension_factor: usize, - glwe_infos: &OUT, - brk_infos: &GGSW, -) -> usize +impl BlindRotationExecute for Module where - OUT: GLWEInfos, - GGSW: GGSWInfos, - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxNormalizeTmpBytes - + VecZnxBigAllocBytes - + VecZnxIdftApplyTmpBytes - + VecZnxBigNormalizeTmpBytes, -{ - let brk_size: usize = brk_infos.size(); - - if block_size > 1 { - let cols: usize = (brk_infos.rank() + 1).into(); - let dnum: usize = brk_infos.dnum().into(); - let acc_dft: usize = module.vec_znx_dft_alloc_bytes(cols, dnum) * extension_factor; - let acc_big: usize = module.vec_znx_big_alloc_bytes(1, brk_size); - let vmp_res: usize = module.vec_znx_dft_alloc_bytes(cols, brk_size) * extension_factor; - let vmp_xai: usize = module.vec_znx_dft_alloc_bytes(1, brk_size); - let acc_dft_add: usize = vmp_res; - let vmp: usize = module.vmp_apply_dft_to_dft_tmp_bytes(brk_size, dnum, dnum, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2) - let acc: usize = if extension_factor > 1 { - VecZnx::alloc_bytes(module.n(), cols, glwe_infos.size()) * extension_factor - } else { - 0 - }; - - acc + acc_dft - + acc_dft_add - + vmp_res - + vmp_xai - + (vmp | (acc_big + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_apply_tmp_bytes()))) - } else { - GLWECiphertext::alloc_bytes(glwe_infos) - + GLWECiphertext::external_product_inplace_scratch_space(module, glwe_infos, brk_infos) - } -} - -impl BlincRotationExecute for BlindRotationKeyPrepared -where - Module: VecZnxBigAllocBytes - + VecZnxDftAllocBytes - + SvpPPolAllocBytes + Self: VecZnxDftBytesOf + + VecZnxBigBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxIdftApplyTmpBytes - + VecZnxIdftApply - + VecZnxDftAdd - + VecZnxDftAddInplace - + VecZnxDftApply - + VecZnxDftZero - + SvpApplyDftToDft - + VecZnxDftSubInplace - + VecZnxBigAddSmallInplace + + GLWEExternalProduct + + ModuleN + VecZnxRotate - + VecZnxAddInplace - + VecZnxSubInplace - + VecZnxNormalize - + VecZnxNormalizeInplace + + VecZnxDftApply + + VecZnxDftZero + + VmpApplyDftToDft + + SvpApplyDftToDft + + VecZnxDftAddInplace + + VecZnxDftSubInplace + + VecZnxIdftApply + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize + VecZnxCopy - + VecZnxMulXpMinusOneInplace - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxIdftApplyConsume - + VecZnxBigNormalize - + VecZnxNormalizeTmpBytes, - Scratch: TakeVecZnxDftSlice + TakeVecZnxDft + TakeVecZnxBig + TakeVecZnxSlice + TakeVecZnx + ScratchAvailable, + + GLWEMulXpMinusOne + + GLWEAdd + + GLWENormalize, + Scratch: ScratchTakeCore, { - fn execute( + fn blind_rotation_execute_tmp_bytes( &self, - module: &Module, - res: &mut GLWECiphertext, - lwe: &LWECiphertext, - lut: &LookUpTable, - scratch: &mut Scratch, - ) { - match self.dist { + block_size: usize, + extension_factor: usize, + glwe_infos: &G, + brk_infos: &B, + ) -> usize + where + G: GLWEInfos, + B: GGSWInfos, + { + let brk_size: usize = brk_infos.size(); + + if block_size > 1 { + let cols: usize = (brk_infos.rank() + 1).into(); + let dnum: usize = brk_infos.dnum().into(); + let acc_dft: usize = self.bytes_of_vec_znx_dft(cols, dnum) * extension_factor; + let acc_big: usize = self.bytes_of_vec_znx_big(1, brk_size); + let vmp_res: usize = self.bytes_of_vec_znx_dft(cols, brk_size) * extension_factor; + let vmp_xai: usize = self.bytes_of_vec_znx_dft(1, brk_size); + let acc_dft_add: usize = vmp_res; + let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes(brk_size, dnum, dnum, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2) + let acc: usize = if extension_factor > 1 { + VecZnx::bytes_of(self.n(), cols, glwe_infos.size()) * extension_factor + } else { + 0 + }; + + acc + acc_dft + + acc_dft_add + + vmp_res + + vmp_xai + + (vmp + | (acc_big + + (self + .vec_znx_big_normalize_tmp_bytes() + .max(self.vec_znx_idft_apply_tmp_bytes())))) + } else { + GLWE::bytes_of_from_infos(glwe_infos) + GLWE::external_product_tmp_bytes(self, glwe_infos, glwe_infos, brk_infos) + } + } + + fn blind_rotation_execute( + &self, + res: &mut GLWE, + lwe: &LWE
, + lut: &LookupTable, + brk: &BlindRotationKeyPrepared, + scratch: &mut Scratch, + ) where + DR: DataMut, + DL: DataRef, + DB: DataRef, + { + match brk.dist { Distribution::BinaryBlock(_) | Distribution::BinaryFixed(_) | Distribution::BinaryProb(_) | Distribution::ZERO => { if lut.extension_factor() > 1 { - execute_block_binary_extended(module, res, lwe, lut, self, scratch) - } else if self.block_size() > 1 { - execute_block_binary(module, res, lwe, lut, self, scratch); + execute_block_binary_extended(self, res, lwe, lut, brk, scratch) + } else if brk.block_size() > 1 { + execute_block_binary(self, res, lwe, lut, brk, scratch); } else { - execute_standard(module, res, lwe, lut, self, scratch); + execute_standard(self, res, lwe, lut, brk, scratch); } } - _ => panic!("invalid CGGI distribution"), + _ => panic!("invalid CGGI distribution (have you prepared the key?)"), } } } -fn execute_block_binary_extended( - module: &Module, - res: &mut GLWECiphertext, - lwe: &LWECiphertext, - lut: &LookUpTable, - brk: &BlindRotationKeyPrepared, - scratch: &mut Scratch, +fn execute_block_binary_extended( + module: &M, + res: &mut GLWE, + lwe: &LWE, + lut: &LookupTable, + brk: &BlindRotationKeyPrepared, + scratch: &mut Scratch, ) where DataRes: DataMut, DataIn: DataRef, DataBrk: DataRef, - Module: VecZnxBigAllocBytes - + VecZnxDftAllocBytes - + SvpPPolAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VecZnxIdftApplyTmpBytes - + VecZnxIdftApply - + VecZnxDftAdd - + VecZnxDftAddInplace - + VecZnxDftApply - + VecZnxDftZero - + SvpApplyDftToDft - + VecZnxDftSubInplace - + VecZnxBigAddSmallInplace + M: VecZnxDftBytesOf + + ModuleN + VecZnxRotate - + VecZnxAddInplace - + VecZnxSubInplace - + VecZnxNormalize - + VecZnxNormalizeInplace + + VecZnxDftApply + + VecZnxDftZero + + VmpApplyDftToDft + + SvpApplyDftToDft + + VecZnxDftAddInplace + + VecZnxDftSubInplace + + VecZnxIdftApply + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize + VecZnxCopy - + VecZnxMulXpMinusOneInplace - + VecZnxBigNormalize - + VmpApplyDftToDft, - Scratch: TakeVecZnxDftSlice + TakeVecZnxDft + TakeVecZnxBig + TakeVecZnxSlice + ScratchAvailable + TakeVecZnx, + + VecZnxBigBytesOf, + Scratch: ScratchTakeCore, { let n_glwe: usize = brk.n_glwe().into(); let extension_factor: usize = lut.extension_factor(); @@ -162,16 +145,16 @@ fn execute_block_binary_extended( let cols: usize = (res.rank() + 1).into(); let (mut acc, scratch_1) = scratch.take_vec_znx_slice(extension_factor, n_glwe, cols, res.size()); - let (mut acc_dft, scratch_2) = scratch_1.take_vec_znx_dft_slice(extension_factor, n_glwe, cols, dnum); - let (mut vmp_res, scratch_3) = scratch_2.take_vec_znx_dft_slice(extension_factor, n_glwe, cols, brk.size()); - let (mut acc_add_dft, scratch_4) = scratch_3.take_vec_znx_dft_slice(extension_factor, n_glwe, cols, brk.size()); - let (mut vmp_xai, scratch_5) = scratch_4.take_vec_znx_dft(n_glwe, 1, brk.size()); + let (mut acc_dft, scratch_2) = scratch_1.take_vec_znx_dft_slice(module, extension_factor, cols, dnum); + let (mut vmp_res, scratch_3) = scratch_2.take_vec_znx_dft_slice(module, extension_factor, cols, brk.size()); + let (mut acc_add_dft, scratch_4) = scratch_3.take_vec_znx_dft_slice(module, extension_factor, cols, brk.size()); + let (mut vmp_xai, scratch_5) = scratch_4.take_vec_znx_dft(module, 1, brk.size()); (0..extension_factor).for_each(|i| { acc[i].zero(); }); - let x_pow_a: &Vec, B>>; + let x_pow_a: &Vec, BE>>; if let Some(b) = &brk.x_pow_a { x_pow_a = b } else { @@ -179,7 +162,7 @@ fn execute_block_binary_extended( } let mut lwe_2n: Vec = vec![0i64; (lwe.n() + 1).as_usize()]; // TODO: from scratch space - let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref(); + let lwe_ref: LWE<&[u8]> = lwe.to_ref(); let two_n: usize = 2 * n_glwe; let two_n_ext: usize = 2 * lut.domain_size(); @@ -269,7 +252,7 @@ fn execute_block_binary_extended( }); { - let (mut acc_add_big, scratch7) = scratch_5.take_vec_znx_big(n_glwe, 1, brk.size()); + let (mut acc_add_big, scratch7) = scratch_5.take_vec_znx_big(module, 1, brk.size()); (0..extension_factor).for_each(|j| { (0..cols).for_each(|i| { @@ -286,46 +269,37 @@ fn execute_block_binary_extended( }); } -fn execute_block_binary( - module: &Module, - res: &mut GLWECiphertext, - lwe: &LWECiphertext, - lut: &LookUpTable, - brk: &BlindRotationKeyPrepared, - scratch: &mut Scratch, +fn execute_block_binary( + module: &M, + res: &mut GLWE, + lwe: &LWE, + lut: &LookupTable, + brk: &BlindRotationKeyPrepared, + scratch: &mut Scratch, ) where DataRes: DataMut, DataIn: DataRef, DataBrk: DataRef, - Module: VecZnxBigAllocBytes - + VecZnxDftAllocBytes - + SvpPPolAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VecZnxIdftApplyTmpBytes - + VecZnxIdftApply - + VecZnxDftAdd - + VecZnxDftAddInplace - + VecZnxDftApply - + VecZnxDftZero - + SvpApplyDftToDft - + VecZnxDftSubInplace - + VecZnxBigAddSmallInplace + M: VecZnxDftBytesOf + + ModuleN + VecZnxRotate - + VecZnxAddInplace - + VecZnxSubInplace - + VecZnxNormalize - + VecZnxNormalizeInplace + + VecZnxDftApply + + VecZnxDftZero + + VmpApplyDftToDft + + SvpApplyDftToDft + + VecZnxDftAddInplace + + VecZnxDftSubInplace + + VecZnxIdftApply + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize + VecZnxCopy - + VecZnxMulXpMinusOneInplace - + VmpApplyDftToDft - + VecZnxBigNormalize, - Scratch: TakeVecZnxDftSlice + TakeVecZnxDft + TakeVecZnxBig + TakeVecZnxSlice + ScratchAvailable + TakeVecZnx, + + VecZnxBigBytesOf, + Scratch: ScratchTakeCore, { let n_glwe: usize = brk.n_glwe().into(); let mut lwe_2n: Vec = vec![0i64; (lwe.n() + 1).into()]; // TODO: from scratch space - let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut(); - let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref(); + let mut out_mut: GLWE<&mut [u8]> = res.to_mut(); + let lwe_ref: LWE<&[u8]> = lwe.to_ref(); let two_n: usize = n_glwe << 1; let base2k: usize = brk.base2k().into(); let dnum: usize = brk.dnum().into(); @@ -351,12 +325,12 @@ fn execute_block_binary( // ACC + [sum DFT(X^ai -1) * (DFT(ACC) x BRKi)] - let (mut acc_dft, scratch_1) = scratch.take_vec_znx_dft(n_glwe, cols, dnum); - let (mut vmp_res, scratch_2) = scratch_1.take_vec_znx_dft(n_glwe, cols, brk.size()); - let (mut acc_add_dft, scratch_3) = scratch_2.take_vec_znx_dft(n_glwe, cols, brk.size()); - let (mut vmp_xai, scratch_4) = scratch_3.take_vec_znx_dft(n_glwe, 1, brk.size()); + let (mut acc_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols, dnum); + let (mut vmp_res, scratch_2) = scratch_1.take_vec_znx_dft(module, cols, brk.size()); + let (mut acc_add_dft, scratch_3) = scratch_2.take_vec_znx_dft(module, cols, brk.size()); + let (mut vmp_xai, scratch_4) = scratch_3.take_vec_znx_dft(module, 1, brk.size()); - let x_pow_a: &Vec, B>>; + let x_pow_a: &Vec, BE>>; if let Some(b) = &brk.x_pow_a { x_pow_a = b } else { @@ -389,7 +363,7 @@ fn execute_block_binary( }); { - let (mut acc_add_big, scratch_5) = scratch_4.take_vec_znx_big(n_glwe, 1, brk.size()); + let (mut acc_add_big, scratch_5) = scratch_4.take_vec_znx_big(module, 1, brk.size()); (0..cols).for_each(|i| { module.vec_znx_idft_apply(&mut acc_add_big, 0, &acc_add_dft, i, scratch_5); @@ -408,44 +382,19 @@ fn execute_block_binary( }); } -fn execute_standard( - module: &Module, - res: &mut GLWECiphertext, - lwe: &LWECiphertext, - lut: &LookUpTable, - brk: &BlindRotationKeyPrepared, - scratch: &mut Scratch, +fn execute_standard( + module: &M, + res: &mut GLWE, + lwe: &LWE, + lut: &LookupTable, + brk: &BlindRotationKeyPrepared, + scratch: &mut Scratch, ) where DataRes: DataMut, DataIn: DataRef, DataBrk: DataRef, - Module: VecZnxBigAllocBytes - + VecZnxDftAllocBytes - + SvpPPolAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VecZnxIdftApplyTmpBytes - + VecZnxIdftApply - + VecZnxDftAdd - + VecZnxDftAddInplace - + VecZnxDftApply - + VecZnxDftZero - + SvpApplyDftToDft - + VecZnxDftSubInplace - + VecZnxBigAddSmallInplace - + VecZnxRotate - + VecZnxAddInplace - + VecZnxSubInplace - + VecZnxNormalize - + VecZnxNormalizeInplace - + VecZnxCopy - + VecZnxMulXpMinusOneInplace - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxIdftApplyConsume - + VecZnxBigNormalize - + VecZnxNormalizeTmpBytes, - Scratch: TakeVecZnxDftSlice + TakeVecZnxDft + TakeVecZnxBig + TakeVecZnxSlice + ScratchAvailable + TakeVecZnx, + M: VecZnxRotate + GLWEExternalProduct + GLWEMulXpMinusOne + GLWEAdd + GLWENormalize, + Scratch: ScratchTakeCore, { #[cfg(debug_assertions)] { @@ -480,8 +429,8 @@ fn execute_standard( } let mut lwe_2n: Vec = vec![0i64; (lwe.n() + 1).into()]; // TODO: from scratch space - let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut(); - let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref(); + let mut out_mut: GLWE<&mut [u8]> = res.to_mut(); + let lwe_ref: LWE<&[u8]> = lwe.to_ref(); mod_switch_2n( 2 * lut.domain_size(), @@ -499,7 +448,7 @@ fn execute_standard( module.vec_znx_rotate(b, out_mut.data_mut(), 0, &lut.data[0], 0); // ACC + [sum DFT(X^ai -1) * (DFT(ACC) x BRKi)] - let (mut acc_tmp, scratch_1) = scratch.take_glwe_ct(&out_mut); + let (mut acc_tmp, scratch_1) = scratch.take_glwe(&out_mut); // TODO: see if faster by skipping normalization in external product and keeping acc in big coeffs // TODO: first iteration can be optimized to be a gglwe product @@ -508,55 +457,13 @@ fn execute_standard( acc_tmp.external_product(module, &out_mut, ski, scratch_1); // acc_tmp = (sk[i] * acc) * (X^{ai} - 1) - acc_tmp.mul_xp_minus_one_inplace(module, *ai, scratch_1); + module.glwe_mul_xp_minus_one_inplace(*ai, &mut acc_tmp, scratch_1); // acc = acc + (sk[i] * acc) * (X^{ai} - 1) - out_mut.add_inplace(module, &acc_tmp); + module.glwe_add_inplace(&mut out_mut, &acc_tmp); }); // We can normalize only at the end because we add normalized values in [-2^{base2k-1}, 2^{base2k-1}] // on top of each others, thus ~ 2^{63-base2k} additions are supported before overflow. - out_mut.normalize_inplace(module, scratch_1); -} - -pub fn mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphertext<&[u8]>, rot_dir: LookUpTableRotationDirection) { - let base2k: usize = lwe.base2k().into(); - - let log2n: usize = usize::BITS as usize - (n - 1).leading_zeros() as usize + 1; - - res.copy_from_slice(lwe.data().at(0, 0)); - - match rot_dir { - LookUpTableRotationDirection::Left => { - res.iter_mut().for_each(|x| *x = -*x); - } - LookUpTableRotationDirection::Right => {} - } - - if base2k > log2n { - let diff: usize = base2k - (log2n - 1); // additional -1 because we map to [-N/2, N/2) instead of [0, N) - res.iter_mut().for_each(|x| { - *x = div_round_by_pow2(x, diff); - }) - } else { - let rem: usize = base2k - (log2n % base2k); - let size: usize = log2n.div_ceil(base2k); - (1..size).for_each(|i| { - if i == size - 1 && rem != base2k { - let k_rem: usize = base2k - rem; - izip!(lwe.data().at(0, i).iter(), res.iter_mut()).for_each(|(x, y)| { - *y = (*y << k_rem) + (x >> rem); - }); - } else { - izip!(lwe.data().at(0, i).iter(), res.iter_mut()).for_each(|(x, y)| { - *y = (*y << base2k) + x; - }); - } - }) - } -} - -#[inline(always)] -fn div_round_by_pow2(x: &i64, k: usize) -> i64 { - (x + (1 << (k - 1))) >> k + module.glwe_normalize_inplace(&mut out_mut, scratch_1); } diff --git a/poulpy-schemes/src/tfhe/blind_rotation/algorithms/cggi/key.rs b/poulpy-schemes/src/tfhe/blind_rotation/algorithms/cggi/key.rs new file mode 100644 index 0000000..ee3f454 --- /dev/null +++ b/poulpy-schemes/src/tfhe/blind_rotation/algorithms/cggi/key.rs @@ -0,0 +1,79 @@ +use poulpy_hal::{ + layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxView, ZnxViewMut}, + source::Source, +}; + +use std::marker::PhantomData; + +use poulpy_core::{ + Distribution, GGSWEncryptSk, GetDistribution, ScratchTakeCore, + layouts::{GGSW, GGSWInfos, GLWEInfos, GLWESecretPreparedToRef, LWEInfos, LWESecret, LWESecretToRef}, +}; + +use crate::tfhe::blind_rotation::{ + BlindRotationKey, BlindRotationKeyEncryptSk, BlindRotationKeyFactory, BlindRotationKeyInfos, CGGI, +}; + +impl BlindRotationKeyFactory for BlindRotationKey { + fn blind_rotation_key_alloc(infos: &A) -> BlindRotationKey, CGGI> + where + A: BlindRotationKeyInfos, + { + BlindRotationKey { + keys: (0..infos.n_lwe().as_usize()) + .map(|_| GGSW::alloc_from_infos(infos)) + .collect(), + dist: Distribution::NONE, + _phantom: PhantomData, + } + } +} + +impl BlindRotationKeyEncryptSk for Module +where + Self: GGSWEncryptSk, + Scratch: ScratchTakeCore, +{ + fn blind_rotation_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize { + self.ggsw_encrypt_sk_tmp_bytes(infos) + } + + fn blind_rotation_key_encrypt_sk( + &self, + res: &mut BlindRotationKey, + sk_glwe: &S0, + sk_lwe: &S1, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + D: DataMut, + S0: GLWESecretPreparedToRef + GLWEInfos, + S1: LWESecretToRef + LWEInfos + GetDistribution, + { + assert_eq!(res.keys.len() as u32, sk_lwe.n()); + assert!(sk_glwe.n() <= self.n() as u32); + assert_eq!(sk_glwe.rank(), res.rank()); + + match sk_lwe.dist() { + Distribution::BinaryBlock(_) | Distribution::BinaryFixed(_) | Distribution::BinaryProb(_) | Distribution::ZERO => {} + _ => { + panic!("invalid GLWESecret distribution: must be BinaryBlock, BinaryFixed or BinaryProb (or ZERO for debugging)") + } + } + + { + let sk_lwe: &LWESecret<&[u8]> = &sk_lwe.to_ref(); + + res.dist = sk_lwe.dist(); + + let mut pt: ScalarZnx> = ScalarZnx::alloc(sk_glwe.n().into(), 1); + let sk_ref: ScalarZnx<&[u8]> = sk_lwe.data().to_ref(); + + for (i, ggsw) in res.keys.iter_mut().enumerate() { + pt.at_mut(0, 0)[0] = sk_ref.at(0, 0)[i]; + ggsw.encrypt_sk(self, &pt, sk_glwe, source_xa, source_xe, scratch); + } + } + } +} diff --git a/poulpy-schemes/src/tfhe/blind_rotation/algorithms/cggi/key_compressed.rs b/poulpy-schemes/src/tfhe/blind_rotation/algorithms/cggi/key_compressed.rs new file mode 100644 index 0000000..255f8f2 --- /dev/null +++ b/poulpy-schemes/src/tfhe/blind_rotation/algorithms/cggi/key_compressed.rs @@ -0,0 +1,84 @@ +use std::marker::PhantomData; + +use poulpy_core::{ + Distribution, GGSWCompressedEncryptSk, GetDistribution, ScratchTakeCore, + layouts::{GGSWCompressed, GGSWInfos, GLWEInfos, GLWESecretPreparedToRef, LWEInfos, LWESecret, LWESecretToRef}, +}; +use poulpy_hal::{ + layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxView, ZnxViewMut}, + source::Source, +}; + +use crate::tfhe::blind_rotation::{ + BlindRotationKeyCompressed, BlindRotationKeyCompressedEncryptSk, BlindRotationKeyCompressedFactory, BlindRotationKeyInfos, + CGGI, +}; + +impl BlindRotationKeyCompressedFactory for BlindRotationKeyCompressed { + fn blind_rotation_key_compressed_alloc(infos: &A) -> BlindRotationKeyCompressed, CGGI> + where + A: BlindRotationKeyInfos, + { + let mut data: Vec>> = Vec::with_capacity(infos.n_lwe().into()); + (0..infos.n_lwe().as_usize()).for_each(|_| data.push(GGSWCompressed::alloc_from_infos(infos))); + BlindRotationKeyCompressed { + keys: data, + dist: Distribution::NONE, + _phantom: PhantomData, + } + } +} + +impl BlindRotationKeyCompressedEncryptSk for Module +where + Self: GGSWCompressedEncryptSk, + Scratch: ScratchTakeCore, +{ + fn blind_rotation_key_compressed_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGSWInfos, + { + self.ggsw_compressed_encrypt_sk_tmp_bytes(infos) + } + + fn blind_rotation_key_compressed_encrypt_sk( + &self, + res: &mut BlindRotationKeyCompressed, + sk_glwe: &S0, + sk_lwe: &S1, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + D: DataMut, + S0: GLWESecretPreparedToRef + GLWEInfos, + S1: LWESecretToRef + LWEInfos + GetDistribution, + { + assert_eq!(res.keys.len() as u32, sk_lwe.n()); + assert!(sk_glwe.n() <= self.n() as u32); + assert_eq!(sk_glwe.rank(), res.rank()); + + match sk_lwe.dist() { + Distribution::BinaryBlock(_) | Distribution::BinaryFixed(_) | Distribution::BinaryProb(_) | Distribution::ZERO => {} + _ => { + panic!("invalid GLWESecret distribution: must be BinaryBlock, BinaryFixed or BinaryProb (or ZERO for debugging)") + } + } + + { + let sk_lwe: &LWESecret<&[u8]> = &sk_lwe.to_ref(); + + let mut source_xa: Source = Source::new(seed_xa); + + res.dist = sk_lwe.dist(); + + let mut pt: ScalarZnx> = ScalarZnx::alloc(sk_glwe.n().into(), 1); + let sk_ref: ScalarZnx<&[u8]> = sk_lwe.data().to_ref(); + + for (i, ggsw) in res.keys.iter_mut().enumerate() { + pt.at_mut(0, 0)[0] = sk_ref.at(0, 0)[i]; + ggsw.encrypt_sk(self, &pt, sk_glwe, source_xa.new_seed(), source_xe, scratch); + } + } + } +} diff --git a/poulpy-schemes/src/tfhe/blind_rotation/algorithms/cggi/key_prepared.rs b/poulpy-schemes/src/tfhe/blind_rotation/algorithms/cggi/key_prepared.rs new file mode 100644 index 0000000..bc1ab98 --- /dev/null +++ b/poulpy-schemes/src/tfhe/blind_rotation/algorithms/cggi/key_prepared.rs @@ -0,0 +1,76 @@ +use poulpy_hal::{ + api::{SvpPPolAlloc, SvpPrepare}, + layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, SvpPPol}, +}; + +use std::marker::PhantomData; + +use poulpy_core::{ + Distribution, + layouts::{GGSWPreparedFactory, LWEInfos, prepared::GGSWPrepared}, +}; + +use crate::tfhe::blind_rotation::{ + BlindRotationKey, BlindRotationKeyInfos, BlindRotationKeyPrepared, BlindRotationKeyPreparedFactory, CGGI, + utils::set_xai_plus_y, +}; + +impl BlindRotationKeyPreparedFactory for Module +where + Self: GGSWPreparedFactory + SvpPPolAlloc + SvpPrepare, +{ + fn blind_rotation_key_prepared_alloc(&self, infos: &A) -> BlindRotationKeyPrepared, CGGI, BE> + where + A: BlindRotationKeyInfos, + { + BlindRotationKeyPrepared { + data: (0..infos.n_lwe().as_usize()) + .map(|_| GGSWPrepared::alloc_from_infos(self, infos)) + .collect(), + dist: Distribution::NONE, + x_pow_a: None, + _phantom: PhantomData, + } + } + + fn blind_rotation_key_prepare_tmp_bytes(&self, infos: &A) -> usize + where + A: BlindRotationKeyInfos, + { + self.ggsw_prepare_tmp_bytes(infos) + } + + fn prepare_blind_rotation_key( + &self, + res: &mut BlindRotationKeyPrepared, + other: &BlindRotationKey, + scratch: &mut Scratch, + ) where + DM: DataMut, + DR: DataRef, + { + #[cfg(debug_assertions)] + { + assert_eq!(res.data.len(), other.keys.len()); + } + + let n: usize = other.n().as_usize(); + + for (a, b) in res.data.iter_mut().zip(other.keys.iter()) { + a.prepare(self, b, scratch); + } + + res.dist = other.dist; + + if let Distribution::BinaryBlock(_) = other.dist { + let mut x_pow_a: Vec, BE>> = Vec::with_capacity(n << 1); + let mut buf: ScalarZnx> = ScalarZnx::alloc(n, 1); + (0..n << 1).for_each(|i| { + let mut res: SvpPPol, BE> = self.svp_ppol_alloc(1); + set_xai_plus_y(self, i, 0, &mut res, &mut buf); + x_pow_a.push(res); + }); + res.x_pow_a = Some(x_pow_a); + } + } +} diff --git a/poulpy-schemes/src/tfhe/blind_rotation/algorithms/cggi/mod.rs b/poulpy-schemes/src/tfhe/blind_rotation/algorithms/cggi/mod.rs new file mode 100644 index 0000000..d67ee45 --- /dev/null +++ b/poulpy-schemes/src/tfhe/blind_rotation/algorithms/cggi/mod.rs @@ -0,0 +1,10 @@ +mod algorithm; +mod key; +mod key_compressed; +mod key_prepared; + +use crate::tfhe::blind_rotation::BlindRotationAlgo; + +#[derive(Clone)] +pub struct CGGI {} +impl BlindRotationAlgo for CGGI {} diff --git a/poulpy-schemes/src/tfhe/blind_rotation/algorithms/mod.rs b/poulpy-schemes/src/tfhe/blind_rotation/algorithms/mod.rs new file mode 100644 index 0000000..f25fb51 --- /dev/null +++ b/poulpy-schemes/src/tfhe/blind_rotation/algorithms/mod.rs @@ -0,0 +1,116 @@ +mod cggi; + +pub use cggi::*; + +use itertools::izip; +use poulpy_core::{ + ScratchTakeCore, + layouts::{GGSWInfos, GLWE, GLWEInfos, LWE, LWEInfos}, +}; +use poulpy_hal::layouts::{Backend, DataMut, DataRef, Scratch, ZnxView}; + +use crate::tfhe::blind_rotation::{BlindRotationKeyInfos, BlindRotationKeyPrepared, LookUpTableRotationDirection, LookupTable}; + +pub trait BlindRotationAlgo {} + +pub trait BlindRotationExecute { + fn blind_rotation_execute_tmp_bytes( + &self, + block_size: usize, + extension_factor: usize, + glwe_infos: &G, + brk_infos: &B, + ) -> usize + where + G: GLWEInfos, + B: GGSWInfos; + + fn blind_rotation_execute( + &self, + res: &mut GLWE, + lwe: &LWE
, + lut: &LookupTable, + brk: &BlindRotationKeyPrepared, + scratch: &mut Scratch, + ) where + DR: DataMut, + DL: DataRef, + DB: DataRef; +} + +impl BlindRotationKeyPrepared +where + Scratch: ScratchTakeCore, +{ + pub fn execute( + &self, + module: &M, + res: &mut GLWE, + lwe: &LWE, + lut: &LookupTable, + scratch: &mut Scratch, + ) where + M: BlindRotationExecute, + { + module.blind_rotation_execute(res, lwe, lut, self, scratch); + } +} + +impl BlindRotationKeyPrepared, BRA, BE> { + pub fn execute_tmp_bytes( + module: &M, + block_size: usize, + extension_factor: usize, + glwe_infos: &A, + brk_infos: &B, + ) -> usize + where + A: GLWEInfos, + B: BlindRotationKeyInfos, + M: BlindRotationExecute, + { + module.blind_rotation_execute_tmp_bytes(block_size, extension_factor, glwe_infos, brk_infos) + } +} + +pub fn mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWE<&[u8]>, rot_dir: LookUpTableRotationDirection) { + let base2k: usize = lwe.base2k().into(); + + let log2n: usize = usize::BITS as usize - (n - 1).leading_zeros() as usize + 1; + + res.copy_from_slice(lwe.data().at(0, 0)); + + match rot_dir { + LookUpTableRotationDirection::Left => { + res.iter_mut().for_each(|x| *x = -*x); + } + LookUpTableRotationDirection::Right => {} + } + + if base2k > log2n { + let diff: usize = base2k - (log2n - 1); // additional -1 because we map to [-N/2, N/2) instead of [0, N) + res.iter_mut().for_each(|x| { + *x = div_round_by_pow2(x, diff); + }) + } else { + let rem: usize = base2k - (log2n % base2k); + let size: usize = log2n.div_ceil(base2k); + (1..size).for_each(|i| { + if i == size - 1 && rem != base2k { + let k_rem: usize = base2k - rem; + izip!(lwe.data().at(0, i).iter(), res.iter_mut()).for_each(|(x, y)| { + *y = (*y << k_rem) + (x >> rem); + }); + } else { + izip!(lwe.data().at(0, i).iter(), res.iter_mut()).for_each(|(x, y)| { + *y = (*y << base2k) + x; + }); + } + }) + } +} + +#[inline(always)] +fn div_round_by_pow2(x: &i64, k: usize) -> i64 { + (x + (1 << (k - 1))) >> k +} diff --git a/poulpy-schemes/src/tfhe/blind_rotation/cggi_key.rs b/poulpy-schemes/src/tfhe/blind_rotation/cggi_key.rs deleted file mode 100644 index fbf506b..0000000 --- a/poulpy-schemes/src/tfhe/blind_rotation/cggi_key.rs +++ /dev/null @@ -1,224 +0,0 @@ -use poulpy_hal::{ - api::{ - ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, - VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, - VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, - VmpPMatAlloc, VmpPrepare, - }, - layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxView, ZnxViewMut}, - source::Source, -}; - -use std::marker::PhantomData; - -use poulpy_core::{ - Distribution, - layouts::{ - GGSWCiphertext, GGSWInfos, LWESecret, - compressed::GGSWCiphertextCompressed, - prepared::{GGSWCiphertextPrepared, GLWESecretPrepared}, - }, -}; - -use crate::tfhe::blind_rotation::{ - BlindRotationKey, BlindRotationKeyAlloc, BlindRotationKeyCompressed, BlindRotationKeyEncryptSk, BlindRotationKeyInfos, - BlindRotationKeyPrepared, BlindRotationKeyPreparedAlloc, CGGI, -}; - -impl BlindRotationKeyAlloc for BlindRotationKey, CGGI> { - fn alloc(infos: &A) -> Self - where - A: BlindRotationKeyInfos, - { - let mut data: Vec>> = Vec::with_capacity(infos.n_lwe().into()); - for _ in 0..infos.n_lwe().as_usize() { - data.push(GGSWCiphertext::alloc(infos)); - } - - Self { - keys: data, - dist: Distribution::NONE, - _phantom: PhantomData, - } - } -} - -impl BlindRotationKey, CGGI> { - pub fn generate_from_sk_scratch_space(module: &Module, infos: &A) -> usize - where - A: GGSWInfos, - Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes, - { - GGSWCiphertext::encrypt_sk_scratch_space(module, infos) - } -} - -impl BlindRotationKeyEncryptSk for BlindRotationKey -where - Module: VecZnxAddScalarInplace - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, -{ - fn encrypt_sk( - &mut self, - module: &Module, - sk_glwe: &GLWESecretPrepared, - sk_lwe: &LWESecret, - source_xa: &mut Source, - source_xe: &mut Source, - scratch: &mut Scratch, - ) where - DataSkGLWE: DataRef, - DataSkLWE: DataRef, - { - #[cfg(debug_assertions)] - { - use poulpy_core::layouts::{GLWEInfos, LWEInfos}; - - assert_eq!(self.keys.len() as u32, sk_lwe.n()); - assert!(sk_glwe.n() <= module.n() as u32); - assert_eq!(sk_glwe.rank(), self.rank()); - match sk_lwe.dist() { - Distribution::BinaryBlock(_) - | Distribution::BinaryFixed(_) - | Distribution::BinaryProb(_) - | Distribution::ZERO => {} - _ => panic!( - "invalid GLWESecret distribution: must be BinaryBlock, BinaryFixed or BinaryProb (or ZERO for debugging)" - ), - } - } - - self.dist = sk_lwe.dist(); - - let mut pt: ScalarZnx> = ScalarZnx::alloc(sk_glwe.n().into(), 1); - let sk_ref: ScalarZnx<&[u8]> = sk_lwe.data().to_ref(); - - self.keys.iter_mut().enumerate().for_each(|(i, ggsw)| { - pt.at_mut(0, 0)[0] = sk_ref.at(0, 0)[i]; - ggsw.encrypt_sk(module, &pt, sk_glwe, source_xa, source_xe, scratch); - }); - } -} - -impl BlindRotationKeyPreparedAlloc for BlindRotationKeyPrepared, CGGI, B> -where - Module: VmpPMatAlloc + VmpPrepare, -{ - fn alloc(module: &Module, infos: &A) -> Self - where - A: BlindRotationKeyInfos, - { - let mut data: Vec, B>> = Vec::with_capacity(infos.n_lwe().into()); - (0..infos.n_lwe().as_usize()).for_each(|_| data.push(GGSWCiphertextPrepared::alloc(module, infos))); - Self { - data, - dist: Distribution::NONE, - x_pow_a: None, - _phantom: PhantomData, - } - } -} - -impl BlindRotationKeyCompressed, CGGI> { - pub fn alloc(infos: &A) -> Self - where - A: BlindRotationKeyInfos, - { - let mut data: Vec>> = Vec::with_capacity(infos.n_lwe().into()); - (0..infos.n_lwe().as_usize()).for_each(|_| data.push(GGSWCiphertextCompressed::alloc(infos))); - Self { - keys: data, - dist: Distribution::NONE, - _phantom: PhantomData, - } - } - - pub fn generate_from_sk_scratch_space(module: &Module, infos: &A) -> usize - where - A: GGSWInfos, - Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes, - { - GGSWCiphertextCompressed::encrypt_sk_scratch_space(module, infos) - } -} - -impl BlindRotationKeyCompressed { - #[allow(clippy::too_many_arguments)] - pub fn encrypt_sk( - &mut self, - module: &Module, - sk_glwe: &GLWESecretPrepared, - sk_lwe: &LWESecret, - seed_xa: [u8; 32], - source_xe: &mut Source, - scratch: &mut Scratch, - ) where - DataSkGLWE: DataRef, - DataSkLWE: DataRef, - Module: VecZnxAddScalarInplace - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, - { - #[cfg(debug_assertions)] - { - use poulpy_core::layouts::{GLWEInfos, LWEInfos}; - - assert_eq!(self.n_lwe(), sk_lwe.n()); - assert!(sk_glwe.n() <= module.n() as u32); - assert_eq!(sk_glwe.rank(), self.rank()); - match sk_lwe.dist() { - Distribution::BinaryBlock(_) - | Distribution::BinaryFixed(_) - | Distribution::BinaryProb(_) - | Distribution::ZERO => {} - _ => panic!( - "invalid GLWESecret distribution: must be BinaryBlock, BinaryFixed or BinaryProb (or ZERO for debugging)" - ), - } - } - - self.dist = sk_lwe.dist(); - - let mut pt: ScalarZnx> = ScalarZnx::alloc(sk_glwe.n().into(), 1); - let sk_ref: ScalarZnx<&[u8]> = sk_lwe.data().to_ref(); - - let mut source_xa: Source = Source::new(seed_xa); - - self.keys.iter_mut().enumerate().for_each(|(i, ggsw)| { - pt.at_mut(0, 0)[0] = sk_ref.at(0, 0)[i]; - ggsw.encrypt_sk( - module, - &pt, - sk_glwe, - source_xa.new_seed(), - source_xe, - scratch, - ); - }); - } -} diff --git a/poulpy-schemes/src/tfhe/blind_rotation/encryption/key.rs b/poulpy-schemes/src/tfhe/blind_rotation/encryption/key.rs new file mode 100644 index 0000000..b250e07 --- /dev/null +++ b/poulpy-schemes/src/tfhe/blind_rotation/encryption/key.rs @@ -0,0 +1,60 @@ +use poulpy_hal::{ + layouts::{Backend, DataMut, Scratch}, + source::Source, +}; + +use poulpy_core::{ + GetDistribution, ScratchTakeCore, + layouts::{GGSWInfos, GLWEInfos, GLWESecretPreparedToRef, LWEInfos, LWESecretToRef}, +}; + +use crate::tfhe::blind_rotation::{BlindRotationAlgo, BlindRotationKey}; + +pub trait BlindRotationKeyEncryptSk { + fn blind_rotation_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGSWInfos; + + #[allow(clippy::too_many_arguments)] + fn blind_rotation_key_encrypt_sk( + &self, + res: &mut BlindRotationKey, + sk_glwe: &S0, + sk_lwe: &S1, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + D: DataMut, + S0: GLWESecretPreparedToRef + GLWEInfos, + S1: LWESecretToRef + LWEInfos + GetDistribution; +} + +impl BlindRotationKey { + pub fn encrypt_sk( + &mut self, + module: &M, + sk_glwe: &S0, + sk_lwe: &S1, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + S0: GLWESecretPreparedToRef + GLWEInfos, + S1: LWESecretToRef + LWEInfos + GetDistribution, + Scratch: ScratchTakeCore, + M: BlindRotationKeyEncryptSk, + { + module.blind_rotation_key_encrypt_sk(self, sk_glwe, sk_lwe, source_xa, source_xe, scratch); + } +} + +impl BlindRotationKey, BRA> { + pub fn encrypt_sk_tmp_bytes(module: &M, infos: &A) -> usize + where + A: GGSWInfos, + M: BlindRotationKeyEncryptSk, + { + module.blind_rotation_key_encrypt_sk_tmp_bytes(infos) + } +} diff --git a/poulpy-schemes/src/tfhe/blind_rotation/encryption/key_compressed.rs b/poulpy-schemes/src/tfhe/blind_rotation/encryption/key_compressed.rs new file mode 100644 index 0000000..4898365 --- /dev/null +++ b/poulpy-schemes/src/tfhe/blind_rotation/encryption/key_compressed.rs @@ -0,0 +1,30 @@ +use poulpy_core::{ + GetDistribution, + layouts::{GGSWInfos, GLWEInfos, GLWESecretPreparedToRef, LWEInfos, LWESecretToRef}, +}; +use poulpy_hal::{ + layouts::{Backend, DataMut, Scratch}, + source::Source, +}; + +use crate::tfhe::blind_rotation::{BlindRotationAlgo, BlindRotationKeyCompressed}; + +pub trait BlindRotationKeyCompressedEncryptSk { + fn blind_rotation_key_compressed_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGSWInfos; + + #[allow(clippy::too_many_arguments)] + fn blind_rotation_key_compressed_encrypt_sk( + &self, + res: &mut BlindRotationKeyCompressed, + sk_glwe: &S0, + sk_lwe: &S1, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + D: DataMut, + S0: GLWESecretPreparedToRef + GLWEInfos, + S1: LWESecretToRef + LWEInfos + GetDistribution; +} diff --git a/poulpy-schemes/src/tfhe/blind_rotation/encryption/mod.rs b/poulpy-schemes/src/tfhe/blind_rotation/encryption/mod.rs new file mode 100644 index 0000000..62623fc --- /dev/null +++ b/poulpy-schemes/src/tfhe/blind_rotation/encryption/mod.rs @@ -0,0 +1,5 @@ +mod key; +mod key_compressed; + +pub use key::*; +pub use key_compressed::*; diff --git a/poulpy-schemes/src/tfhe/blind_rotation/key_prepared.rs b/poulpy-schemes/src/tfhe/blind_rotation/key_prepared.rs deleted file mode 100644 index 8719de4..0000000 --- a/poulpy-schemes/src/tfhe/blind_rotation/key_prepared.rs +++ /dev/null @@ -1,130 +0,0 @@ -use poulpy_hal::{ - api::{SvpPPolAlloc, SvpPrepare, VmpPMatAlloc, VmpPrepare}, - layouts::{Backend, Data, DataMut, DataRef, Module, ScalarZnx, Scratch, SvpPPol}, -}; - -use std::marker::PhantomData; - -use poulpy_core::{ - Distribution, - layouts::{ - Base2K, Degree, Dnum, Dsize, GGSWInfos, GLWEInfos, LWEInfos, Rank, TorusPrecision, - prepared::{GGSWCiphertextPrepared, Prepare, PrepareAlloc}, - }, -}; - -use crate::tfhe::blind_rotation::{BlindRotationAlgo, BlindRotationKey, BlindRotationKeyInfos, utils::set_xai_plus_y}; - -pub trait BlindRotationKeyPreparedAlloc { - fn alloc(module: &Module, infos: &A) -> Self - where - A: BlindRotationKeyInfos; -} - -#[derive(PartialEq, Eq)] -pub struct BlindRotationKeyPrepared { - pub(crate) data: Vec>, - pub(crate) dist: Distribution, - pub(crate) x_pow_a: Option, B>>>, - pub(crate) _phantom: PhantomData, -} - -impl BlindRotationKeyInfos for BlindRotationKeyPrepared { - fn n_glwe(&self) -> Degree { - self.n() - } - - fn n_lwe(&self) -> Degree { - Degree(self.data.len() as u32) - } -} - -impl LWEInfos for BlindRotationKeyPrepared { - fn base2k(&self) -> Base2K { - self.data[0].base2k() - } - - fn k(&self) -> TorusPrecision { - self.data[0].k() - } - - fn n(&self) -> Degree { - self.data[0].n() - } - - fn size(&self) -> usize { - self.data[0].size() - } -} - -impl GLWEInfos for BlindRotationKeyPrepared { - fn rank(&self) -> Rank { - self.data[0].rank() - } -} -impl GGSWInfos for BlindRotationKeyPrepared { - fn dsize(&self) -> poulpy_core::layouts::Dsize { - Dsize(1) - } - - fn dnum(&self) -> Dnum { - self.data[0].dnum() - } -} - -impl BlindRotationKeyPrepared { - pub fn block_size(&self) -> usize { - match self.dist { - Distribution::BinaryBlock(value) => value, - _ => 1, - } - } -} - -impl PrepareAlloc, BRA, B>> - for BlindRotationKey -where - BlindRotationKeyPrepared, BRA, B>: BlindRotationKeyPreparedAlloc, - BlindRotationKeyPrepared, BRA, B>: Prepare>, -{ - fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> BlindRotationKeyPrepared, BRA, B> { - let mut brk: BlindRotationKeyPrepared, BRA, B> = BlindRotationKeyPrepared::alloc(module, self); - brk.prepare(module, self, scratch); - brk - } -} - -impl Prepare> - for BlindRotationKeyPrepared -where - Module: VmpPMatAlloc + VmpPrepare + SvpPPolAlloc + SvpPrepare, -{ - fn prepare(&mut self, module: &Module, other: &BlindRotationKey, scratch: &mut Scratch) { - #[cfg(debug_assertions)] - { - assert_eq!(self.data.len(), other.keys.len()); - } - - let n: usize = other.n().as_usize(); - - self.data - .iter_mut() - .zip(other.keys.iter()) - .for_each(|(ggsw_prepared, other)| { - ggsw_prepared.prepare(module, other, scratch); - }); - - self.dist = other.dist; - - if let Distribution::BinaryBlock(_) = other.dist { - let mut x_pow_a: Vec, B>> = Vec::with_capacity(n << 1); - let mut buf: ScalarZnx> = ScalarZnx::alloc(n, 1); - (0..n << 1).for_each(|i| { - let mut res: SvpPPol, B> = module.svp_ppol_alloc(1); - set_xai_plus_y(module, i, 0, &mut res, &mut buf); - x_pow_a.push(res); - }); - self.x_pow_a = Some(x_pow_a); - } - } -} diff --git a/poulpy-schemes/src/tfhe/blind_rotation/key.rs b/poulpy-schemes/src/tfhe/blind_rotation/layouts/key.rs similarity index 87% rename from poulpy-schemes/src/tfhe/blind_rotation/key.rs rename to poulpy-schemes/src/tfhe/blind_rotation/layouts/key.rs index 86dbd21..182c973 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/key.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/layouts/key.rs @@ -1,5 +1,5 @@ use poulpy_hal::{ - layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, Scratch, WriterTo}, + layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, WriterTo}, source::Source, }; @@ -7,10 +7,7 @@ use std::{fmt, marker::PhantomData}; use poulpy_core::{ Distribution, - layouts::{ - Base2K, Degree, Dnum, Dsize, GGSWCiphertext, GGSWInfos, GLWEInfos, LWEInfos, LWESecret, Rank, TorusPrecision, - prepared::GLWESecretPrepared, - }, + layouts::{Base2K, Degree, Dnum, Dsize, GGSW, GGSWInfos, GLWEInfos, LWEInfos, Rank, TorusPrecision}, }; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; @@ -81,28 +78,31 @@ pub trait BlindRotationKeyAlloc { A: BlindRotationKeyInfos; } -pub trait BlindRotationKeyEncryptSk { - #[allow(clippy::too_many_arguments)] - fn encrypt_sk( - &mut self, - module: &Module, - sk_glwe: &GLWESecretPrepared, - sk_lwe: &LWESecret, - source_xa: &mut Source, - source_xe: &mut Source, - scratch: &mut Scratch, - ) where - DataSkGLWE: DataRef, - DataSkLWE: DataRef; -} - #[derive(Clone)] pub struct BlindRotationKey { - pub(crate) keys: Vec>, + pub(crate) keys: Vec>, pub(crate) dist: Distribution, pub(crate) _phantom: PhantomData, } +pub trait BlindRotationKeyFactory { + fn blind_rotation_key_alloc(infos: &A) -> BlindRotationKey, BRA> + where + A: BlindRotationKeyInfos; +} + +impl BlindRotationKey, BRA> +where + Self: BlindRotationKeyFactory, +{ + pub fn alloc(infos: &A) -> BlindRotationKey, BRA> + where + A: BlindRotationKeyInfos, + { + Self::blind_rotation_key_alloc(infos) + } +} + impl fmt::Debug for BlindRotationKey { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{self}") diff --git a/poulpy-schemes/src/tfhe/blind_rotation/key_compressed.rs b/poulpy-schemes/src/tfhe/blind_rotation/layouts/key_compressed.rs similarity index 87% rename from poulpy-schemes/src/tfhe/blind_rotation/key_compressed.rs rename to poulpy-schemes/src/tfhe/blind_rotation/layouts/key_compressed.rs index 51ff139..26539e1 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/key_compressed.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/layouts/key_compressed.rs @@ -8,18 +8,36 @@ use std::{fmt, marker::PhantomData}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use poulpy_core::{ Distribution, - layouts::{Base2K, Degree, Dsize, GGSWInfos, GLWEInfos, LWEInfos, TorusPrecision, compressed::GGSWCiphertextCompressed}, + layouts::{Base2K, Degree, Dsize, GGSWInfos, GLWEInfos, LWEInfos, TorusPrecision, compressed::GGSWCompressed}, }; use crate::tfhe::blind_rotation::{BlindRotationAlgo, BlindRotationKeyInfos}; #[derive(Clone)] pub struct BlindRotationKeyCompressed { - pub(crate) keys: Vec>, + pub(crate) keys: Vec>, pub(crate) dist: Distribution, pub(crate) _phantom: PhantomData, } +pub trait BlindRotationKeyCompressedFactory { + fn blind_rotation_key_compressed_alloc(infos: &A) -> BlindRotationKeyCompressed, BRA> + where + A: BlindRotationKeyInfos; +} + +impl BlindRotationKeyCompressed, BRA> +where + Self: BlindRotationKeyCompressedFactory, +{ + pub fn alloc(infos: &A) -> BlindRotationKeyCompressed, BRA> + where + A: BlindRotationKeyInfos, + { + Self::blind_rotation_key_compressed_alloc(infos) + } +} + impl fmt::Debug for BlindRotationKeyCompressed { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{self}") diff --git a/poulpy-schemes/src/tfhe/blind_rotation/layouts/key_prepared.rs b/poulpy-schemes/src/tfhe/blind_rotation/layouts/key_prepared.rs new file mode 100644 index 0000000..cf533c1 --- /dev/null +++ b/poulpy-schemes/src/tfhe/blind_rotation/layouts/key_prepared.rs @@ -0,0 +1,116 @@ +use poulpy_hal::layouts::{Backend, Data, DataMut, DataRef, Scratch, SvpPPol}; + +use std::marker::PhantomData; + +use poulpy_core::{ + Distribution, + layouts::{Base2K, Degree, Dnum, Dsize, GGSWInfos, GLWEInfos, LWEInfos, Rank, TorusPrecision, prepared::GGSWPrepared}, +}; + +use crate::tfhe::blind_rotation::{BlindRotationAlgo, BlindRotationKey, BlindRotationKeyInfos}; + +pub trait BlindRotationKeyPreparedFactory { + fn blind_rotation_key_prepared_alloc(&self, infos: &A) -> BlindRotationKeyPrepared, BRA, BE> + where + A: BlindRotationKeyInfos; + + fn blind_rotation_key_prepare_tmp_bytes(&self, infos: &A) -> usize + where + A: BlindRotationKeyInfos; + + fn prepare_blind_rotation_key( + &self, + res: &mut BlindRotationKeyPrepared, + other: &BlindRotationKey, + scratch: &mut Scratch, + ) where + DM: DataMut, + DR: DataRef; +} + +impl BlindRotationKeyPrepared, BRA, BE> { + pub fn alloc(module: &M, infos: &A) -> Self + where + A: BlindRotationKeyInfos, + M: BlindRotationKeyPreparedFactory, + { + module.blind_rotation_key_prepared_alloc(infos) + } + + pub fn prepare_tmp_bytes(module: &M, infos: &A) -> usize + where + A: BlindRotationKeyInfos, + M: BlindRotationKeyPreparedFactory, + { + module.blind_rotation_key_prepare_tmp_bytes(infos) + } +} + +impl BlindRotationKeyPrepared { + pub fn prepare(&mut self, module: &M, other: &BlindRotationKey, scratch: &mut Scratch) + where + M: BlindRotationKeyPreparedFactory, + { + module.prepare_blind_rotation_key(self, other, scratch); + } +} + +#[derive(PartialEq, Eq)] +pub struct BlindRotationKeyPrepared { + pub(crate) data: Vec>, + pub(crate) dist: Distribution, + pub(crate) x_pow_a: Option, B>>>, + pub(crate) _phantom: PhantomData, +} + +impl BlindRotationKeyInfos for BlindRotationKeyPrepared { + fn n_glwe(&self) -> Degree { + self.n() + } + + fn n_lwe(&self) -> Degree { + Degree(self.data.len() as u32) + } +} + +impl LWEInfos for BlindRotationKeyPrepared { + fn base2k(&self) -> Base2K { + self.data[0].base2k() + } + + fn k(&self) -> TorusPrecision { + self.data[0].k() + } + + fn n(&self) -> Degree { + self.data[0].n() + } + + fn size(&self) -> usize { + self.data[0].size() + } +} + +impl GLWEInfos for BlindRotationKeyPrepared { + fn rank(&self) -> Rank { + self.data[0].rank() + } +} +impl GGSWInfos for BlindRotationKeyPrepared { + fn dsize(&self) -> poulpy_core::layouts::Dsize { + Dsize(1) + } + + fn dnum(&self) -> Dnum { + self.data[0].dnum() + } +} + +impl BlindRotationKeyPrepared { + pub fn block_size(&self) -> usize { + match self.dist { + Distribution::BinaryBlock(value) => value, + _ => 1, + } + } +} diff --git a/poulpy-schemes/src/tfhe/blind_rotation/layouts/mod.rs b/poulpy-schemes/src/tfhe/blind_rotation/layouts/mod.rs new file mode 100644 index 0000000..f3e285f --- /dev/null +++ b/poulpy-schemes/src/tfhe/blind_rotation/layouts/mod.rs @@ -0,0 +1,6 @@ +mod key; +mod key_compressed; +mod key_prepared; +pub use key::*; +pub use key_compressed::*; +pub use key_prepared::*; diff --git a/poulpy-schemes/src/tfhe/blind_rotation/lut.rs b/poulpy-schemes/src/tfhe/blind_rotation/lut.rs index f8e9006..74c0441 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/lut.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/lut.rs @@ -1,3 +1,4 @@ +use poulpy_core::layouts::{Base2K, Degree, TorusPrecision}; use poulpy_hal::{ api::{ ScratchOwnedAlloc, ScratchOwnedBorrow, TakeSlice, VecZnxCopy, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, @@ -13,32 +14,97 @@ pub enum LookUpTableRotationDirection { Right, } -pub struct LookUpTable { +pub struct LookUpTableLayout { + pub n: Degree, + pub extension_factor: usize, + pub k: TorusPrecision, + pub base2k: Base2K, +} + +pub trait LookupTableInfos { + fn n(&self) -> Degree; + fn extension_factor(&self) -> usize; + fn k(&self) -> TorusPrecision; + fn base2k(&self) -> Base2K; + fn size(&self) -> usize; +} + +impl LookupTableInfos for LookUpTableLayout { + fn base2k(&self) -> Base2K { + self.base2k + } + + fn extension_factor(&self) -> usize { + self.extension_factor + } + + fn k(&self) -> TorusPrecision { + self.k + } + + fn size(&self) -> usize { + self.k().as_usize().div_ceil(self.base2k().as_usize()) + } + + fn n(&self) -> Degree { + self.n + } +} + +pub struct LookupTable { pub(crate) data: Vec>>, pub(crate) rot_dir: LookUpTableRotationDirection, - pub(crate) base2k: usize, - pub(crate) k: usize, + pub(crate) base2k: Base2K, + pub(crate) k: TorusPrecision, pub(crate) drift: usize, } -impl LookUpTable { - pub fn alloc(module: &Module, base2k: usize, k: usize, extension_factor: usize) -> Self { +impl LookupTableInfos for LookupTable { + fn base2k(&self) -> Base2K { + self.base2k + } + + fn extension_factor(&self) -> usize { + self.data.len() + } + + fn k(&self) -> TorusPrecision { + self.k + } + + fn n(&self) -> Degree { + self.data[0].n().into() + } + + fn size(&self) -> usize { + self.data[0].size() + } +} + +pub trait LookupTableFactory { + fn lookup_table_set(&self, res: &mut LookupTable, f: &[i64], k: usize); + fn lookup_table_rotate(&self, k: i64, res: &mut LookupTable); +} + +impl LookupTable { + pub fn alloc(infos: &A) -> Self + where + A: LookupTableInfos, + { #[cfg(debug_assertions)] { assert!( - extension_factor & (extension_factor - 1) == 0, - "extension_factor must be a power of two but is: {extension_factor}" + infos.extension_factor() & (infos.extension_factor() - 1) == 0, + "extension_factor must be a power of two but is: {}", + infos.extension_factor() ); } - let size: usize = k.div_ceil(base2k); - let mut data: Vec>> = Vec::with_capacity(extension_factor); - (0..extension_factor).for_each(|_| { - data.push(VecZnx::alloc(module.n(), 1, size)); - }); Self { - data, - base2k, - k, + data: (0..infos.extension_factor()) + .map(|_| VecZnx::alloc(infos.n().into(), 1, infos.size())) + .collect(), + base2k: infos.base2k(), + k: infos.k(), drift: 0, rot_dir: LookUpTableRotationDirection::Left, } @@ -68,115 +134,18 @@ impl LookUpTable { self.rot_dir = rot_dir } - pub fn set(&mut self, module: &Module, f: &[i64], k: usize) + pub fn set(&mut self, module: &M, f: &[i64], k: usize) where - Module: VecZnxRotateInplace - + VecZnxNormalizeInplace - + VecZnxNormalizeTmpBytes - + VecZnxSwitchRing - + VecZnxCopy - + VecZnxRotateInplaceTmpBytes, - ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, - Scratch: TakeSlice, + M: LookupTableFactory, { - assert!(f.len() <= module.n()); - - let base2k: usize = self.base2k; - - let mut scratch: ScratchOwned = ScratchOwned::alloc(module.vec_znx_normalize_tmp_bytes() | (self.domain_size() << 3)); - - // Get the number minimum limb to store the message modulus - let limbs: usize = k.div_ceil(base2k); - - #[cfg(debug_assertions)] - { - assert!(f.len() <= module.n()); - assert!( - (max_bit_size(f) + (k % base2k) as u32) < i64::BITS, - "overflow: max(|f|) << (k%base2k) > i64::BITS" - ); - assert!(limbs <= self.data[0].size()); - } - - // Scaling factor - let mut scale = 1; - if !k.is_multiple_of(base2k) { - scale <<= base2k - (k % base2k); - } - - // #elements in lookup table - let f_len: usize = f.len(); - - // If LUT size > TakeScalarZnx - let domain_size: usize = self.domain_size(); - - let size: usize = self.k.div_ceil(self.base2k); - - // Equivalent to AUTO([f(0), -f(n-1), -f(n-2), ..., -f(1)], -1) - let mut lut_full: VecZnx> = VecZnx::alloc(domain_size, 1, size); - - let lut_at: &mut [i64] = lut_full.at_mut(0, limbs - 1); - - let step: usize = domain_size.div_round(f_len); - - f.iter().enumerate().for_each(|(i, fi)| { - let start: usize = i * step; - let end: usize = start + step; - lut_at[start..end].fill(fi * scale); - }); - - let drift: usize = step >> 1; - - // Rotates half the step to the left - - if self.extension_factor() > 1 { - let (tmp, _) = scratch.borrow().take_slice(lut_full.n()); - - for i in 0..self.extension_factor() { - module.vec_znx_switch_ring(&mut self.data[i], 0, &lut_full, 0); - if i < self.extension_factor() { - vec_znx_rotate_inplace::<_, ZnxRef>(-1, &mut lut_full, 0, tmp); - } - } - } else { - module.vec_znx_copy(&mut self.data[0], 0, &lut_full, 0); - } - - for a in self.data.iter_mut() { - module.vec_znx_normalize_inplace(self.base2k, a, 0, scratch.borrow()); - } - - self.rotate(module, -(drift as i64)); - - self.drift = drift + module.lookup_table_set(self, f, k); } - #[allow(dead_code)] - pub(crate) fn rotate(&mut self, module: &Module, k: i64) + pub(crate) fn rotate(&mut self, module: &M, k: i64) where - Module: VecZnxRotateInplace + VecZnxRotateInplaceTmpBytes, - ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + M: LookupTableFactory, { - let extension_factor: usize = self.extension_factor(); - let two_n: usize = 2 * self.data[0].n(); - let two_n_ext: usize = two_n * extension_factor; - - let mut scratch: ScratchOwned<_> = ScratchOwned::alloc(module.vec_znx_rotate_inplace_tmp_bytes()); - - let k_pos: usize = ((k + two_n_ext as i64) % two_n_ext as i64) as usize; - - let k_hi: usize = k_pos / extension_factor; - let k_lo: usize = k_pos % extension_factor; - - (0..extension_factor - k_lo).for_each(|i| { - module.vec_znx_rotate_inplace(k_hi as i64, &mut self.data[i], 0, scratch.borrow()); - }); - - (extension_factor - k_lo..extension_factor).for_each(|i| { - module.vec_znx_rotate_inplace(k_hi as i64 + 1, &mut self.data[i], 0, scratch.borrow()); - }); - - self.data.rotate_right(k_lo); + module.lookup_table_rotate(k, self); } } @@ -204,3 +173,116 @@ fn max_bit_size(vec: &[i64]) -> u32 { .max() .unwrap_or(0) } + +impl LookupTableFactory for Module +where + Self: VecZnxRotateInplace + + VecZnxNormalizeInplace + + VecZnxNormalizeTmpBytes + + VecZnxSwitchRing + + VecZnxCopy + + VecZnxRotateInplaceTmpBytes + + VecZnxRotateInplace + + VecZnxRotateInplaceTmpBytes, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: TakeSlice, +{ + fn lookup_table_set(&self, res: &mut LookupTable, f: &[i64], k: usize) { + assert!(f.len() <= self.n()); + + let base2k: usize = res.base2k.into(); + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + self.vec_znx_normalize_tmp_bytes() + .max(res.domain_size() << 3), + ); + + // Get the number minimum limb to store the message modulus + let limbs: usize = k.div_ceil(base2k); + + #[cfg(debug_assertions)] + { + assert!(f.len() <= self.n()); + assert!( + (max_bit_size(f) + (k % base2k) as u32) < i64::BITS, + "overflow: max(|f|) << (k%base2k) > i64::BITS" + ); + assert!(limbs <= res.data[0].size()); + } + + // Scaling factor + let mut scale = 1; + if !k.is_multiple_of(base2k) { + scale <<= base2k - (k % base2k); + } + + // #elements in lookup table + let f_len: usize = f.len(); + + // If LUT size > TakeScalarZnx + let domain_size: usize = res.domain_size(); + + let size: usize = res.k.div_ceil(res.base2k) as usize; + + // Equivalent to AUTO([f(0), -f(n-1), -f(n-2), ..., -f(1)], -1) + let mut lut_full: VecZnx> = VecZnx::alloc(domain_size, 1, size); + + let lut_at: &mut [i64] = lut_full.at_mut(0, limbs - 1); + + let step: usize = domain_size.div_round(f_len); + + f.iter().enumerate().for_each(|(i, fi)| { + let start: usize = i * step; + let end: usize = start + step; + lut_at[start..end].fill(fi * scale); + }); + + let drift: usize = step >> 1; + + // Rotates half the step to the left + + if res.extension_factor() > 1 { + let (tmp, _) = scratch.borrow().take_slice(lut_full.n()); + + for i in 0..res.extension_factor() { + self.vec_znx_switch_ring(&mut res.data[i], 0, &lut_full, 0); + if i < res.extension_factor() { + vec_znx_rotate_inplace::<_, ZnxRef>(-1, &mut lut_full, 0, tmp); + } + } + } else { + self.vec_znx_copy(&mut res.data[0], 0, &lut_full, 0); + } + + for a in res.data.iter_mut() { + self.vec_znx_normalize_inplace(res.base2k.into(), a, 0, scratch.borrow()); + } + + res.rotate(self, -(drift as i64)); + + res.drift = drift + } + + fn lookup_table_rotate(&self, k: i64, res: &mut LookupTable) { + let extension_factor: usize = res.extension_factor(); + let two_n: usize = 2 * res.data[0].n(); + let two_n_ext: usize = two_n * extension_factor; + + let mut scratch: ScratchOwned<_> = ScratchOwned::alloc(self.vec_znx_rotate_inplace_tmp_bytes()); + + let k_pos: usize = ((k + two_n_ext as i64) % two_n_ext as i64) as usize; + + let k_hi: usize = k_pos / extension_factor; + let k_lo: usize = k_pos % extension_factor; + + (0..extension_factor - k_lo).for_each(|i| { + self.vec_znx_rotate_inplace(k_hi as i64, &mut res.data[i], 0, scratch.borrow()); + }); + + (extension_factor - k_lo..extension_factor).for_each(|i| { + self.vec_znx_rotate_inplace(k_hi as i64 + 1, &mut res.data[i], 0, scratch.borrow()); + }); + + res.data.rotate_right(k_lo); + } +} diff --git a/poulpy-schemes/src/tfhe/blind_rotation/mod.rs b/poulpy-schemes/src/tfhe/blind_rotation/mod.rs index bd83a08..93da18b 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/mod.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/mod.rs @@ -1,35 +1,11 @@ -mod cggi_algo; -mod cggi_key; -mod key; -mod key_compressed; -mod key_prepared; +mod algorithms; +mod encryption; +mod layouts; mod lut; mod utils; -pub use cggi_algo::*; -pub use key::*; -pub use key_compressed::*; -pub use key_prepared::*; +pub use algorithms::*; +pub use encryption::*; +pub use layouts::*; pub use lut::*; - pub mod tests; - -use poulpy_core::layouts::{GLWECiphertext, LWECiphertext}; -use poulpy_hal::layouts::{Backend, DataMut, DataRef, Module, Scratch}; - -pub trait BlindRotationAlgo {} - -#[derive(Clone)] -pub struct CGGI {} -impl BlindRotationAlgo for CGGI {} - -pub trait BlincRotationExecute { - fn execute( - &self, - module: &Module, - res: &mut GLWECiphertext, - lwe: &LWECiphertext, - lut: &LookUpTable, - scratch: &mut Scratch, - ); -} diff --git a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs index 190adff..1651987 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs @@ -1,89 +1,39 @@ use poulpy_hal::{ - api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, - SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, - VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAdd, VecZnxDftAddInplace, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftSubInplace, VecZnxDftZero, VecZnxFillUniform, VecZnxIdftApply, - VecZnxIdftApplyConsume, VecZnxIdftApplyTmpBytes, VecZnxMulXpMinusOneInplace, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxSub, VecZnxSubInplace, - VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, ZnAddNormal, - ZnFillUniform, ZnNormalizeInplace, - }, - layouts::{Backend, Module, ScratchOwned, ZnxView}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeSliceImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl, - TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl, - }, + api::{ScratchOwnedAlloc, ScratchOwnedBorrow}, + layouts::{Backend, Scratch, ScratchOwned, ZnxView}, source::Source, }; use crate::tfhe::blind_rotation::{ - BlincRotationExecute, BlindRotationKey, BlindRotationKeyAlloc, BlindRotationKeyEncryptSk, BlindRotationKeyLayout, - BlindRotationKeyPrepared, CGGI, LookUpTable, cggi_blind_rotate_scratch_space, mod_switch_2n, + BlindRotationAlgo, BlindRotationExecute, BlindRotationKey, BlindRotationKeyEncryptSk, BlindRotationKeyFactory, + BlindRotationKeyLayout, BlindRotationKeyPrepared, BlindRotationKeyPreparedFactory, LookUpTableLayout, LookupTable, + LookupTableFactory, mod_switch_2n, }; -use poulpy_core::layouts::{ - GLWECiphertext, GLWECiphertextLayout, GLWEPlaintext, GLWESecret, LWECiphertext, LWECiphertextLayout, LWECiphertextToRef, - LWEInfos, LWEPlaintext, LWESecret, - prepared::{GLWESecretPrepared, PrepareAlloc}, +use poulpy_core::{ + GLWEDecrypt, LWEEncryptSk, ScratchTakeCore, + layouts::{ + GLWE, GLWELayout, GLWEPlaintext, GLWESecret, GLWESecretPreparedFactory, LWE, LWEInfos, LWELayout, LWEPlaintext, + LWESecret, LWEToRef, prepared::GLWESecretPrepared, + }, }; -pub fn test_blind_rotation(module: &Module, n_lwe: usize, block_size: usize, extension_factor: usize) -where - Module: VecZnxBigAllocBytes - + VecZnxDftAllocBytes - + SvpPPolAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VecZnxIdftApplyTmpBytes - + VecZnxIdftApply - + VecZnxDftAdd - + VecZnxDftAddInplace - + VecZnxDftApply - + VecZnxDftZero - + SvpApplyDftToDft - + VecZnxDftSubInplace - + VecZnxBigAddSmallInplace - + VecZnxRotate - + VecZnxAddInplace - + VecZnxSubInplace - + VecZnxNormalize - + VecZnxNormalizeInplace - + VecZnxCopy - + VecZnxMulXpMinusOneInplace - + SvpPrepare - + SvpPPolAlloc - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxBigAddInplace - + VecZnxBigNormalize - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxAddNormal - + VecZnxAddScalarInplace - + VecZnxRotateInplace - + VecZnxSwitchRing - + VecZnxSub - + VmpPMatAlloc - + VmpPrepare - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + ZnFillUniform - + ZnAddNormal - + VecZnxRotateInplaceTmpBytes - + ZnNormalizeInplace, - B: Backend - + VecZnxDftAllocBytesImpl - + VecZnxBigAllocBytesImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeVecZnxDftSliceImpl - + ScratchAvailableImpl - + TakeVecZnxImpl - + TakeVecZnxSliceImpl - + TakeSliceImpl, +pub fn test_blind_rotation( + module: &M, + n_lwe: usize, + block_size: usize, + extension_factor: usize, +) where + M: BlindRotationKeyEncryptSk + + BlindRotationKeyPreparedFactory + + BlindRotationExecute + + LookupTableFactory + + GLWESecretPreparedFactory + + GLWEDecrypt + + LWEEncryptSk, + BlindRotationKey, BRA>: BlindRotationKeyFactory, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchTakeCore, { let n_glwe: usize = module.n(); let base2k: usize = 19; @@ -111,31 +61,30 @@ where rank: rank.into(), }; - let glwe_infos: GLWECiphertextLayout = GLWECiphertextLayout { + let glwe_infos: GLWELayout = GLWELayout { n: n_glwe.into(), base2k: base2k.into(), k: k_res.into(), rank: rank.into(), }; - let lwe_infos: LWECiphertextLayout = LWECiphertextLayout { + let lwe_infos: LWELayout = LWELayout { n: n_lwe.into(), k: k_lwe.into(), base2k: base2k.into(), }; - let mut scratch: ScratchOwned = ScratchOwned::::alloc(BlindRotationKey::generate_from_sk_scratch_space( - module, &brk_infos, - )); + let mut scratch: ScratchOwned = ScratchOwned::::alloc(BlindRotationKey::encrypt_sk_tmp_bytes(module, &brk_infos)); - let mut sk_glwe: GLWESecret> = GLWESecret::alloc(&glwe_infos); + let mut sk_glwe: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_infos); sk_glwe.fill_ternary_prob(0.5, &mut source_xs); - let sk_glwe_dft: GLWESecretPrepared, B> = sk_glwe.prepare_alloc(module, scratch.borrow()); + let mut sk_glwe_dft: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc_from_infos(module, &glwe_infos); + sk_glwe_dft.prepare(module, &sk_glwe); let mut sk_lwe: LWESecret> = LWESecret::alloc(n_lwe.into()); sk_lwe.fill_binary_block(block_size, &mut source_xs); - let mut scratch_br: ScratchOwned = ScratchOwned::::alloc(cggi_blind_rotate_scratch_space( + let mut scratch_br: ScratchOwned = ScratchOwned::::alloc(BlindRotationKeyPrepared::execute_tmp_bytes( module, block_size, extension_factor, @@ -143,7 +92,7 @@ where &brk_infos, )); - let mut brk: BlindRotationKey, CGGI> = BlindRotationKey::, CGGI>::alloc(&brk_infos); + let mut brk: BlindRotationKey, BRA> = BlindRotationKey::, BRA>::alloc(&brk_infos); brk.encrypt_sk( module, @@ -154,9 +103,9 @@ where scratch.borrow(), ); - let mut lwe: LWECiphertext> = LWECiphertext::alloc(&lwe_infos); + let mut lwe: LWE> = LWE::alloc_from_infos(&lwe_infos); - let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc(&lwe_infos); + let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc_from_infos(&lwe_infos); let x: i64 = 15 % (message_modulus as i64); @@ -172,16 +121,24 @@ where .enumerate() .for_each(|(i, x)| *x = f(i as i64)); - let mut lut: LookUpTable = LookUpTable::alloc(module, base2k, k_lut, extension_factor); + let lut_infos = LookUpTableLayout { + n: module.n().into(), + extension_factor, + k: k_lut.into(), + base2k: base2k.into(), + }; + + let mut lut: LookupTable = LookupTable::alloc(&lut_infos); lut.set(module, &f_vec, log_message_modulus + 1); - let mut res: GLWECiphertext> = GLWECiphertext::alloc(&glwe_infos); + let mut res: GLWE> = GLWE::alloc_from_infos(&glwe_infos); - let brk_prepared: BlindRotationKeyPrepared, CGGI, B> = brk.prepare_alloc(module, scratch.borrow()); + let mut brk_prepared: BlindRotationKeyPrepared, BRA, BE> = BlindRotationKeyPrepared::alloc(module, &brk); + brk_prepared.prepare(module, &brk, scratch_br.borrow()); brk_prepared.execute(module, &mut res, &lwe, &lut, scratch_br.borrow()); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_infos); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_infos); res.decrypt(module, &mut pt_have, &sk_glwe_dft, scratch.borrow()); diff --git a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_lut.rs b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_lut.rs index bb6492c..80d7663 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_lut.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_lut.rs @@ -1,25 +1,12 @@ use std::vec; -use poulpy_hal::{ - api::{ - VecZnxCopy, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, - VecZnxSwitchRing, - }, - layouts::{Backend, Module}, - oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeSliceImpl}, -}; +use poulpy_hal::api::ModuleN; -use crate::tfhe::blind_rotation::{DivRound, LookUpTable}; +use crate::tfhe::blind_rotation::{DivRound, LookUpTableLayout, LookupTable, LookupTableFactory}; -pub fn test_lut_standard(module: &Module) +pub fn test_lut_standard(module: &M) where - Module: VecZnxRotateInplace - + VecZnxNormalizeInplace - + VecZnxNormalizeTmpBytes - + VecZnxSwitchRing - + VecZnxCopy - + VecZnxRotateInplaceTmpBytes, - B: Backend + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl + TakeSliceImpl, + M: LookupTableFactory + ModuleN, { let base2k: usize = 20; let k_lut: usize = 40; @@ -33,7 +20,14 @@ where .enumerate() .for_each(|(i, x)| *x = (i as i64) - 8); - let mut lut: LookUpTable = LookUpTable::alloc(module, base2k, k_lut, extension_factor); + let lut_infos: LookUpTableLayout = LookUpTableLayout { + n: module.n().into(), + extension_factor, + k: k_lut.into(), + base2k: base2k.into(), + }; + + let mut lut: LookupTable = LookupTable::alloc(&lut_infos); lut.set(module, &f, log_scale); let half_step: i64 = lut.domain_size().div_round(message_modulus << 1) as i64; @@ -51,15 +45,9 @@ where }); } -pub fn test_lut_extended(module: &Module) +pub fn test_lut_extended(module: &M) where - Module: VecZnxRotateInplace - + VecZnxNormalizeInplace - + VecZnxNormalizeTmpBytes - + VecZnxSwitchRing - + VecZnxCopy - + VecZnxRotateInplaceTmpBytes, - B: Backend + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl + TakeSliceImpl, + M: LookupTableFactory + ModuleN, { let base2k: usize = 20; let k_lut: usize = 40; @@ -73,7 +61,14 @@ where .enumerate() .for_each(|(i, x)| *x = (i as i64) - 8); - let mut lut: LookUpTable = LookUpTable::alloc(module, base2k, k_lut, extension_factor); + let lut_infos: LookUpTableLayout = LookUpTableLayout { + n: module.n().into(), + extension_factor, + k: k_lut.into(), + base2k: base2k.into(), + }; + + let mut lut: LookupTable = LookupTable::alloc(&lut_infos); lut.set(module, &f, log_scale); let half_step: i64 = lut.domain_size().div_round(message_modulus << 1) as i64; diff --git a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_serialization.rs b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_serialization.rs index f25a236..341dd49 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_serialization.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_serialization.rs @@ -1,8 +1,6 @@ use poulpy_hal::test_suite::serialization::test_reader_writer_interface; -use crate::tfhe::blind_rotation::{ - BlindRotationKey, BlindRotationKeyAlloc, BlindRotationKeyCompressed, BlindRotationKeyLayout, CGGI, -}; +use crate::tfhe::blind_rotation::{BlindRotationKey, BlindRotationKeyCompressed, BlindRotationKeyLayout, CGGI}; #[test] fn test_cggi_blind_rotation_key_serialization() { @@ -14,7 +12,6 @@ fn test_cggi_blind_rotation_key_serialization() { dnum: 2_usize.into(), rank: 2_usize.into(), }; - let original: BlindRotationKey, CGGI> = BlindRotationKey::alloc(&layout); test_reader_writer_interface(original); } @@ -29,7 +26,6 @@ fn test_cggi_blind_rotation_key_compressed_serialization() { dnum: 2_usize.into(), rank: 2_usize.into(), }; - let original: BlindRotationKeyCompressed, CGGI> = BlindRotationKeyCompressed::alloc(&layout); test_reader_writer_interface(original); } diff --git a/poulpy-schemes/src/tfhe/blind_rotation/tests/implementation/cpu_spqlios/fft64.rs b/poulpy-schemes/src/tfhe/blind_rotation/tests/implementation/cpu_spqlios/fft64.rs deleted file mode 100644 index ecb2421..0000000 --- a/poulpy-schemes/src/tfhe/blind_rotation/tests/implementation/cpu_spqlios/fft64.rs +++ /dev/null @@ -1,37 +0,0 @@ -use poulpy_backend::cpu_spqlios::FFT64Spqlios; -use poulpy_hal::{api::ModuleNew, layouts::Module}; - -use crate::tfhe::blind_rotation::tests::{ - generic_blind_rotation::test_blind_rotation, - generic_lut::{test_lut_extended, test_lut_standard}, -}; - -#[test] -fn lut_standard() { - let module: Module = Module::::new(32); - test_lut_standard(&module); -} - -#[test] -fn lut_extended() { - let module: Module = Module::::new(32); - test_lut_extended(&module); -} - -#[test] -fn standard() { - let module: Module = Module::::new(512); - test_blind_rotation(&module, 224, 1, 1); -} - -#[test] -fn block_binary() { - let module: Module = Module::::new(512); - test_blind_rotation(&module, 224, 7, 1); -} - -#[test] -fn block_binary_extended() { - let module: Module = Module::::new(512); - test_blind_rotation(&module, 224, 7, 2); -} diff --git a/poulpy-schemes/src/tfhe/blind_rotation/tests/implementation/cpu_spqlios/mod.rs b/poulpy-schemes/src/tfhe/blind_rotation/tests/implementation/cpu_spqlios/mod.rs deleted file mode 100644 index aebaafb..0000000 --- a/poulpy-schemes/src/tfhe/blind_rotation/tests/implementation/cpu_spqlios/mod.rs +++ /dev/null @@ -1 +0,0 @@ -mod fft64; diff --git a/poulpy-schemes/src/tfhe/blind_rotation/tests/implementation/fft64.rs b/poulpy-schemes/src/tfhe/blind_rotation/tests/implementation/fft64.rs new file mode 100644 index 0000000..5a471e2 --- /dev/null +++ b/poulpy-schemes/src/tfhe/blind_rotation/tests/implementation/fft64.rs @@ -0,0 +1,40 @@ +use poulpy_backend::cpu_fft64_ref::FFT64Ref; +use poulpy_hal::{api::ModuleNew, layouts::Module}; + +use crate::tfhe::blind_rotation::{ + CGGI, + tests::{ + generic_blind_rotation::test_blind_rotation, + generic_lut::{test_lut_extended, test_lut_standard}, + }, +}; + +#[test] +fn lut_standard() { + let module: Module = Module::::new(32); + test_lut_standard(&module); +} + +#[test] +fn lut_extended() { + let module: Module = Module::::new(32); + test_lut_extended(&module); +} + +#[test] +fn standard() { + let module: Module = Module::::new(512); + test_blind_rotation::(&module, 224, 1, 1); +} + +#[test] +fn block_binary() { + let module: Module = Module::::new(512); + test_blind_rotation::(&module, 224, 7, 1); +} + +#[test] +fn block_binary_extended() { + let module: Module = Module::::new(512); + test_blind_rotation::(&module, 224, 7, 2); +} diff --git a/poulpy-schemes/src/tfhe/blind_rotation/tests/implementation/mod.rs b/poulpy-schemes/src/tfhe/blind_rotation/tests/implementation/mod.rs index f2bc1d4..aebaafb 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/tests/implementation/mod.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/tests/implementation/mod.rs @@ -1 +1 @@ -mod cpu_spqlios; +mod fft64; diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs index 45b9717..a627e0c 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs @@ -1,185 +1,231 @@ use std::collections::HashMap; use poulpy_hal::{ - api::{ - ScratchAvailable, TakeMatZnx, TakeSlice, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice, TakeVecZnxSlice, - VecZnxAddInplace, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigAutomorphismInplace, - VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxCopy, VecZnxDftAddInplace, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNegateInplace, - VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, - VecZnxRotateInplaceTmpBytes, VecZnxRshInplace, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, - VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, - }, - layouts::{Backend, DataMut, DataRef, Module, Scratch, ToOwnedDeep}, - oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl}, + api::{ModuleLogN, ModuleN, ScratchOwnedAlloc, ScratchOwnedBorrow}, + layouts::{Backend, DataRef, Module, Scratch, ScratchOwned, ToOwnedDeep}, }; use poulpy_core::{ - GLWEOperations, TakeGGLWE, TakeGLWECt, - layouts::{Dsize, GGLWECiphertextLayout, GGSWInfos, GLWEInfos, LWEInfos}, + GGSWFromGGLWE, GLWEDecrypt, GLWEPacking, GLWETrace, ScratchTakeCore, + layouts::{ + Dsize, GGLWELayout, GGSWInfos, GGSWToMut, GLWEInfos, GLWESecretPreparedFactory, GLWEToMut, GLWEToRef, LWEInfos, LWEToRef, + }, }; -use poulpy_core::glwe_packing; -use poulpy_core::layouts::{GGSWCiphertext, GLWECiphertext, LWECiphertext, prepared::GGLWEAutomorphismKeyPrepared}; +use poulpy_core::layouts::{GGSW, GLWE, LWE, prepared::GLWEAutomorphismKeyPrepared}; use crate::tfhe::{ blind_rotation::{ - BlincRotationExecute, BlindRotationAlgo, BlindRotationKeyPrepared, LookUpTable, LookUpTableRotationDirection, + BlindRotationAlgo, BlindRotationExecute, LookUpTableLayout, LookUpTableRotationDirection, LookupTable, LookupTableFactory, }, - circuit_bootstrapping::{CircuitBootstrappingKeyPrepared, CirtuitBootstrappingExecute}, + circuit_bootstrapping::{CircuitBootstrappingKeyInfos, CircuitBootstrappingKeyPrepared}, }; -impl CirtuitBootstrappingExecute for CircuitBootstrappingKeyPrepared -where - Module: VecZnxRotateInplace - + VecZnxNormalizeInplace - + VecZnxNormalizeTmpBytes - + VecZnxSwitchRing - + VecZnxBigAutomorphismInplace - + VecZnxRshInplace - + VecZnxDftCopy - + VecZnxIdftApplyTmpA - + VecZnxSub - + VecZnxAddInplace - + VecZnxNegateInplace - + VecZnxCopy - + VecZnxSubInplace - + VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxAutomorphismInplace - + VecZnxBigSubSmallNegateInplace - + VecZnxRotateInplaceTmpBytes - + VecZnxBigAllocBytes - + VecZnxDftAddInplace - + VecZnxRotate - + VecZnxNormalize, - B: Backend + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, - Scratch: TakeVecZnx - + TakeVecZnxDftSlice - + TakeVecZnxBig - + TakeVecZnxDft - + TakeMatZnx - + ScratchAvailable - + TakeVecZnxSlice - + TakeSlice, - BlindRotationKeyPrepared: BlincRotationExecute, -{ - fn execute_to_constant( +pub trait CirtuitBootstrappingExecute { + fn circuit_bootstrapping_execute_tmp_bytes( &self, - module: &Module, - res: &mut GGSWCiphertext, - lwe: &LWECiphertext, + block_size: usize, + extension_factor: usize, + res_infos: &R, + cbt_infos: &A, + ) -> usize + where + R: GGSWInfos, + A: CircuitBootstrappingKeyInfos; + + fn circuit_bootstrapping_execute_to_constant( + &self, + res: &mut R, + lwe: &L, + key: &CircuitBootstrappingKeyPrepared, log_domain: usize, extension_factor: usize, - scratch: &mut Scratch, - ) { + scratch: &mut Scratch, + ) where + R: GGSWToMut + GGSWInfos, + L: LWEToRef + LWEInfos, + D: DataRef; + + #[allow(clippy::too_many_arguments)] + fn circuit_bootstrapping_execute_to_exponent( + &self, + log_gap_out: usize, + res: &mut R, + lwe: &L, + key: &CircuitBootstrappingKeyPrepared, + log_domain: usize, + extension_factor: usize, + scratch: &mut Scratch, + ) where + R: GGSWToMut + GGSWInfos, + L: LWEToRef + LWEInfos, + D: DataRef; +} + +impl CircuitBootstrappingKeyPrepared { + pub fn execute_to_constant( + &self, + module: &M, + res: &mut R, + lwe: &L, + log_domain: usize, + extension_factor: usize, + scratch: &mut Scratch, + ) where + M: CirtuitBootstrappingExecute, + R: GGSWToMut + GGSWInfos, + L: LWEToRef + LWEInfos, + { + module.circuit_bootstrapping_execute_to_constant(res, lwe, self, log_domain, extension_factor, scratch); + } + + #[allow(clippy::too_many_arguments)] + pub fn execute_to_exponent( + &self, + module: &M, + log_gap_out: usize, + res: &mut R, + lwe: &L, + log_domain: usize, + extension_factor: usize, + scratch: &mut Scratch, + ) where + M: CirtuitBootstrappingExecute, + R: GGSWToMut + GGSWInfos, + L: LWEToRef + LWEInfos, + { + module.circuit_bootstrapping_execute_to_exponent( + log_gap_out, + res, + lwe, + self, + log_domain, + extension_factor, + scratch, + ); + } +} + +impl CirtuitBootstrappingExecute for Module +where + Self: ModuleN + + LookupTableFactory + + BlindRotationExecute + + GLWETrace + + GLWEPacking + + GGSWFromGGLWE + + GLWESecretPreparedFactory + + GLWEDecrypt, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchTakeCore, +{ + fn circuit_bootstrapping_execute_tmp_bytes( + &self, + block_size: usize, + extension_factor: usize, + res_infos: &R, + cbt_infos: &A, + ) -> usize + where + R: GGSWInfos, + A: CircuitBootstrappingKeyInfos, + { + self.blind_rotation_execute_tmp_bytes( + block_size, + extension_factor, + res_infos, + &cbt_infos.brk_infos(), + ) + .max(self.glwe_trace_tmp_bytes(res_infos, res_infos, &cbt_infos.atk_infos())) + .max(self.ggsw_from_gglwe_tmp_bytes(res_infos, &cbt_infos.tsk_infos())) + } + + fn circuit_bootstrapping_execute_to_constant( + &self, + res: &mut R, + lwe: &L, + key: &CircuitBootstrappingKeyPrepared, + log_domain: usize, + extension_factor: usize, + scratch: &mut Scratch, + ) where + R: GGSWToMut + GGSWInfos, + L: LWEToRef + LWEInfos, + D: DataRef, + { circuit_bootstrap_core( false, - module, + self, 0, res, lwe, log_domain, extension_factor, - self, + key, scratch, ); } - fn execute_to_exponent( + fn circuit_bootstrapping_execute_to_exponent( &self, - module: &Module, log_gap_out: usize, - res: &mut GGSWCiphertext, - lwe: &LWECiphertext, + res: &mut R, + lwe: &L, + key: &CircuitBootstrappingKeyPrepared, log_domain: usize, extension_factor: usize, - scratch: &mut Scratch, - ) { + scratch: &mut Scratch, + ) where + R: GGSWToMut + GGSWInfos, + L: LWEToRef + LWEInfos, + D: DataRef, + { circuit_bootstrap_core( true, - module, + self, log_gap_out, res, lwe, log_domain, extension_factor, - self, + key, scratch, ); } } #[allow(clippy::too_many_arguments)] -pub fn circuit_bootstrap_core( +pub fn circuit_bootstrap_core( to_exponent: bool, - module: &Module, + module: &M, log_gap_out: usize, - res: &mut GGSWCiphertext, - lwe: &LWECiphertext, + res: &mut R, + lwe: &L, log_domain: usize, extension_factor: usize, - key: &CircuitBootstrappingKeyPrepared, - scratch: &mut Scratch, + key: &CircuitBootstrappingKeyPrepared, + scratch: &mut Scratch, ) where - DRes: DataMut, - DLwe: DataRef, - DBrk: DataRef, - Module: VecZnxRotateInplace - + VecZnxNormalizeInplace - + VecZnxNormalizeTmpBytes - + VecZnxSwitchRing - + VecZnxBigAutomorphismInplace - + VecZnxRshInplace - + VecZnxDftCopy - + VecZnxIdftApplyTmpA - + VecZnxSub - + VecZnxAddInplace - + VecZnxNegateInplace - + VecZnxCopy - + VecZnxSubInplace - + VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxAutomorphismInplace - + VecZnxBigSubSmallNegateInplace - + VecZnxBigAllocBytes - + VecZnxDftAddInplace - + VecZnxRotateInplaceTmpBytes - + VecZnxRotate - + VecZnxNormalize, - B: Backend + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, - Scratch: TakeVecZnxDftSlice - + TakeVecZnxBig - + TakeVecZnxDft - + TakeVecZnx - + ScratchAvailable - + TakeVecZnxSlice - + TakeMatZnx - + TakeSlice, - BlindRotationKeyPrepared: BlincRotationExecute, + R: GGSWToMut, + L: LWEToRef, + D: DataRef, + M: ModuleN + + LookupTableFactory + + BlindRotationExecute + + GLWETrace + + GLWEPacking + + GGSWFromGGLWE + + GLWESecretPreparedFactory + + GLWEDecrypt, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchTakeCore, { - #[cfg(debug_assertions)] - { - use poulpy_core::layouts::LWEInfos; + let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); + let lwe: &LWE<&[u8]> = &lwe.to_ref(); - assert_eq!(res.n(), key.brk.n()); - assert_eq!(lwe.base2k(), key.brk.base2k()); - assert_eq!(res.base2k(), key.brk.base2k()); - } + assert_eq!(res.n(), key.brk.n()); + assert_eq!(lwe.base2k(), key.brk.base2k()); + assert_eq!(res.base2k(), key.brk.base2k()); let n: usize = res.n().into(); let base2k: usize = res.base2k().into(); @@ -203,8 +249,15 @@ pub fn circuit_bootstrap_core( }); } + let lut_infos: LookUpTableLayout = LookUpTableLayout { + n: module.n().into(), + extension_factor, + k: (base2k * dnum).into(), + base2k: base2k.into(), + }; + // Lut precision, basically must be able to hold the decomposition power basis of the GGSW - let mut lut: LookUpTable = LookUpTable::alloc(module, base2k, base2k * dnum, extension_factor); + let mut lut: LookupTable = LookupTable::alloc(&lut_infos); lut.set(module, &f, base2k * dnum); if to_exponent { @@ -212,9 +265,9 @@ pub fn circuit_bootstrap_core( } // TODO: separate GGSW k from output of blind rotation k - let (mut res_glwe, scratch_1) = scratch.take_glwe_ct(res); + let (mut res_glwe, scratch_1) = scratch.take_glwe(res); - let gglwe_infos: GGLWECiphertextLayout = GGLWECiphertextLayout { + let gglwe_infos: GGLWELayout = GGLWELayout { n: n.into(), base2k: base2k.into(), k: k.into(), @@ -233,7 +286,7 @@ pub fn circuit_bootstrap_core( let log_gap_in: usize = (usize::BITS - (gap * alpha - 1).leading_zeros()) as _; (0..dnum).for_each(|i| { - let mut tmp_glwe: GLWECiphertext<&mut [u8]> = tmp_gglwe.at_mut(i, 0); + let mut tmp_glwe: GLWE<&mut [u8]> = tmp_gglwe.at_mut(i, 0); if to_exponent { // Isolates i-th LUT and moves coefficients according to requested gap. @@ -251,8 +304,14 @@ pub fn circuit_bootstrap_core( tmp_glwe.trace(module, 0, module.log_n(), &res_glwe, &key.atk, scratch_2); } + // let sk_glwe: &poulpy_core::layouts::GLWESecret<&[u8]> = &sk_glwe.to_ref(); + // let sk_glwe_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, sk_glwe.rank()); + // let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&res_glwe); + // res_glwe.decrypt(module, &mut pt, &sk_glwe_prepared, scratch_2); + // println!("pt[{i}]: {}", pt); + if i < dnum { - res_glwe.rotate_inplace(module, -(gap as i64), scratch_2); + module.glwe_rotate_inplace(-(gap as i64), &mut res_glwe, scratch_2); } }); @@ -261,49 +320,27 @@ pub fn circuit_bootstrap_core( } #[allow(clippy::too_many_arguments)] -fn post_process( - module: &Module, - res: &mut GLWECiphertext, - a: &GLWECiphertext, +fn post_process( + module: &M, + res: &mut R, + a: &A, log_gap_in: usize, log_gap_out: usize, log_domain: usize, - auto_keys: &HashMap, B>>, - scratch: &mut Scratch, + auto_keys: &HashMap, BE>>, + scratch: &mut Scratch, ) where - DataRes: DataMut, - DataA: DataRef, - Module: VecZnxRotateInplace - + VecZnxNormalizeInplace - + VecZnxNormalizeTmpBytes - + VecZnxSwitchRing - + VecZnxBigAutomorphismInplace - + VecZnxRshInplace - + VecZnxDftCopy - + VecZnxIdftApplyTmpA - + VecZnxSub - + VecZnxAddInplace - + VecZnxNegateInplace - + VecZnxCopy - + VecZnxSubInplace - + VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxAutomorphismInplace - + VecZnxBigSubSmallNegateInplace - + VecZnxRotate - + VecZnxNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + R: GLWEToMut, + A: GLWEToRef, + M: ModuleLogN + GLWETrace + GLWEPacking, + Scratch: ScratchTakeCore, { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); + let log_n: usize = module.log_n(); - let mut cts: HashMap>> = HashMap::new(); + let mut cts: HashMap>> = HashMap::new(); // First partial trace, vanishes all coefficients which are not multiples of gap_in // [1, 1, 1, 1, 0, 0, 0, ..., 0, 0, -1, -1, -1, -1] -> [1, 0, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0, 0] @@ -322,11 +359,11 @@ fn post_process( let steps: usize = 1 << log_domain; // TODO: from Scratch - let mut cts_vec: Vec>> = Vec::new(); + let mut cts_vec: Vec>> = Vec::new(); for i in 0..steps { if i != 0 { - res.rotate_inplace(module, -(1 << log_gap_in), scratch); + module.glwe_rotate_inplace(-(1 << log_gap_in), res, scratch); } cts_vec.push(res.to_owned_deep()); } @@ -335,8 +372,9 @@ fn post_process( cts.insert(i * (1 << log_gap_out), ct); } - glwe_packing(module, &mut cts, log_gap_out, auto_keys, scratch); - let packed: &mut GLWECiphertext> = cts.remove(&0).unwrap(); + module.glwe_pack(&mut cts, log_gap_out, auto_keys, scratch); + + let packed: &mut GLWE> = cts.remove(&0).unwrap(); res.trace( module, log_n - log_gap_out, diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/key.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/key.rs index 427cf75..c6b8adc 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/key.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/key.rs @@ -1,42 +1,38 @@ -use poulpy_core::layouts::{ - GGLWEAutomorphismKey, GGLWEAutomorphismKeyLayout, GGLWEInfos, GGLWETensorKey, GGLWETensorKeyLayout, GGSWInfos, - GLWECiphertext, GLWEInfos, GLWESecret, LWEInfos, LWESecret, - prepared::{GGLWEAutomorphismKeyPrepared, GGLWETensorKeyPrepared, GLWESecretPrepared, PrepareAlloc}, +use poulpy_core::{ + Distribution, GLWEAutomorphismKeyEncryptSk, GLWETensorKeyEncryptSk, GetDistribution, ScratchTakeCore, + layouts::{ + GGLWEInfos, GGSWInfos, GLWEAutomorphismKey, GLWEAutomorphismKeyLayout, GLWEInfos, GLWESecretPreparedFactory, + GLWESecretToRef, GLWETensorKey, GLWETensorKeyLayout, LWEInfos, LWESecretToRef, prepared::GLWESecretPrepared, + }, + trace_galois_elements, }; use std::collections::HashMap; use poulpy_hal::{ - api::{ - ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, - TakeSvpPPol, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, - VecZnxAutomorphism, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, - VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, - VecZnxSwitchRing, VmpPMatAlloc, VmpPrepare, - }, - layouts::{Backend, Data, DataRef, Module, Scratch}, + layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}, source::Source, }; use crate::tfhe::blind_rotation::{ - BlindRotationAlgo, BlindRotationKey, BlindRotationKeyAlloc, BlindRotationKeyEncryptSk, BlindRotationKeyInfos, - BlindRotationKeyLayout, BlindRotationKeyPrepared, + BlindRotationAlgo, BlindRotationKey, BlindRotationKeyEncryptSk, BlindRotationKeyFactory, BlindRotationKeyInfos, + BlindRotationKeyLayout, }; pub trait CircuitBootstrappingKeyInfos { fn brk_infos(&self) -> BlindRotationKeyLayout; - fn atk_infos(&self) -> GGLWEAutomorphismKeyLayout; - fn tsk_infos(&self) -> GGLWETensorKeyLayout; + fn atk_infos(&self) -> GLWEAutomorphismKeyLayout; + fn tsk_infos(&self) -> GLWETensorKeyLayout; } #[derive(Debug, Clone, Copy)] pub struct CircuitBootstrappingKeyLayout { pub layout_brk: BlindRotationKeyLayout, - pub layout_atk: GGLWEAutomorphismKeyLayout, - pub layout_tsk: GGLWETensorKeyLayout, + pub layout_atk: GLWEAutomorphismKeyLayout, + pub layout_tsk: GLWETensorKeyLayout, } impl CircuitBootstrappingKeyInfos for CircuitBootstrappingKeyLayout { - fn atk_infos(&self) -> GGLWEAutomorphismKeyLayout { + fn atk_infos(&self) -> GLWEAutomorphismKeyLayout { self.layout_atk } @@ -44,96 +40,117 @@ impl CircuitBootstrappingKeyInfos for CircuitBootstrappingKeyLayout { self.layout_brk } - fn tsk_infos(&self) -> GGLWETensorKeyLayout { + fn tsk_infos(&self) -> GLWETensorKeyLayout { self.layout_tsk } } -pub trait CircuitBootstrappingKeyEncryptSk { +pub trait CircuitBootstrappingKeyEncryptSk { #[allow(clippy::too_many_arguments)] - fn encrypt_sk( - module: &Module, - sk_lwe: &LWESecret, - sk_glwe: &GLWESecret, - cbt_infos: &INFOS, + fn circuit_bootstrapping_key_encrypt_sk( + &self, + res: &mut CircuitBootstrappingKey, + sk_lwe: &S0, + sk_glwe: &S1, source_xa: &mut Source, source_xe: &mut Source, - scratch: &mut Scratch, - ) -> Self + scratch: &mut Scratch, + ) where + D: DataMut, + S0: LWESecretToRef + GetDistribution + LWEInfos, + S1: GLWESecretToRef + GLWEInfos + GetDistribution; +} + +impl CircuitBootstrappingKey, BRA> { + pub fn alloc_from_infos(infos: &A) -> Self where - INFOS: CircuitBootstrappingKeyInfos, - DLwe: DataRef, - DGlwe: DataRef; + A: CircuitBootstrappingKeyInfos, + BlindRotationKey, BRA>: BlindRotationKeyFactory, + { + let atk_infos: &GLWEAutomorphismKeyLayout = &infos.atk_infos(); + let brk_infos: &BlindRotationKeyLayout = &infos.brk_infos(); + let trk_infos: &GLWETensorKeyLayout = &infos.tsk_infos(); + let gal_els: Vec = trace_galois_elements(atk_infos.log_n(), 2 * atk_infos.n().as_usize() as i64); + + Self { + brk: , BRA> as BlindRotationKeyFactory>::blind_rotation_key_alloc(brk_infos), + atk: gal_els + .iter() + .map(|&gal_el| { + let key = GLWEAutomorphismKey::alloc_from_infos(atk_infos); + (gal_el, key) + }) + .collect(), + tsk: GLWETensorKey::alloc_from_infos(trk_infos), + } + } } pub struct CircuitBootstrappingKey { pub(crate) brk: BlindRotationKey, - pub(crate) tsk: GGLWETensorKey>, - pub(crate) atk: HashMap>>, + pub(crate) tsk: GLWETensorKey>, + pub(crate) atk: HashMap>>, } -impl CircuitBootstrappingKeyEncryptSk for CircuitBootstrappingKey, BRA> -where - BlindRotationKey, BRA>: BlindRotationKeyAlloc + BlindRotationKeyEncryptSk, - Module: SvpApplyDftToDft - + VecZnxIdftApplyTmpA - + VecZnxAddScalarInplace - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + VecZnxSwitchRing - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxAutomorphism, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeSvpPPol + TakeVecZnxBig, -{ - fn encrypt_sk( - module: &Module, - sk_lwe: &LWESecret, - sk_glwe: &GLWESecret, - cbt_infos: &INFOS, +impl CircuitBootstrappingKey { + pub fn encrypt_sk( + &mut self, + module: &M, + sk_lwe: &S0, + sk_glwe: &S1, source_xa: &mut Source, source_xe: &mut Source, - scratch: &mut Scratch, - ) -> Self - where - INFOS: CircuitBootstrappingKeyInfos, - DLwe: DataRef, - DGlwe: DataRef, + scratch: &mut Scratch, + ) where + S0: LWESecretToRef + GetDistribution + LWEInfos, + S1: GLWESecretToRef + GLWEInfos + GetDistribution, + M: CircuitBootstrappingKeyEncryptSk, { - assert_eq!(sk_lwe.n(), cbt_infos.brk_infos().n_lwe()); - assert_eq!(sk_glwe.n(), cbt_infos.brk_infos().n_glwe()); - assert_eq!(sk_glwe.n(), cbt_infos.atk_infos().n()); - assert_eq!(sk_glwe.n(), cbt_infos.tsk_infos().n()); + module.circuit_bootstrapping_key_encrypt_sk(self, sk_lwe, sk_glwe, source_xa, source_xe, scratch); + } +} - let atk_infos: GGLWEAutomorphismKeyLayout = cbt_infos.atk_infos(); - let brk_infos: BlindRotationKeyLayout = cbt_infos.brk_infos(); - let trk_infos: GGLWETensorKeyLayout = cbt_infos.tsk_infos(); +impl CircuitBootstrappingKeyEncryptSk for Module +where + Self: GLWETensorKeyEncryptSk + + BlindRotationKeyEncryptSk + + GLWEAutomorphismKeyEncryptSk + + GLWESecretPreparedFactory, + Scratch: ScratchTakeCore, +{ + fn circuit_bootstrapping_key_encrypt_sk( + &self, + res: &mut CircuitBootstrappingKey, + sk_lwe: &S0, + sk_glwe: &S1, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + D: DataMut, + S0: LWESecretToRef + GetDistribution + LWEInfos, + S1: GLWESecretToRef + GLWEInfos + GetDistribution, + { + let brk_infos: &BlindRotationKeyLayout = &res.brk_infos(); + let atk_infos: &GLWEAutomorphismKeyLayout = &res.atk_infos(); + let tsk_infos: &GLWETensorKeyLayout = &res.tsk_infos(); - let mut auto_keys: HashMap>> = HashMap::new(); - let gal_els: Vec = GLWECiphertext::trace_galois_elements(module); - gal_els.iter().for_each(|gal_el| { - let mut key: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&atk_infos); - key.encrypt_sk(module, *gal_el, sk_glwe, source_xa, source_xe, scratch); - auto_keys.insert(*gal_el, key); - }); + assert_eq!(sk_lwe.n(), brk_infos.n_lwe()); + assert_eq!(sk_glwe.n(), brk_infos.n_glwe()); + assert_eq!(sk_glwe.n(), atk_infos.n()); + assert_eq!(sk_glwe.n(), tsk_infos.n()); - let sk_glwe_prepared: GLWESecretPrepared, B> = sk_glwe.prepare_alloc(module, scratch); + assert!(sk_glwe.dist() != &Distribution::NONE); - let mut brk: BlindRotationKey, BRA> = BlindRotationKey::, BRA>::alloc(&brk_infos); - brk.encrypt_sk( - module, + for (p, atk) in res.atk.iter_mut() { + atk.encrypt_sk(self, *p, sk_glwe, source_xa, source_xe, scratch); + } + + let mut sk_glwe_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(self, brk_infos.rank()); + sk_glwe_prepared.prepare(self, sk_glwe); + + res.brk.encrypt_sk( + self, &sk_glwe_prepared, sk_lwe, source_xa, @@ -141,27 +158,15 @@ where scratch, ); - let mut tsk: GGLWETensorKey> = GGLWETensorKey::alloc(&trk_infos); - tsk.encrypt_sk(module, sk_glwe, source_xa, source_xe, scratch); - - Self { - brk, - atk: auto_keys, - tsk, - } + res.tsk + .encrypt_sk(self, sk_glwe, source_xa, source_xe, scratch); } } -pub struct CircuitBootstrappingKeyPrepared { - pub(crate) brk: BlindRotationKeyPrepared, - pub(crate) tsk: GGLWETensorKeyPrepared, B>, - pub(crate) atk: HashMap, B>>, -} - -impl CircuitBootstrappingKeyInfos for CircuitBootstrappingKeyPrepared { - fn atk_infos(&self) -> GGLWEAutomorphismKeyLayout { +impl CircuitBootstrappingKeyInfos for CircuitBootstrappingKey { + fn atk_infos(&self) -> GLWEAutomorphismKeyLayout { let (_, atk) = self.atk.iter().next().expect("atk is empty"); - GGLWEAutomorphismKeyLayout { + GLWEAutomorphismKeyLayout { n: atk.n(), base2k: atk.base2k(), k: atk.k(), @@ -182,8 +187,8 @@ impl CircuitBootstrappingKeyInfo } } - fn tsk_infos(&self) -> GGLWETensorKeyLayout { - GGLWETensorKeyLayout { + fn tsk_infos(&self) -> GLWETensorKeyLayout { + GLWETensorKeyLayout { n: self.tsk.n(), base2k: self.tsk.base2k(), k: self.tsk.k(), @@ -193,22 +198,3 @@ impl CircuitBootstrappingKeyInfo } } } - -impl PrepareAlloc, BRA, B>> - for CircuitBootstrappingKey -where - Module: VmpPMatAlloc + VmpPrepare, - BlindRotationKey: PrepareAlloc, BRA, B>>, - GGLWETensorKey: PrepareAlloc, B>>, - GGLWEAutomorphismKey: PrepareAlloc, B>>, -{ - fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> CircuitBootstrappingKeyPrepared, BRA, B> { - let brk: BlindRotationKeyPrepared, BRA, B> = self.brk.prepare_alloc(module, scratch); - let tsk: GGLWETensorKeyPrepared, B> = self.tsk.prepare_alloc(module, scratch); - let mut atk: HashMap, B>> = HashMap::new(); - for (key, value) in &self.atk { - atk.insert(*key, value.prepare_alloc(module, scratch)); - } - CircuitBootstrappingKeyPrepared { brk, tsk, atk } - } -} diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/key_compressed.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/key_compressed.rs new file mode 100644 index 0000000..223c193 --- /dev/null +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/key_compressed.rs @@ -0,0 +1,13 @@ +use std::collections::HashMap; + +use poulpy_core::layouts::{GLWEAutomorphismKeyCompressed, GLWETensorKeyCompressed}; +use poulpy_hal::layouts::Data; + +use crate::tfhe::blind_rotation::{BlindRotationAlgo, BlindRotationKeyCompressed}; + +#[allow(dead_code)] +pub struct CircuitBootstrappingKey { + pub(crate) brk: BlindRotationKeyCompressed, + pub(crate) tsk: GLWETensorKeyCompressed>, + pub(crate) atk: HashMap>>, +} diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/key_prepared.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/key_prepared.rs new file mode 100644 index 0000000..6adca70 --- /dev/null +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/key_prepared.rs @@ -0,0 +1,146 @@ +use poulpy_core::{ + layouts::{ + GGLWEInfos, GGSWInfos, GLWEAutomorphismKeyLayout, GLWEAutomorphismKeyPreparedFactory, GLWEInfos, GLWETensorKeyLayout, + GLWETensorKeyPreparedFactory, LWEInfos, + prepared::{GLWEAutomorphismKeyPrepared, GLWETensorKeyPrepared}, + }, + trace_galois_elements, +}; +use std::collections::HashMap; + +use poulpy_hal::layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}; + +use crate::tfhe::{ + blind_rotation::{ + BlindRotationAlgo, BlindRotationKeyInfos, BlindRotationKeyLayout, BlindRotationKeyPrepared, + BlindRotationKeyPreparedFactory, + }, + circuit_bootstrapping::{CircuitBootstrappingKey, CircuitBootstrappingKeyInfos}, +}; + +impl CircuitBootstrappingKeyPrepared, BRA, BE> { + pub fn alloc_from_infos(module: &M, infos: &A) -> CircuitBootstrappingKeyPrepared, BRA, BE> + where + A: CircuitBootstrappingKeyInfos, + M: CircuitBootstrappingKeyPreparedFactory, + { + module.circuit_bootstrapping_key_prepared_alloc_from_infos(infos) + } +} + +impl CircuitBootstrappingKeyPrepared { + pub fn prepare(&mut self, module: &M, other: &CircuitBootstrappingKey, scratch: &mut Scratch) + where + DR: DataRef, + M: CircuitBootstrappingKeyPreparedFactory, + { + module.circuit_bootstrapping_key_prepare(self, other, scratch); + } +} + +impl CircuitBootstrappingKeyPreparedFactory for Module where + Self: Sized + + BlindRotationKeyPreparedFactory + + GLWETensorKeyPreparedFactory + + GLWEAutomorphismKeyPreparedFactory +{ +} + +pub trait CircuitBootstrappingKeyPreparedFactory +where + Self: Sized + + BlindRotationKeyPreparedFactory + + GLWETensorKeyPreparedFactory + + GLWEAutomorphismKeyPreparedFactory, +{ + fn circuit_bootstrapping_key_prepared_alloc_from_infos( + &self, + infos: &A, + ) -> CircuitBootstrappingKeyPrepared, BRA, BE> + where + A: CircuitBootstrappingKeyInfos, + { + let atk_infos: &GLWEAutomorphismKeyLayout = &infos.atk_infos(); + let gal_els: Vec = trace_galois_elements(atk_infos.log_n(), 2 * atk_infos.n().as_usize() as i64); + + CircuitBootstrappingKeyPrepared { + brk: BlindRotationKeyPrepared::alloc(self, &infos.brk_infos()), + tsk: GLWETensorKeyPrepared::alloc_from_infos(self, &infos.tsk_infos()), + atk: gal_els + .iter() + .map(|&gal_el| { + let key = GLWEAutomorphismKeyPrepared::alloc_from_infos(self, atk_infos); + (gal_el, key) + }) + .collect(), + } + } + + fn circuit_bootstrapping_key_prepare_tmp_bytes(&self, infos: &A) -> usize + where + A: CircuitBootstrappingKeyInfos, + { + self.blind_rotation_key_prepare_tmp_bytes(&infos.brk_infos()) + .max(self.prepare_tensor_key_tmp_bytes(&infos.tsk_infos())) + .max(self.prepare_glwe_automorphism_key_tmp_bytes(&infos.atk_infos())) + } + + fn circuit_bootstrapping_key_prepare( + &self, + res: &mut CircuitBootstrappingKeyPrepared, + other: &CircuitBootstrappingKey, + scratch: &mut Scratch, + ) where + DM: DataMut, + DR: DataRef, + { + res.brk.prepare(self, &other.brk, scratch); + res.tsk.prepare(self, &other.tsk, scratch); + + for (k, a) in res.atk.iter_mut() { + a.prepare(self, other.atk.get(k).unwrap(), scratch); + } + } +} + +pub struct CircuitBootstrappingKeyPrepared { + pub(crate) brk: BlindRotationKeyPrepared, + pub(crate) tsk: GLWETensorKeyPrepared, B>, + pub(crate) atk: HashMap, B>>, +} + +impl CircuitBootstrappingKeyInfos for CircuitBootstrappingKeyPrepared { + fn atk_infos(&self) -> GLWEAutomorphismKeyLayout { + let (_, atk) = self.atk.iter().next().expect("atk is empty"); + GLWEAutomorphismKeyLayout { + n: atk.n(), + base2k: atk.base2k(), + k: atk.k(), + dnum: atk.dnum(), + dsize: atk.dsize(), + rank: atk.rank(), + } + } + + fn brk_infos(&self) -> BlindRotationKeyLayout { + BlindRotationKeyLayout { + n_glwe: self.brk.n_glwe(), + n_lwe: self.brk.n_lwe(), + base2k: self.brk.base2k(), + k: self.brk.k(), + dnum: self.brk.dnum(), + rank: self.brk.rank(), + } + } + + fn tsk_infos(&self) -> GLWETensorKeyLayout { + GLWETensorKeyLayout { + n: self.tsk.n(), + base2k: self.tsk.base2k(), + k: self.tsk.k(), + dnum: self.tsk.dnum(), + dsize: self.tsk.dsize(), + rank: self.tsk.rank(), + } + } +} diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/mod.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/mod.rs index 86dcf5e..e857cd0 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/mod.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/mod.rs @@ -1,34 +1,12 @@ mod circuit; mod key; +mod key_compressed; +mod key_prepared; + +#[cfg(test)] pub mod tests; pub use circuit::*; pub use key::*; - -use poulpy_core::layouts::{GGSWCiphertext, LWECiphertext}; - -use poulpy_hal::layouts::{Backend, DataMut, DataRef, Module, Scratch}; - -pub trait CirtuitBootstrappingExecute { - fn execute_to_constant( - &self, - module: &Module, - res: &mut GGSWCiphertext, - lwe: &LWECiphertext, - log_domain: usize, - extension_factor: usize, - scratch: &mut Scratch, - ); - - #[allow(clippy::too_many_arguments)] - fn execute_to_exponent( - &self, - module: &Module, - log_gap_out: usize, - res: &mut GGSWCiphertext, - lwe: &LWECiphertext, - log_domain: usize, - extension_factor: usize, - scratch: &mut Scratch, - ); -} +// pub use key_compressed::*; +pub use key_prepared::*; diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/circuit_bootstrapping.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/circuit_bootstrapping.rs index 40f5448..59223ea 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/circuit_bootstrapping.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/circuit_bootstrapping.rs @@ -1,110 +1,49 @@ use std::time::Instant; use poulpy_hal::{ - api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, - SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, - VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigAutomorphismInplace, - VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxCopy, VecZnxDftAddInplace, - VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy, VecZnxFillUniform, VecZnxIdftApplyConsume, - VecZnxIdftApplyTmpA, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, - VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxRshInplace, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, - VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, ZnAddNormal, ZnFillUniform, - ZnNormalizeInplace, - }, - layouts::{Backend, Module, ScalarZnx, ScratchOwned, ZnxView, ZnxViewMut}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeMatZnxImpl, TakeScalarZnxImpl, TakeSliceImpl, - TakeSvpPPolImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, - }, + api::{ModuleN, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxRotateInplace}, + layouts::{Backend, ScalarZnx, Scratch, ScratchOwned, ZnxView, ZnxViewMut}, source::Source, }; use crate::tfhe::{ - blind_rotation::{ - BlincRotationExecute, BlindRotationAlgo, BlindRotationKey, BlindRotationKeyAlloc, BlindRotationKeyEncryptSk, - BlindRotationKeyLayout, BlindRotationKeyPrepared, - }, + blind_rotation::{BlindRotationAlgo, BlindRotationKey, BlindRotationKeyFactory, BlindRotationKeyLayout}, circuit_bootstrapping::{ CircuitBootstrappingKey, CircuitBootstrappingKeyEncryptSk, CircuitBootstrappingKeyLayout, - CircuitBootstrappingKeyPrepared, CirtuitBootstrappingExecute, + CircuitBootstrappingKeyPrepared, CircuitBootstrappingKeyPreparedFactory, CirtuitBootstrappingExecute, + }, +}; + +use poulpy_core::{ + GGSWNoise, GLWEDecrypt, GLWEEncryptSk, GLWEExternalProduct, LWEEncryptSk, ScratchTakeCore, + layouts::{ + Dsize, GGSWLayout, GGSWPreparedFactory, GLWEAutomorphismKeyLayout, GLWESecretPreparedFactory, GLWETensorKeyLayout, + LWELayout, }, }; use poulpy_core::layouts::{ - Dsize, GGLWEAutomorphismKeyLayout, GGLWETensorKeyLayout, GGSWCiphertextLayout, LWECiphertextLayout, prepared::PrepareAlloc, + GGSW, GLWE, GLWEPlaintext, GLWESecret, LWE, LWEPlaintext, LWESecret, + prepared::{GGSWPrepared, GLWESecretPrepared}, }; -use poulpy_core::layouts::{ - GGSWCiphertext, GLWECiphertext, GLWEPlaintext, GLWESecret, LWECiphertext, LWEPlaintext, LWESecret, - prepared::{GGSWCiphertextPrepared, GLWESecretPrepared}, -}; - -pub fn test_circuit_bootstrapping_to_exponent(module: &Module) +pub fn test_circuit_bootstrapping_to_exponent(module: &M) where - Module: VecZnxFillUniform - + VecZnxAddNormal - + VecZnxNormalizeInplace - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalize - + VecZnxSub - + VecZnxAddScalarInplace - + VecZnxAutomorphism - + VecZnxSwitchRing - + VecZnxBigAllocBytes - + VecZnxIdftApplyTmpA - + SvpApplyDftToDft - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxBigAlloc - + VecZnxDftAlloc - + VecZnxBigNormalizeTmpBytes - + VmpPMatAlloc - + VmpPrepare - + SvpPrepare - + SvpPPolAlloc - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + SvpPPolAllocBytes - + VecZnxRotateInplace - + VecZnxBigAutomorphismInplace - + VecZnxRshInplace - + VecZnxDftCopy - + VecZnxNegateInplace - + VecZnxCopy - + VecZnxAutomorphismInplace - + VecZnxBigSubSmallNegateInplace - + VecZnxRotateInplaceTmpBytes - + VecZnxBigAllocBytes - + VecZnxDftAddInplace - + VecZnxRotate - + ZnFillUniform - + ZnAddNormal - + ZnNormalizeInplace, - B: Backend - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + TakeVecZnxDftImpl - + ScratchAvailableImpl - + TakeVecZnxImpl - + TakeScalarZnxImpl - + TakeSvpPPolImpl - + TakeVecZnxBigImpl - + TakeVecZnxDftSliceImpl - + TakeMatZnxImpl - + TakeVecZnxSliceImpl - + TakeSliceImpl, - BlindRotationKey, BRA>: PrepareAlloc, BRA, B>>, - BlindRotationKeyPrepared, BRA, B>: BlincRotationExecute, - BlindRotationKey, BRA>: BlindRotationKeyAlloc + BlindRotationKeyEncryptSk, + M: ModuleN + + GLWESecretPreparedFactory + + GLWEExternalProduct + + GLWEDecrypt + + LWEEncryptSk + + CircuitBootstrappingKeyEncryptSk + + CircuitBootstrappingKeyPreparedFactory + + CirtuitBootstrappingExecute + + GGSWPreparedFactory + + GGSWNoise + + GLWEEncryptSk + + VecZnxRotateInplace, + BlindRotationKey, BRA>: BlindRotationKeyFactory, // TODO find a way to remove this bound or move it to CBT KEY + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchTakeCore, { let n_glwe: usize = module.n(); let base2k: usize = 17; @@ -128,7 +67,7 @@ where let k_ggsw_res: usize = 4 * base2k; let rows_ggsw_res: usize = 2; - let lwe_infos: LWECiphertextLayout = LWECiphertextLayout { + let lwe_infos: LWELayout = LWELayout { n: n_lwe.into(), k: k_lwe_ct.into(), base2k: base2k.into(), @@ -143,7 +82,7 @@ where dnum: rows_brk.into(), rank: rank.into(), }, - layout_atk: GGLWEAutomorphismKeyLayout { + layout_atk: GLWEAutomorphismKeyLayout { n: n_glwe.into(), base2k: base2k.into(), k: k_atk.into(), @@ -151,7 +90,7 @@ where rank: rank.into(), dsize: Dsize(1), }, - layout_tsk: GGLWETensorKeyLayout { + layout_tsk: GLWETensorKeyLayout { n: n_glwe.into(), base2k: base2k.into(), k: k_tsk.into(), @@ -161,7 +100,7 @@ where }, }; - let ggsw_infos: GGSWCiphertextLayout = GGSWCiphertextLayout { + let ggsw_infos: GGSWLayout = GGSWLayout { n: n_glwe.into(), base2k: base2k.into(), k: k_ggsw_res.into(), @@ -170,7 +109,7 @@ where rank: rank.into(), }; - let mut scratch: ScratchOwned = ScratchOwned::alloc(1 << 23); + let mut scratch: ScratchOwned = ScratchOwned::alloc(1 << 23); let mut source_xs: Source = Source::new([1u8; 32]); let mut source_xa: Source = Source::new([1u8; 32]); @@ -179,38 +118,44 @@ where let mut sk_lwe: LWESecret> = LWESecret::alloc(n_lwe.into()); sk_lwe.fill_binary_block(block_size, &mut source_xs); - let mut sk_glwe: GLWESecret> = GLWESecret::alloc_with(n_glwe.into(), rank.into()); + let mut sk_glwe: GLWESecret> = GLWESecret::alloc(n_glwe.into(), rank.into()); sk_glwe.fill_ternary_prob(0.5, &mut source_xs); - let sk_glwe_prepared: GLWESecretPrepared, B> = sk_glwe.prepare_alloc(module, scratch.borrow()); + let mut sk_glwe_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank.into()); + sk_glwe_prepared.prepare(module, &sk_glwe); let data: i64 = 1; - let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc_with(base2k.into(), k_lwe_pt.into()); + let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc(base2k.into(), k_lwe_pt.into()); pt_lwe.encode_i64(data, (k_lwe_pt + 1).into()); println!("pt_lwe: {pt_lwe}"); - let mut ct_lwe: LWECiphertext> = LWECiphertext::alloc(&lwe_infos); + let mut ct_lwe: LWE> = LWE::alloc_from_infos(&lwe_infos); ct_lwe.encrypt_sk(module, &pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe); let now: Instant = Instant::now(); - let cbt_key: CircuitBootstrappingKey, BRA> = CircuitBootstrappingKey::encrypt_sk( + let mut cbt_key: CircuitBootstrappingKey, BRA> = CircuitBootstrappingKey::alloc_from_infos(&cbt_infos); + println!("CBT-ALLOC: {} ms", now.elapsed().as_millis()); + + let now: Instant = Instant::now(); + cbt_key.encrypt_sk( module, &sk_lwe, &sk_glwe, - &cbt_infos, &mut source_xa, &mut source_xe, scratch.borrow(), ); - println!("CBT-KGEN: {} ms", now.elapsed().as_millis()); + println!("CBT-ENCRYPT: {} ms", now.elapsed().as_millis()); - let mut res: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_infos); + let mut res: GGSW> = GGSW::alloc_from_infos(&ggsw_infos); let log_gap_out = 1; - let cbt_prepared: CircuitBootstrappingKeyPrepared, BRA, B> = cbt_key.prepare_alloc(module, scratch.borrow()); + let mut cbt_prepared: CircuitBootstrappingKeyPrepared, BRA, BE> = + CircuitBootstrappingKeyPrepared::alloc_from_infos(module, &cbt_infos); + cbt_prepared.prepare(module, &cbt_key, scratch.borrow()); let now: Instant = Instant::now(); cbt_prepared.execute_to_exponent( @@ -236,8 +181,8 @@ where res.print_noise(module, &sk_glwe_prepared, &pt_ggsw); - let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(&ggsw_infos); - let mut pt_glwe: GLWEPlaintext> = GLWEPlaintext::alloc(&ggsw_infos); + let mut ct_glwe: GLWE> = GLWE::alloc_from_infos(&ggsw_infos); + let mut pt_glwe: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&ggsw_infos); pt_glwe.data.at_mut(0, 0)[0] = 1 << (base2k - 2); ct_glwe.encrypt_sk( @@ -249,11 +194,12 @@ where scratch.borrow(), ); - let res_prepared: GGSWCiphertextPrepared, B> = res.prepare_alloc(module, scratch.borrow()); + let mut res_prepared: GGSWPrepared, BE> = GGSWPrepared::alloc_from_infos(module, &res); + res_prepared.prepare(module, &res, scratch.borrow()); ct_glwe.external_product_inplace(module, &res_prepared, scratch.borrow()); - let mut pt_res: GLWEPlaintext> = GLWEPlaintext::alloc(&ggsw_infos); + let mut pt_res: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&ggsw_infos); ct_glwe.decrypt(module, &mut pt_res, &sk_glwe_prepared, scratch.borrow()); // Parameters are set such that the first limb should be noiseless. @@ -262,76 +208,28 @@ where assert_eq!(pt_res.data.at(0, 0), pt_want); } -pub fn test_circuit_bootstrapping_to_constant(module: &Module) +pub fn test_circuit_bootstrapping_to_constant(module: &M) where - Module: VecZnxFillUniform - + VecZnxAddNormal - + VecZnxNormalizeInplace - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalize - + VecZnxSub - + VecZnxAddScalarInplace - + VecZnxAutomorphism - + VecZnxSwitchRing - + VecZnxBigAllocBytes - + VecZnxIdftApplyTmpA - + SvpApplyDftToDft - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxBigAlloc - + VecZnxDftAlloc - + VecZnxBigNormalizeTmpBytes - + VmpPMatAlloc - + VmpPrepare - + SvpPrepare - + SvpPPolAlloc - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + SvpPPolAllocBytes - + VecZnxRotateInplace - + VecZnxBigAutomorphismInplace - + VecZnxRotateInplaceTmpBytes - + VecZnxRshInplace - + VecZnxDftCopy - + VecZnxNegateInplace - + VecZnxCopy - + VecZnxAutomorphismInplace - + VecZnxBigSubSmallNegateInplace - + VecZnxBigAllocBytes - + VecZnxDftAddInplace - + VecZnxRotate - + ZnFillUniform - + ZnAddNormal - + ZnNormalizeInplace, - B: Backend - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + TakeVecZnxDftImpl - + ScratchAvailableImpl - + TakeVecZnxImpl - + TakeScalarZnxImpl - + TakeSvpPPolImpl - + TakeVecZnxBigImpl - + TakeVecZnxDftSliceImpl - + TakeMatZnxImpl - + TakeVecZnxSliceImpl - + TakeSliceImpl, - BlindRotationKey, BRA>: PrepareAlloc, BRA, B>>, - BlindRotationKeyPrepared, BRA, B>: BlincRotationExecute, - BlindRotationKey, BRA>: BlindRotationKeyAlloc + BlindRotationKeyEncryptSk, + M: ModuleN + + GLWESecretPreparedFactory + + GLWEExternalProduct + + GLWEDecrypt + + LWEEncryptSk + + CircuitBootstrappingKeyEncryptSk + + CircuitBootstrappingKeyPreparedFactory + + CirtuitBootstrappingExecute + + GGSWPreparedFactory + + GGSWNoise + + GLWEEncryptSk + + VecZnxRotateInplace, + BlindRotationKey, BRA>: BlindRotationKeyFactory, // TODO find a way to remove this bound or move it to CBT KEY + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchTakeCore, { let n_glwe: usize = module.n(); let base2k: usize = 14; let extension_factor: usize = 1; - let rank: usize = 2; + let rank: usize = 1; let n_lwe: usize = 77; let k_lwe_pt: usize = 1; @@ -350,7 +248,7 @@ where let k_ggsw_res: usize = 4 * base2k; let rows_ggsw_res: usize = 3; - let lwe_infos: LWECiphertextLayout = LWECiphertextLayout { + let lwe_infos: LWELayout = LWELayout { n: n_lwe.into(), k: k_lwe_ct.into(), base2k: base2k.into(), @@ -365,7 +263,7 @@ where dnum: rows_brk.into(), rank: rank.into(), }, - layout_atk: GGLWEAutomorphismKeyLayout { + layout_atk: GLWEAutomorphismKeyLayout { n: n_glwe.into(), base2k: base2k.into(), k: k_atk.into(), @@ -373,7 +271,7 @@ where rank: rank.into(), dsize: Dsize(1), }, - layout_tsk: GGLWETensorKeyLayout { + layout_tsk: GLWETensorKeyLayout { n: n_glwe.into(), base2k: base2k.into(), k: k_tsk.into(), @@ -383,7 +281,7 @@ where }, }; - let ggsw_infos: GGSWCiphertextLayout = GGSWCiphertextLayout { + let ggsw_infos: GGSWLayout = GGSWLayout { n: n_glwe.into(), base2k: base2k.into(), k: k_ggsw_res.into(), @@ -392,7 +290,7 @@ where rank: rank.into(), }; - let mut scratch: ScratchOwned = ScratchOwned::alloc(1 << 23); + let mut scratch: ScratchOwned = ScratchOwned::alloc(1 << 23); let mut source_xs: Source = Source::new([1u8; 32]); let mut source_xa: Source = Source::new([1u8; 32]); @@ -401,36 +299,42 @@ where let mut sk_lwe: LWESecret> = LWESecret::alloc(n_lwe.into()); sk_lwe.fill_binary_block(block_size, &mut source_xs); - let mut sk_glwe: GLWESecret> = GLWESecret::alloc_with(n_glwe.into(), rank.into()); + let mut sk_glwe: GLWESecret> = GLWESecret::alloc(n_glwe.into(), rank.into()); sk_glwe.fill_ternary_prob(0.5, &mut source_xs); - let sk_glwe_prepared: GLWESecretPrepared, B> = sk_glwe.prepare_alloc(module, scratch.borrow()); + let mut sk_glwe_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank.into()); + sk_glwe_prepared.prepare(module, &sk_glwe); let data: i64 = 1; - let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc_with(base2k.into(), k_lwe_pt.into()); + let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc(base2k.into(), k_lwe_pt.into()); pt_lwe.encode_i64(data, (k_lwe_pt + 1).into()); println!("pt_lwe: {pt_lwe}"); - let mut ct_lwe: LWECiphertext> = LWECiphertext::alloc(&lwe_infos); + let mut ct_lwe: LWE> = LWE::alloc_from_infos(&lwe_infos); ct_lwe.encrypt_sk(module, &pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe); let now: Instant = Instant::now(); - let cbt_key: CircuitBootstrappingKey, BRA> = CircuitBootstrappingKey::encrypt_sk( + let mut cbt_key: CircuitBootstrappingKey, BRA> = CircuitBootstrappingKey::alloc_from_infos(&cbt_infos); + println!("CBT-ALLOC: {} ms", now.elapsed().as_millis()); + + let now: Instant = Instant::now(); + cbt_key.encrypt_sk( module, &sk_lwe, &sk_glwe, - &cbt_infos, &mut source_xa, &mut source_xe, scratch.borrow(), ); - println!("CBT-KGEN: {} ms", now.elapsed().as_millis()); + println!("CBT-ENCRYPT: {} ms", now.elapsed().as_millis()); - let mut res: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_infos); + let mut res: GGSW> = GGSW::alloc_from_infos(&ggsw_infos); - let cbt_prepared: CircuitBootstrappingKeyPrepared, BRA, B> = cbt_key.prepare_alloc(module, scratch.borrow()); + let mut cbt_prepared: CircuitBootstrappingKeyPrepared, BRA, BE> = + CircuitBootstrappingKeyPrepared::alloc_from_infos(module, &cbt_infos); + cbt_prepared.prepare(module, &cbt_key, scratch.borrow()); let now: Instant = Instant::now(); cbt_prepared.execute_to_constant( @@ -449,8 +353,8 @@ where res.print_noise(module, &sk_glwe_prepared, &pt_ggsw); - let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(&ggsw_infos); - let mut pt_glwe: GLWEPlaintext> = GLWEPlaintext::alloc(&ggsw_infos); + let mut ct_glwe: GLWE> = GLWE::alloc_from_infos(&ggsw_infos); + let mut pt_glwe: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&ggsw_infos); pt_glwe.data.at_mut(0, 0)[0] = 1 << (base2k - k_lwe_pt - 1); ct_glwe.encrypt_sk( @@ -462,11 +366,12 @@ where scratch.borrow(), ); - let res_prepared: GGSWCiphertextPrepared, B> = res.prepare_alloc(module, scratch.borrow()); + let mut res_prepared: GGSWPrepared, BE> = GGSWPrepared::alloc_from_infos(module, &res); + res_prepared.prepare(module, &res, scratch.borrow()); ct_glwe.external_product_inplace(module, &res_prepared, scratch.borrow()); - let mut pt_res: GLWEPlaintext> = GLWEPlaintext::alloc(&ggsw_infos); + let mut pt_res: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&ggsw_infos); ct_glwe.decrypt(module, &mut pt_res, &sk_glwe_prepared, scratch.borrow()); // Parameters are set such that the first limb should be noiseless. diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/fft64.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/fft64.rs new file mode 100644 index 0000000..a1fdab2 --- /dev/null +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/fft64.rs @@ -0,0 +1,21 @@ +use poulpy_backend::cpu_fft64_ref::FFT64Ref; +use poulpy_hal::{api::ModuleNew, layouts::Module}; + +use crate::tfhe::{ + blind_rotation::CGGI, + circuit_bootstrapping::tests::circuit_bootstrapping::{ + test_circuit_bootstrapping_to_constant, test_circuit_bootstrapping_to_exponent, + }, +}; + +#[test] +fn test_to_constant_cggi() { + let module: Module = Module::::new(256); + test_circuit_bootstrapping_to_constant::(&module); +} + +#[test] +fn test_to_exponent_cggi() { + let module: Module = Module::::new(256); + test_circuit_bootstrapping_to_exponent::(&module); +} diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/implementation/cpu_spqlios/fft64.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/implementation/cpu_spqlios/fft64.rs deleted file mode 100644 index 3661f81..0000000 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/implementation/cpu_spqlios/fft64.rs +++ /dev/null @@ -1,21 +0,0 @@ -use poulpy_backend::cpu_spqlios::FFT64Spqlios; -use poulpy_hal::{api::ModuleNew, layouts::Module}; - -use crate::tfhe::{ - blind_rotation::CGGI, - circuit_bootstrapping::tests::circuit_bootstrapping::{ - test_circuit_bootstrapping_to_constant, test_circuit_bootstrapping_to_exponent, - }, -}; - -#[test] -fn test_to_constant() { - let module: Module = Module::::new(256); - test_circuit_bootstrapping_to_constant::(&module); -} - -#[test] -fn test_to_exponent() { - let module: Module = Module::::new(256); - test_circuit_bootstrapping_to_exponent::(&module); -} diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/implementation/cpu_spqlios/mod.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/implementation/cpu_spqlios/mod.rs deleted file mode 100644 index aebaafb..0000000 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/implementation/cpu_spqlios/mod.rs +++ /dev/null @@ -1 +0,0 @@ -mod fft64; diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/implementation/mod.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/implementation/mod.rs deleted file mode 100644 index f2bc1d4..0000000 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/implementation/mod.rs +++ /dev/null @@ -1 +0,0 @@ -mod cpu_spqlios; diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/mod.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/mod.rs index f9bc7d9..893a6be 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/mod.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/mod.rs @@ -1,4 +1,3 @@ pub mod circuit_bootstrapping; -#[cfg(test)] -mod implementation; +mod fft64;