diff --git a/backend/examples/rlwe_encrypt.rs b/backend/examples/rlwe_encrypt.rs index 56eedf8..58f2183 100644 --- a/backend/examples/rlwe_encrypt.rs +++ b/backend/examples/rlwe_encrypt.rs @@ -1,10 +1,10 @@ use backend::{ hal::{ api::{ - ModuleNew, ScalarZnxAlloc, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPrepare, - VecZnxAddNormal, VecZnxAlloc, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigNormalize, - VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace, VecZnxDecodeVeci64, VecZnxDftAlloc, VecZnxDftFromVecZnx, - VecZnxDftToVecZnxBigTmpA, VecZnxEncodeVeci64, VecZnxFillUniform, VecZnxNormalizeInplace, ZnxInfos, + ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPrepare, VecZnxAddNormal, + VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace, + VecZnxDecodeVeci64, VecZnxDftAlloc, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigTmpA, VecZnxEncodeVeci64, + VecZnxFillUniform, VecZnxNormalizeInplace, ZnxInfos, }, layouts::{Module, ScalarZnx, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft}, }, @@ -27,17 +27,18 @@ fn main() { let mut source: Source = Source::new(seed); // s <- Z_{-1, 0, 1}[X]/(X^{N}+1) - let mut s: ScalarZnx> = module.scalar_znx_alloc(1); + let mut s: ScalarZnx> = ScalarZnx::alloc(module.n(), 1); s.fill_ternary_prob(0, 0.5, &mut source); // Buffer to store s in the DFT domain - let mut s_dft: SvpPPol, FFT64> = module.svp_ppol_alloc(s.cols()); + let mut s_dft: SvpPPol, FFT64> = module.svp_ppol_alloc(n, s.cols()); // s_dft <- DFT(s) module.svp_prepare(&mut s_dft, 0, &s, 0); // Allocates a VecZnx with two columns: ct=(0, 0) - let mut ct: VecZnx> = module.vec_znx_alloc( + let mut ct: VecZnx> = VecZnx::alloc( + module.n(), 2, // Number of columns ct_size, // Number of small poly per column ); @@ -45,7 +46,7 @@ fn main() { // Fill the second column with random values: ct = (0, a) module.vec_znx_fill_uniform(basek, &mut ct, 1, ct_size * basek, &mut source); - let mut buf_dft: VecZnxDft, FFT64> = module.vec_znx_dft_alloc(1, ct_size); + let mut buf_dft: VecZnxDft, FFT64> = module.vec_znx_dft_alloc(n, 1, ct_size); module.vec_znx_dft_from_vec_znx(1, 0, &mut buf_dft, 0, &ct, 1); @@ -60,11 +61,12 @@ fn main() { // Alias scratch space (VecZnxDft is always at least as big as VecZnxBig) // BIG(ct[1] * s) <- IDFT(DFT(ct[1] * s)) (not normalized) - let mut buf_big: VecZnxBig, FFT64> = module.vec_znx_big_alloc(1, ct_size); + let mut buf_big: VecZnxBig, FFT64> = module.vec_znx_big_alloc(n, 1, ct_size); module.vec_znx_dft_to_vec_znx_big_tmp_a(&mut buf_big, 0, &mut buf_dft, 0); // Creates a plaintext: VecZnx with 1 column - let mut m = module.vec_znx_alloc( + let mut m = VecZnx::alloc( + module.n(), 1, // Number of columns msg_size, // Number of small polynomials ); @@ -125,7 +127,7 @@ fn main() { module.vec_znx_big_add_small_inplace(&mut buf_big, 0, &ct, 0); // m + e <- BIG(ct[1] * s + ct[0]) - let mut res = module.vec_znx_alloc(1, ct_size); + let mut res = VecZnx::alloc(module.n(), 1, ct_size); module.vec_znx_big_normalize(basek, &mut res, 0, &buf_big, 0, scratch.borrow()); // have = m * 2^{log_scale} + e diff --git a/backend/src/hal/api/mat_znx.rs b/backend/src/hal/api/mat_znx.rs deleted file mode 100644 index 3579e7d..0000000 --- a/backend/src/hal/api/mat_znx.rs +++ /dev/null @@ -1,17 +0,0 @@ -use crate::hal::layouts::MatZnxOwned; - -/// Allocates as [crate::hal::layouts::MatZnx]. -pub trait MatZnxAlloc { - fn mat_znx_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxOwned; -} - -/// Returns the size in bytes to allocate a [crate::hal::layouts::MatZnx]. -pub trait MatZnxAllocBytes { - fn mat_znx_alloc_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize; -} - -/// Consume a vector of bytes into a [crate::hal::layouts::MatZnx]. -/// User must ensure that bytes is memory aligned and that it length is equal to [MatZnxAllocBytes]. -pub trait MatZnxFromBytes { - fn mat_znx_from_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: Vec) -> MatZnxOwned; -} diff --git a/backend/src/hal/api/mod.rs b/backend/src/hal/api/mod.rs index cb806c9..e518fcd 100644 --- a/backend/src/hal/api/mod.rs +++ b/backend/src/hal/api/mod.rs @@ -1,6 +1,4 @@ -mod mat_znx; mod module; -mod scalar_znx; mod scratch; mod svp_ppol; mod vec_znx; @@ -9,9 +7,7 @@ mod vec_znx_dft; mod vmp_pmat; mod znx_base; -pub use mat_znx::*; pub use module::*; -pub use scalar_znx::*; pub use scratch::*; pub use svp_ppol::*; pub use vec_znx::*; diff --git a/backend/src/hal/api/scalar_znx.rs b/backend/src/hal/api/scalar_znx.rs deleted file mode 100644 index 6343a7e..0000000 --- a/backend/src/hal/api/scalar_znx.rs +++ /dev/null @@ -1,17 +0,0 @@ -use crate::hal::layouts::ScalarZnxOwned; - -/// Allocates as [crate::hal::layouts::ScalarZnx]. -pub trait ScalarZnxAlloc { - fn scalar_znx_alloc(&self, cols: usize) -> ScalarZnxOwned; -} - -/// Returns the size in bytes to allocate a [crate::hal::layouts::ScalarZnx]. -pub trait ScalarZnxAllocBytes { - fn scalar_znx_alloc_bytes(&self, cols: usize) -> usize; -} - -/// Consume a vector of bytes into a [crate::hal::layouts::ScalarZnx]. -/// User must ensure that bytes is memory aligned and that it length is equal to [ScalarZnxAllocBytes]. -pub trait ScalarZnxFromBytes { - fn scalar_znx_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxOwned; -} diff --git a/backend/src/hal/api/scratch.rs b/backend/src/hal/api/scratch.rs index 12b856f..812ed30 100644 --- a/backend/src/hal/api/scratch.rs +++ b/backend/src/hal/api/scratch.rs @@ -1,4 +1,4 @@ -use crate::hal::layouts::{Backend, MatZnx, Module, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}; +use crate::hal::layouts::{Backend, MatZnx, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}; /// Allocates a new [crate::hal::layouts::ScratchOwned] of `size` aligned bytes. pub trait ScratchOwnedAlloc { @@ -27,44 +27,38 @@ pub trait TakeSlice { /// Take a slice of bytes from a [Scratch], wraps it into a [ScalarZnx] and returns it /// as well as a new [Scratch] minus the taken array of bytes. -pub trait TakeScalarZnx { - fn take_scalar_znx(&mut self, module: &Module, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self); +pub trait TakeScalarZnx { + fn take_scalar_znx(&mut self, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self); } /// Take a slice of bytes from a [Scratch], wraps it into a [SvpPPol] and returns it /// as well as a new [Scratch] minus the taken array of bytes. pub trait TakeSvpPPol { - fn take_svp_ppol(&mut self, module: &Module, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Self); + fn take_svp_ppol(&mut self, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Self); } /// Take a slice of bytes from a [Scratch], wraps it into a [VecZnx] and returns it /// as well as a new [Scratch] minus the taken array of bytes. -pub trait TakeVecZnx { - fn take_vec_znx(&mut self, module: &Module, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self); +pub trait TakeVecZnx { + fn take_vec_znx(&mut self, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self); } /// Take a slice of bytes from a [Scratch], slices it into a vector of [VecZnx] aand returns it /// as well as a new [Scratch] minus the taken array of bytes. -pub trait TakeVecZnxSlice { - fn take_vec_znx_slice( - &mut self, - len: usize, - module: &Module, - cols: usize, - size: usize, - ) -> (Vec>, &mut Self); +pub trait TakeVecZnxSlice { + fn take_vec_znx_slice(&mut self, len: usize, n: usize, cols: usize, size: usize) -> (Vec>, &mut Self); } /// Take a slice of bytes from a [Scratch], wraps it into a [VecZnxBig] and returns it /// as well as a new [Scratch] minus the taken array of bytes. pub trait TakeVecZnxBig { - fn take_vec_znx_big(&mut self, module: &Module, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, &mut Self); + fn take_vec_znx_big(&mut self, n: usize, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, &mut Self); } /// Take a slice of bytes from a [Scratch], wraps it into a [VecZnxDft] and returns it /// as well as a new [Scratch] minus the taken array of bytes. pub trait TakeVecZnxDft { - fn take_vec_znx_dft(&mut self, module: &Module, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, &mut Self); + fn take_vec_znx_dft(&mut self, n: usize, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, &mut Self); } /// Take a slice of bytes from a [Scratch], slices it into a vector of [VecZnxDft] and returns it @@ -73,7 +67,7 @@ pub trait TakeVecZnxDftSlice { fn take_vec_znx_dft_slice( &mut self, len: usize, - module: &Module, + n: usize, cols: usize, size: usize, ) -> (Vec>, &mut Self); @@ -84,7 +78,7 @@ pub trait TakeVecZnxDftSlice { pub trait TakeVmpPMat { fn take_vmp_pmat( &mut self, - module: &Module, + n: usize, rows: usize, cols_in: usize, cols_out: usize, @@ -94,10 +88,10 @@ pub trait TakeVmpPMat { /// Take a slice of bytes from a [Scratch], wraps it into a [MatZnx] and returns it /// as well as a new [Scratch] minus the taken array of bytes. -pub trait TakeMatZnx { +pub trait TakeMatZnx { fn take_mat_znx( &mut self, - module: &Module, + n: usize, rows: usize, cols_in: usize, cols_out: usize, diff --git a/backend/src/hal/api/svp_ppol.rs b/backend/src/hal/api/svp_ppol.rs index f500923..05aa189 100644 --- a/backend/src/hal/api/svp_ppol.rs +++ b/backend/src/hal/api/svp_ppol.rs @@ -2,18 +2,18 @@ use crate::hal::layouts::{Backend, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, S /// Allocates as [crate::hal::layouts::SvpPPol]. pub trait SvpPPolAlloc { - fn svp_ppol_alloc(&self, cols: usize) -> SvpPPolOwned; + fn svp_ppol_alloc(&self, n: usize, cols: usize) -> SvpPPolOwned; } /// Returns the size in bytes to allocate a [crate::hal::layouts::SvpPPol]. pub trait SvpPPolAllocBytes { - fn svp_ppol_alloc_bytes(&self, cols: usize) -> usize; + fn svp_ppol_alloc_bytes(&self, n: usize, cols: usize) -> usize; } /// Consume a vector of bytes into a [crate::hal::layouts::MatZnx]. /// User must ensure that bytes is memory aligned and that it length is equal to [SvpPPolAllocBytes]. pub trait SvpPPolFromBytes { - fn svp_ppol_from_bytes(&self, cols: usize, bytes: Vec) -> SvpPPolOwned; + fn svp_ppol_from_bytes(&self, n: usize, cols: usize, bytes: Vec) -> SvpPPolOwned; } /// Prepare a [crate::hal::layouts::ScalarZnx] into an [crate::hal::layouts::SvpPPol]. diff --git a/backend/src/hal/api/vec_znx.rs b/backend/src/hal/api/vec_znx.rs index 413b90b..7f6ce6b 100644 --- a/backend/src/hal/api/vec_znx.rs +++ b/backend/src/hal/api/vec_znx.rs @@ -2,33 +2,7 @@ use rand_distr::Distribution; use rug::Float; use sampling::source::Source; -use crate::hal::layouts::{Backend, ScalarZnxToRef, Scratch, VecZnxOwned, VecZnxToMut, VecZnxToRef}; - -pub trait VecZnxAlloc { - /// Allocates a new [crate::hal::layouts::VecZnx]. - /// - /// # Arguments - /// - /// * `cols`: the number of polynomials. - /// * `size`: the number small polynomials per column. - fn vec_znx_alloc(&self, cols: usize, size: usize) -> VecZnxOwned; -} - -pub trait VecZnxFromBytes { - /// Instantiates a new [crate::hal::layouts::VecZnx] from a slice of bytes. - /// The returned [crate::hal::layouts::VecZnx] takes ownership of the slice of bytes. - /// - /// # Arguments - /// - /// * `cols`: the number of polynomials. - /// * `size`: the number small polynomials per column. - fn vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxOwned; -} - -pub trait VecZnxAllocBytes { - /// Returns the number of bytes necessary to allocate a new [crate::hal::layouts::VecZnx]. - fn vec_znx_alloc_bytes(&self, cols: usize, size: usize) -> usize; -} +use crate::hal::layouts::{Backend, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef}; pub trait VecZnxNormalizeTmpBytes { /// Returns the minimum number of bytes necessary for normalization. diff --git a/backend/src/hal/api/vec_znx_big.rs b/backend/src/hal/api/vec_znx_big.rs index d32f67f..249e078 100644 --- a/backend/src/hal/api/vec_znx_big.rs +++ b/backend/src/hal/api/vec_znx_big.rs @@ -5,18 +5,18 @@ use crate::hal::layouts::{Backend, Scratch, VecZnxBigOwned, VecZnxBigToMut, VecZ /// Allocates as [crate::hal::layouts::VecZnxBig]. pub trait VecZnxBigAlloc { - fn vec_znx_big_alloc(&self, cols: usize, size: usize) -> VecZnxBigOwned; + fn vec_znx_big_alloc(&self, n: usize, cols: usize, size: usize) -> VecZnxBigOwned; } /// Returns the size in bytes to allocate a [crate::hal::layouts::VecZnxBig]. pub trait VecZnxBigAllocBytes { - fn vec_znx_big_alloc_bytes(&self, cols: usize, size: usize) -> usize; + fn vec_znx_big_alloc_bytes(&self, n: usize, cols: usize, size: usize) -> usize; } /// Consume a vector of bytes into a [crate::hal::layouts::VecZnxBig]. /// User must ensure that bytes is memory aligned and that it length is equal to [VecZnxBigAllocBytes]. pub trait VecZnxBigFromBytes { - fn vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxBigOwned; + fn vec_znx_big_from_bytes(&self, n: usize, cols: usize, size: usize, bytes: Vec) -> VecZnxBigOwned; } /// Add a discrete normal distribution on res. diff --git a/backend/src/hal/api/vec_znx_dft.rs b/backend/src/hal/api/vec_znx_dft.rs index 4aeac2f..8efe892 100644 --- a/backend/src/hal/api/vec_znx_dft.rs +++ b/backend/src/hal/api/vec_znx_dft.rs @@ -3,19 +3,19 @@ use crate::hal::layouts::{ }; pub trait VecZnxDftAlloc { - fn vec_znx_dft_alloc(&self, cols: usize, size: usize) -> VecZnxDftOwned; + fn vec_znx_dft_alloc(&self, n: usize, cols: usize, size: usize) -> VecZnxDftOwned; } pub trait VecZnxDftFromBytes { - fn vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxDftOwned; + fn vec_znx_dft_from_bytes(&self, n: usize, cols: usize, size: usize, bytes: Vec) -> VecZnxDftOwned; } pub trait VecZnxDftAllocBytes { - fn vec_znx_dft_alloc_bytes(&self, cols: usize, size: usize) -> usize; + fn vec_znx_dft_alloc_bytes(&self, n: usize, cols: usize, size: usize) -> usize; } pub trait VecZnxDftToVecZnxBigTmpBytes { - fn vec_znx_dft_to_vec_znx_big_tmp_bytes(&self) -> usize; + fn vec_znx_dft_to_vec_znx_big_tmp_bytes(&self, n: usize) -> usize; } pub trait VecZnxDftToVecZnxBig { diff --git a/backend/src/hal/api/vmp_pmat.rs b/backend/src/hal/api/vmp_pmat.rs index 8b0ade7..64fd8a5 100644 --- a/backend/src/hal/api/vmp_pmat.rs +++ b/backend/src/hal/api/vmp_pmat.rs @@ -3,19 +3,27 @@ use crate::hal::layouts::{ }; pub trait VmpPMatAlloc { - fn vmp_pmat_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned; + fn vmp_pmat_alloc(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned; } pub trait VmpPMatAllocBytes { - fn vmp_pmat_alloc_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize; + fn vmp_pmat_alloc_bytes(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize; } pub trait VmpPMatFromBytes { - fn vmp_pmat_from_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: Vec) -> VmpPMatOwned; + fn vmp_pmat_from_bytes( + &self, + n: usize, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + bytes: Vec, + ) -> VmpPMatOwned; } pub trait VmpPrepareTmpBytes { - fn vmp_prepare_tmp_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize; + fn vmp_prepare_tmp_bytes(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize; } pub trait VmpPMatPrepare { @@ -28,6 +36,7 @@ pub trait VmpPMatPrepare { pub trait VmpApplyTmpBytes { fn vmp_apply_tmp_bytes( &self, + n: usize, res_size: usize, a_size: usize, b_rows: usize, @@ -72,6 +81,7 @@ pub trait VmpApply { pub trait VmpApplyAddTmpBytes { fn vmp_apply_add_tmp_bytes( &self, + n: usize, res_size: usize, a_size: usize, b_rows: usize, diff --git a/backend/src/hal/api/znx_base.rs b/backend/src/hal/api/znx_base.rs index 0af86e8..bd94512 100644 --- a/backend/src/hal/api/znx_base.rs +++ b/backend/src/hal/api/znx_base.rs @@ -113,3 +113,7 @@ where pub trait FillUniform { fn fill_uniform(&mut self, source: &mut Source); } + +pub trait Reset { + fn reset(&mut self); +} diff --git a/backend/src/hal/delegates/mat_znx.rs b/backend/src/hal/delegates/mat_znx.rs deleted file mode 100644 index 1f63cae..0000000 --- a/backend/src/hal/delegates/mat_znx.rs +++ /dev/null @@ -1,32 +0,0 @@ -use crate::hal::{ - api::{MatZnxAlloc, MatZnxAllocBytes, MatZnxFromBytes}, - layouts::{Backend, MatZnxOwned, Module}, - oep::{MatZnxAllocBytesImpl, MatZnxAllocImpl, MatZnxFromBytesImpl}, -}; - -impl MatZnxAlloc for Module -where - B: Backend + MatZnxAllocImpl, -{ - fn mat_znx_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxOwned { - B::mat_znx_alloc_impl(self, rows, cols_in, cols_out, size) - } -} - -impl MatZnxAllocBytes for Module -where - B: Backend + MatZnxAllocBytesImpl, -{ - fn mat_znx_alloc_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { - B::mat_znx_alloc_bytes_impl(self, rows, cols_in, cols_out, size) - } -} - -impl MatZnxFromBytes for Module -where - B: Backend + MatZnxFromBytesImpl, -{ - fn mat_znx_from_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: Vec) -> MatZnxOwned { - B::mat_znx_from_bytes_impl(self, rows, cols_in, cols_out, size, bytes) - } -} diff --git a/backend/src/hal/delegates/mod.rs b/backend/src/hal/delegates/mod.rs index f02a59b..595a641 100644 --- a/backend/src/hal/delegates/mod.rs +++ b/backend/src/hal/delegates/mod.rs @@ -1,6 +1,4 @@ -mod mat_znx; mod module; -mod scalar_znx; mod scratch; mod svp_ppol; mod vec_znx; diff --git a/backend/src/hal/delegates/scalar_znx.rs b/backend/src/hal/delegates/scalar_znx.rs deleted file mode 100644 index 644362d..0000000 --- a/backend/src/hal/delegates/scalar_znx.rs +++ /dev/null @@ -1,23 +0,0 @@ -use crate::hal::{ - api::{ScalarZnxAlloc, ScalarZnxAllocBytes}, - layouts::{Backend, Module, ScalarZnxOwned}, - oep::{ScalarZnxAllocBytesImpl, ScalarZnxAllocImpl}, -}; - -impl ScalarZnxAllocBytes for Module -where - B: Backend + ScalarZnxAllocBytesImpl, -{ - fn scalar_znx_alloc_bytes(&self, cols: usize) -> usize { - B::scalar_znx_alloc_bytes_impl(self.n(), cols) - } -} - -impl ScalarZnxAlloc for Module -where - B: Backend + ScalarZnxAllocImpl, -{ - fn scalar_znx_alloc(&self, cols: usize) -> ScalarZnxOwned { - B::scalar_znx_alloc_impl(self.n(), cols) - } -} diff --git a/backend/src/hal/delegates/scratch.rs b/backend/src/hal/delegates/scratch.rs index 350c6a9..a5f6b58 100644 --- a/backend/src/hal/delegates/scratch.rs +++ b/backend/src/hal/delegates/scratch.rs @@ -3,9 +3,7 @@ use crate::hal::{ ScratchAvailable, ScratchFromBytes, ScratchOwnedAlloc, ScratchOwnedBorrow, TakeLike, TakeMatZnx, TakeScalarZnx, TakeSlice, TakeSvpPPol, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice, TakeVecZnxSlice, TakeVmpPMat, }, - layouts::{ - Backend, DataRef, MatZnx, Module, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, - }, + layouts::{Backend, DataRef, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}, oep::{ ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeLikeImpl, TakeMatZnxImpl, TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxDftSliceImpl, @@ -58,12 +56,12 @@ where } } -impl TakeScalarZnx for Scratch +impl TakeScalarZnx for Scratch where B: Backend + TakeScalarZnxImpl, { - fn take_scalar_znx(&mut self, module: &Module, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) { - B::take_scalar_znx_impl(self, module.n(), cols) + fn take_scalar_znx(&mut self, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) { + B::take_scalar_znx_impl(self, n, cols) } } @@ -71,32 +69,26 @@ impl TakeSvpPPol for Scratch where B: Backend + TakeSvpPPolImpl, { - fn take_svp_ppol(&mut self, module: &Module, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Self) { - B::take_svp_ppol_impl(self, module.n(), cols) + fn take_svp_ppol(&mut self, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Self) { + B::take_svp_ppol_impl(self, n, cols) } } -impl TakeVecZnx for Scratch +impl TakeVecZnx for Scratch where B: Backend + TakeVecZnxImpl, { - fn take_vec_znx(&mut self, module: &Module, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) { - B::take_vec_znx_impl(self, module.n(), cols, size) + fn take_vec_znx(&mut self, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) { + B::take_vec_znx_impl(self, n, cols, size) } } -impl TakeVecZnxSlice for Scratch +impl TakeVecZnxSlice for Scratch where B: Backend + TakeVecZnxSliceImpl, { - fn take_vec_znx_slice( - &mut self, - len: usize, - module: &Module, - cols: usize, - size: usize, - ) -> (Vec>, &mut Self) { - B::take_vec_znx_slice_impl(self, len, module.n(), cols, size) + fn take_vec_znx_slice(&mut self, len: usize, n: usize, cols: usize, size: usize) -> (Vec>, &mut Self) { + B::take_vec_znx_slice_impl(self, len, n, cols, size) } } @@ -104,8 +96,8 @@ impl TakeVecZnxBig for Scratch where B: Backend + TakeVecZnxBigImpl, { - fn take_vec_znx_big(&mut self, module: &Module, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, &mut Self) { - B::take_vec_znx_big_impl(self, module.n(), cols, size) + fn take_vec_znx_big(&mut self, n: usize, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, &mut Self) { + B::take_vec_znx_big_impl(self, n, cols, size) } } @@ -113,8 +105,8 @@ impl TakeVecZnxDft for Scratch where B: Backend + TakeVecZnxDftImpl, { - fn take_vec_znx_dft(&mut self, module: &Module, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, &mut Self) { - B::take_vec_znx_dft_impl(self, module.n(), cols, size) + fn take_vec_znx_dft(&mut self, n: usize, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, &mut Self) { + B::take_vec_znx_dft_impl(self, n, cols, size) } } @@ -125,11 +117,11 @@ where fn take_vec_znx_dft_slice( &mut self, len: usize, - module: &Module, + n: usize, cols: usize, size: usize, ) -> (Vec>, &mut Self) { - B::take_vec_znx_dft_slice_impl(self, len, module.n(), cols, size) + B::take_vec_znx_dft_slice_impl(self, len, n, cols, size) } } @@ -139,29 +131,29 @@ where { fn take_vmp_pmat( &mut self, - module: &Module, + n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize, ) -> (VmpPMat<&mut [u8], B>, &mut Self) { - B::take_vmp_pmat_impl(self, module.n(), rows, cols_in, cols_out, size) + B::take_vmp_pmat_impl(self, n, rows, cols_in, cols_out, size) } } -impl TakeMatZnx for Scratch +impl TakeMatZnx for Scratch where B: Backend + TakeMatZnxImpl, { fn take_mat_znx( &mut self, - module: &Module, + n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize, ) -> (MatZnx<&mut [u8]>, &mut Self) { - B::take_mat_znx_impl(self, module.n(), rows, cols_in, cols_out, size) + B::take_mat_znx_impl(self, n, rows, cols_in, cols_out, size) } } diff --git a/backend/src/hal/delegates/svp_ppol.rs b/backend/src/hal/delegates/svp_ppol.rs index e968f8d..e47e474 100644 --- a/backend/src/hal/delegates/svp_ppol.rs +++ b/backend/src/hal/delegates/svp_ppol.rs @@ -8,8 +8,8 @@ impl SvpPPolFromBytes for Module where B: Backend + SvpPPolFromBytesImpl, { - fn svp_ppol_from_bytes(&self, cols: usize, bytes: Vec) -> SvpPPolOwned { - B::svp_ppol_from_bytes_impl(self.n(), cols, bytes) + fn svp_ppol_from_bytes(&self, n: usize, cols: usize, bytes: Vec) -> SvpPPolOwned { + B::svp_ppol_from_bytes_impl(n, cols, bytes) } } @@ -17,8 +17,8 @@ impl SvpPPolAlloc for Module where B: Backend + SvpPPolAllocImpl, { - fn svp_ppol_alloc(&self, cols: usize) -> SvpPPolOwned { - B::svp_ppol_alloc_impl(self.n(), cols) + fn svp_ppol_alloc(&self, n: usize, cols: usize) -> SvpPPolOwned { + B::svp_ppol_alloc_impl(n, cols) } } @@ -26,8 +26,8 @@ impl SvpPPolAllocBytes for Module where B: Backend + SvpPPolAllocBytesImpl, { - fn svp_ppol_alloc_bytes(&self, cols: usize) -> usize { - B::svp_ppol_alloc_bytes_impl(self.n(), cols) + fn svp_ppol_alloc_bytes(&self, n: usize, cols: usize) -> usize { + B::svp_ppol_alloc_bytes_impl(n, cols) } } diff --git a/backend/src/hal/delegates/vec_znx.rs b/backend/src/hal/delegates/vec_znx.rs index b10d8cc..3eac8a6 100644 --- a/backend/src/hal/delegates/vec_znx.rs +++ b/backend/src/hal/delegates/vec_znx.rs @@ -2,54 +2,26 @@ use sampling::source::Source; use crate::hal::{ api::{ - VecZnxAdd, VecZnxAddDistF64, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAlloc, VecZnxAllocBytes, - VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxCopy, VecZnxDecodeCoeffsi64, VecZnxDecodeVecFloat, - VecZnxDecodeVeci64, VecZnxEncodeCoeffsi64, VecZnxEncodeVeci64, VecZnxFillDistF64, VecZnxFillNormal, VecZnxFillUniform, - VecZnxFromBytes, VecZnxLshInplace, VecZnxMerge, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegate, - VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, - VecZnxRshInplace, VecZnxSplit, VecZnxStd, VecZnxSub, VecZnxSubABInplace, VecZnxSubBAInplace, VecZnxSubScalarInplace, - VecZnxSwithcDegree, + VecZnxAdd, VecZnxAddDistF64, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, + VecZnxAutomorphismInplace, VecZnxCopy, VecZnxDecodeCoeffsi64, VecZnxDecodeVecFloat, VecZnxDecodeVeci64, + VecZnxEncodeCoeffsi64, VecZnxEncodeVeci64, VecZnxFillDistF64, VecZnxFillNormal, VecZnxFillUniform, VecZnxLshInplace, + VecZnxMerge, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegate, VecZnxNegateInplace, VecZnxNormalize, + VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSplit, + VecZnxStd, VecZnxSub, VecZnxSubABInplace, VecZnxSubBAInplace, VecZnxSubScalarInplace, VecZnxSwithcDegree, }, - layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxOwned, VecZnxToMut, VecZnxToRef}, + layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef}, oep::{ VecZnxAddDistF64Impl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarInplaceImpl, - VecZnxAllocBytesImpl, VecZnxAllocImpl, VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, - VecZnxDecodeCoeffsi64Impl, VecZnxDecodeVecFloatImpl, VecZnxDecodeVeci64Impl, VecZnxEncodeCoeffsi64Impl, - VecZnxEncodeVeci64Impl, VecZnxFillDistF64Impl, VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxFromBytesImpl, - VecZnxLshInplaceImpl, VecZnxMergeImpl, VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl, - VecZnxNegateInplaceImpl, VecZnxNormalizeImpl, VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, - VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl, VecZnxSplitImpl, VecZnxStdImpl, VecZnxSubABInplaceImpl, - VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl, + VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, VecZnxDecodeCoeffsi64Impl, + VecZnxDecodeVecFloatImpl, VecZnxDecodeVeci64Impl, VecZnxEncodeCoeffsi64Impl, VecZnxEncodeVeci64Impl, + VecZnxFillDistF64Impl, VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxLshInplaceImpl, VecZnxMergeImpl, + VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl, + VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl, + VecZnxSplitImpl, VecZnxStdImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, + VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl, }, }; -impl VecZnxAlloc for Module -where - B: Backend + VecZnxAllocImpl, -{ - fn vec_znx_alloc(&self, cols: usize, size: usize) -> VecZnxOwned { - B::vec_znx_alloc_impl(self.n(), cols, size) - } -} - -impl VecZnxFromBytes for Module -where - B: Backend + VecZnxFromBytesImpl, -{ - fn vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxOwned { - B::vec_znx_from_bytes_impl(self.n(), cols, size, bytes) - } -} - -impl VecZnxAllocBytes for Module -where - B: Backend + VecZnxAllocBytesImpl, -{ - fn vec_znx_alloc_bytes(&self, cols: usize, size: usize) -> usize { - B::vec_znx_alloc_bytes_impl(self.n(), cols, size) - } -} - impl VecZnxNormalizeTmpBytes for Module where B: Backend + VecZnxNormalizeTmpBytesImpl, diff --git a/backend/src/hal/delegates/vec_znx_big.rs b/backend/src/hal/delegates/vec_znx_big.rs index 378e718..fc48d42 100644 --- a/backend/src/hal/delegates/vec_znx_big.rs +++ b/backend/src/hal/delegates/vec_znx_big.rs @@ -24,8 +24,8 @@ impl VecZnxBigAlloc for Module where B: Backend + VecZnxBigAllocImpl, { - fn vec_znx_big_alloc(&self, cols: usize, size: usize) -> VecZnxBigOwned { - B::vec_znx_big_alloc_impl(self.n(), cols, size) + fn vec_znx_big_alloc(&self, n: usize, cols: usize, size: usize) -> VecZnxBigOwned { + B::vec_znx_big_alloc_impl(n, cols, size) } } @@ -33,8 +33,8 @@ impl VecZnxBigFromBytes for Module where B: Backend + VecZnxBigFromBytesImpl, { - fn vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxBigOwned { - B::vec_znx_big_from_bytes_impl(self.n(), cols, size, bytes) + fn vec_znx_big_from_bytes(&self, n: usize, cols: usize, size: usize, bytes: Vec) -> VecZnxBigOwned { + B::vec_znx_big_from_bytes_impl(n, cols, size, bytes) } } @@ -42,8 +42,8 @@ impl VecZnxBigAllocBytes for Module where B: Backend + VecZnxBigAllocBytesImpl, { - fn vec_znx_big_alloc_bytes(&self, cols: usize, size: usize) -> usize { - B::vec_znx_big_alloc_bytes_impl(self.n(), cols, size) + fn vec_znx_big_alloc_bytes(&self, n: usize, cols: usize, size: usize) -> usize { + B::vec_znx_big_alloc_bytes_impl(n, cols, size) } } diff --git a/backend/src/hal/delegates/vec_znx_dft.rs b/backend/src/hal/delegates/vec_znx_dft.rs index 6877dad..75744f4 100644 --- a/backend/src/hal/delegates/vec_znx_dft.rs +++ b/backend/src/hal/delegates/vec_znx_dft.rs @@ -20,8 +20,8 @@ impl VecZnxDftFromBytes for Module where B: Backend + VecZnxDftFromBytesImpl, { - fn vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxDftOwned { - B::vec_znx_dft_from_bytes_impl(self.n(), cols, size, bytes) + fn vec_znx_dft_from_bytes(&self, n: usize, cols: usize, size: usize, bytes: Vec) -> VecZnxDftOwned { + B::vec_znx_dft_from_bytes_impl(n, cols, size, bytes) } } @@ -29,8 +29,8 @@ impl VecZnxDftAllocBytes for Module where B: Backend + VecZnxDftAllocBytesImpl, { - fn vec_znx_dft_alloc_bytes(&self, cols: usize, size: usize) -> usize { - B::vec_znx_dft_alloc_bytes_impl(self.n(), cols, size) + fn vec_znx_dft_alloc_bytes(&self, n: usize, cols: usize, size: usize) -> usize { + B::vec_znx_dft_alloc_bytes_impl(n, cols, size) } } @@ -38,8 +38,8 @@ impl VecZnxDftAlloc for Module where B: Backend + VecZnxDftAllocImpl, { - fn vec_znx_dft_alloc(&self, cols: usize, size: usize) -> VecZnxDftOwned { - B::vec_znx_dft_alloc_impl(self.n(), cols, size) + fn vec_znx_dft_alloc(&self, n: usize, cols: usize, size: usize) -> VecZnxDftOwned { + B::vec_znx_dft_alloc_impl(n, cols, size) } } @@ -47,8 +47,8 @@ impl VecZnxDftToVecZnxBigTmpBytes for Module where B: Backend + VecZnxDftToVecZnxBigTmpBytesImpl, { - fn vec_znx_dft_to_vec_znx_big_tmp_bytes(&self) -> usize { - B::vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(self) + fn vec_znx_dft_to_vec_znx_big_tmp_bytes(&self, n: usize) -> usize { + B::vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(self, n) } } diff --git a/backend/src/hal/delegates/vmp_pmat.rs b/backend/src/hal/delegates/vmp_pmat.rs index 89fca8f..0d9a501 100644 --- a/backend/src/hal/delegates/vmp_pmat.rs +++ b/backend/src/hal/delegates/vmp_pmat.rs @@ -14,8 +14,8 @@ impl VmpPMatAlloc for Module where B: Backend + VmpPMatAllocImpl, { - fn vmp_pmat_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned { - B::vmp_pmat_alloc_impl(self.n(), rows, cols_in, cols_out, size) + fn vmp_pmat_alloc(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned { + B::vmp_pmat_alloc_impl(n, rows, cols_in, cols_out, size) } } @@ -23,8 +23,8 @@ impl VmpPMatAllocBytes for Module where B: Backend + VmpPMatAllocBytesImpl, { - fn vmp_pmat_alloc_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { - B::vmp_pmat_alloc_bytes_impl(self.n(), rows, cols_in, cols_out, size) + fn vmp_pmat_alloc_bytes(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { + B::vmp_pmat_alloc_bytes_impl(n, rows, cols_in, cols_out, size) } } @@ -32,8 +32,16 @@ impl VmpPMatFromBytes for Module where B: Backend + VmpPMatFromBytesImpl, { - fn vmp_pmat_from_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: Vec) -> VmpPMatOwned { - B::vmp_pmat_from_bytes_impl(self.n(), rows, cols_in, cols_out, size, bytes) + fn vmp_pmat_from_bytes( + &self, + n: usize, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + bytes: Vec, + ) -> VmpPMatOwned { + B::vmp_pmat_from_bytes_impl(n, rows, cols_in, cols_out, size, bytes) } } @@ -41,8 +49,8 @@ impl VmpPrepareTmpBytes for Module where B: Backend + VmpPrepareTmpBytesImpl, { - fn vmp_prepare_tmp_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { - B::vmp_prepare_tmp_bytes_impl(self, rows, cols_in, cols_out, size) + fn vmp_prepare_tmp_bytes(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { + B::vmp_prepare_tmp_bytes_impl(self, n, rows, cols_in, cols_out, size) } } @@ -65,6 +73,7 @@ where { fn vmp_apply_tmp_bytes( &self, + n: usize, res_size: usize, a_size: usize, b_rows: usize, @@ -73,7 +82,7 @@ where b_size: usize, ) -> usize { B::vmp_apply_tmp_bytes_impl( - self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size, + self, n, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size, ) } } @@ -98,6 +107,7 @@ where { fn vmp_apply_add_tmp_bytes( &self, + n: usize, res_size: usize, a_size: usize, b_rows: usize, @@ -106,7 +116,7 @@ where b_size: usize, ) -> usize { B::vmp_apply_add_tmp_bytes_impl( - self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size, + self, n, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size, ) } } diff --git a/backend/src/hal/layouts/mat_znx.rs b/backend/src/hal/layouts/mat_znx.rs index 78693cc..7021d02 100644 --- a/backend/src/hal/layouts/mat_znx.rs +++ b/backend/src/hal/layouts/mat_znx.rs @@ -1,7 +1,7 @@ use crate::{ alloc_aligned, hal::{ - api::{DataView, DataViewMut, FillUniform, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero}, + api::{DataView, DataViewMut, FillUniform, Reset, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero}, layouts::{Data, DataMut, DataRef, ReaderFrom, VecZnx, WriterTo}, }, }; @@ -78,15 +78,13 @@ impl MatZnx { } } -impl MatZnx { - pub fn bytes_of(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { - rows * cols_in * VecZnx::>::alloc_bytes::(n, cols_out, size) +impl MatZnx> { + pub fn alloc_bytes(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { + rows * cols_in * VecZnx::>::alloc_bytes(n, cols_out, size) } -} -impl>> MatZnx { - pub(crate) fn alloc(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self { - let data: Vec = alloc_aligned(Self::bytes_of(n, rows, cols_in, cols_out, size)); + pub fn alloc(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self { + let data: Vec = alloc_aligned(Self::alloc_bytes(n, rows, cols_in, cols_out, size)); Self { data: data.into(), n, @@ -97,16 +95,9 @@ impl>> MatZnx { } } - pub(crate) fn from_bytes( - n: usize, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - bytes: impl Into>, - ) -> Self { + pub fn from_bytes(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); - assert!(data.len() == Self::bytes_of(n, rows, cols_in, cols_out, size)); + assert!(data.len() == Self::alloc_bytes(n, rows, cols_in, cols_out, size)); Self { data: data.into(), n, @@ -127,7 +118,7 @@ impl MatZnx { } let self_ref: MatZnx<&[u8]> = self.to_ref(); - let nb_bytes: usize = VecZnx::>::alloc_bytes::(self.n, self.cols_out, self.size); + let nb_bytes: usize = VecZnx::>::alloc_bytes(self.n, self.cols_out, self.size); let start: usize = nb_bytes * self.cols() * row + col * nb_bytes; let end: usize = start + nb_bytes; @@ -155,7 +146,7 @@ impl MatZnx { let size: usize = self.size(); let self_ref: MatZnx<&mut [u8]> = self.to_mut(); - let nb_bytes: usize = VecZnx::>::alloc_bytes::(n, cols_out, size); + let nb_bytes: usize = VecZnx::>::alloc_bytes(n, cols_out, size); let start: usize = nb_bytes * cols_in * row + col * nb_bytes; let end: usize = start + nb_bytes; @@ -175,6 +166,17 @@ impl FillUniform for MatZnx { } } +impl Reset for MatZnx { + fn reset(&mut self) { + self.zero(); + self.n = 0; + self.size = 0; + self.rows = 0; + self.cols_in = 0; + self.cols_out = 0; + } +} + pub type MatZnxOwned = MatZnx>; pub type MatZnxMut<'a> = MatZnx<&'a mut [u8]>; pub type MatZnxRef<'a> = MatZnx<&'a [u8]>; diff --git a/backend/src/hal/layouts/scalar_znx.rs b/backend/src/hal/layouts/scalar_znx.rs index 4e23939..068b116 100644 --- a/backend/src/hal/layouts/scalar_znx.rs +++ b/backend/src/hal/layouts/scalar_znx.rs @@ -6,7 +6,7 @@ use sampling::source::Source; use crate::{ alloc_aligned, hal::{ - api::{DataView, DataViewMut, FillUniform, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero}, + api::{DataView, DataViewMut, FillUniform, Reset, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero}, layouts::{Data, DataMut, DataRef, ReaderFrom, VecZnx, WriterTo}, }, }; @@ -107,15 +107,13 @@ impl ScalarZnx { } } -impl ScalarZnx { - pub fn bytes_of(n: usize, cols: usize) -> usize { +impl ScalarZnx> { + pub fn alloc_bytes(n: usize, cols: usize) -> usize { n * cols * size_of::() } -} -impl>> ScalarZnx { pub fn alloc(n: usize, cols: usize) -> Self { - let data: Vec = alloc_aligned::(Self::bytes_of(n, cols)); + let data: Vec = alloc_aligned::(Self::alloc_bytes(n, cols)); Self { data: data.into(), n, @@ -123,9 +121,9 @@ impl>> ScalarZnx { } } - pub(crate) fn from_bytes(n: usize, cols: usize, bytes: impl Into>) -> Self { + pub fn from_bytes(n: usize, cols: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); - assert!(data.len() == Self::bytes_of(n, cols)); + assert!(data.len() == Self::alloc_bytes(n, cols)); Self { data: data.into(), n, @@ -149,6 +147,14 @@ impl FillUniform for ScalarZnx { } } +impl Reset for ScalarZnx { + fn reset(&mut self) { + self.zero(); + self.n = 0; + self.cols = 0; + } +} + pub type ScalarZnxOwned = ScalarZnx>; impl ScalarZnx { diff --git a/backend/src/hal/layouts/vec_znx.rs b/backend/src/hal/layouts/vec_znx.rs index 2d4557a..b9afc1b 100644 --- a/backend/src/hal/layouts/vec_znx.rs +++ b/backend/src/hal/layouts/vec_znx.rs @@ -3,7 +3,7 @@ use std::fmt; use crate::{ alloc_aligned, hal::{ - api::{DataView, DataViewMut, FillUniform, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero}, + api::{DataView, DataViewMut, FillUniform, Reset, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero}, layouts::{Data, DataMut, DataRef, ReaderFrom, WriterTo}, }, }; @@ -79,15 +79,13 @@ impl ZnxZero for VecZnx { } } -impl VecZnx { - pub fn alloc_bytes(n: usize, cols: usize, size: usize) -> usize { - n * cols * size * size_of::() +impl VecZnx> { + pub fn alloc_bytes(n: usize, cols: usize, size: usize) -> usize { + n * cols * size * size_of::() } -} -impl>> VecZnx { - pub fn alloc(n: usize, cols: usize, size: usize) -> Self { - let data: Vec = alloc_aligned::(Self::alloc_bytes::(n, cols, size)); + pub fn alloc(n: usize, cols: usize, size: usize) -> Self { + let data: Vec = alloc_aligned::(Self::alloc_bytes(n, cols, size)); Self { data: data.into(), n, @@ -99,7 +97,7 @@ impl>> VecZnx { pub fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); - assert!(data.len() == Self::alloc_bytes::(n, cols, size)); + assert!(data.len() == Self::alloc_bytes(n, cols, size)); Self { data: data.into(), n, @@ -163,6 +161,16 @@ impl FillUniform for VecZnx { } } +impl Reset for VecZnx { + fn reset(&mut self) { + self.zero(); + self.n = 0; + self.cols = 0; + self.size = 0; + self.max_size = 0; + } +} + pub type VecZnxOwned = VecZnx>; pub type VecZnxMut<'a> = VecZnx<&'a mut [u8]>; pub type VecZnxRef<'a> = VecZnx<&'a [u8]>; diff --git a/backend/src/hal/oep/mat_znx.rs b/backend/src/hal/oep/mat_znx.rs deleted file mode 100644 index 87fd16a..0000000 --- a/backend/src/hal/oep/mat_znx.rs +++ /dev/null @@ -1,20 +0,0 @@ -use crate::hal::layouts::{Backend, MatZnxOwned, Module}; - -pub unsafe trait MatZnxAllocImpl { - fn mat_znx_alloc_impl(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxOwned; -} - -pub unsafe trait MatZnxAllocBytesImpl { - fn mat_znx_alloc_bytes_impl(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize; -} - -pub unsafe trait MatZnxFromBytesImpl { - fn mat_znx_from_bytes_impl( - module: &Module, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - bytes: Vec, - ) -> MatZnxOwned; -} diff --git a/backend/src/hal/oep/mod.rs b/backend/src/hal/oep/mod.rs index ef1ee02..bc53c0e 100644 --- a/backend/src/hal/oep/mod.rs +++ b/backend/src/hal/oep/mod.rs @@ -1,6 +1,4 @@ -mod mat_znx; mod module; -mod scalar_znx; mod scratch; mod svp_ppol; mod vec_znx; @@ -8,9 +6,7 @@ mod vec_znx_big; mod vec_znx_dft; mod vmp_pmat; -pub use mat_znx::*; pub use module::*; -pub use scalar_znx::*; pub use scratch::*; pub use svp_ppol::*; pub use vec_znx::*; diff --git a/backend/src/hal/oep/scalar_znx.rs b/backend/src/hal/oep/scalar_znx.rs deleted file mode 100644 index cad0dd2..0000000 --- a/backend/src/hal/oep/scalar_znx.rs +++ /dev/null @@ -1,13 +0,0 @@ -use crate::hal::layouts::{Backend, ScalarZnxOwned}; - -pub unsafe trait ScalarZnxFromBytesImpl { - fn scalar_znx_from_bytes_impl(n: usize, cols: usize, bytes: Vec) -> ScalarZnxOwned; -} - -pub unsafe trait ScalarZnxAllocBytesImpl { - fn scalar_znx_alloc_bytes_impl(n: usize, cols: usize) -> usize; -} - -pub unsafe trait ScalarZnxAllocImpl { - fn scalar_znx_alloc_impl(n: usize, cols: usize) -> ScalarZnxOwned; -} diff --git a/backend/src/hal/oep/vec_znx.rs b/backend/src/hal/oep/vec_znx.rs index 3d61fc0..ae9ce44 100644 --- a/backend/src/hal/oep/vec_znx.rs +++ b/backend/src/hal/oep/vec_znx.rs @@ -2,32 +2,7 @@ use rand_distr::Distribution; use rug::Float; use sampling::source::Source; -use crate::hal::layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxOwned, VecZnxToMut, VecZnxToRef}; - -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [crate::hal::layouts::VecZnx::new] for reference code. -/// * See [crate::hal::api::VecZnxAlloc] for corresponding public API. -/// * See [crate::doc::backend_safety] for safety contract. -/// * See test \[TODO\] -pub unsafe trait VecZnxAllocImpl { - fn vec_znx_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxOwned; -} - -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [crate::hal::layouts::VecZnx::from_bytes] for reference code. -/// * See [crate::hal::api::VecZnxFromBytes] for corresponding public API. -/// * See [crate::doc::backend_safety] for safety contract. -pub unsafe trait VecZnxFromBytesImpl { - fn vec_znx_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec) -> VecZnxOwned; -} - -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [crate::hal::layouts::VecZnx::alloc_bytes] for reference code. -/// * See [crate::hal::api::VecZnxAllocBytes] for corresponding public API. -/// * See [crate::doc::backend_safety] for safety contract. -pub unsafe trait VecZnxAllocBytesImpl { - fn vec_znx_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize; -} +use crate::hal::layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef}; /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See [vec_znx_normalize_base2k_tmp_bytes_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L245C17-L245C55) for reference code. diff --git a/backend/src/hal/oep/vec_znx_dft.rs b/backend/src/hal/oep/vec_znx_dft.rs index 3f2aa0a..e20c710 100644 --- a/backend/src/hal/oep/vec_znx_dft.rs +++ b/backend/src/hal/oep/vec_znx_dft.rs @@ -16,7 +16,7 @@ pub unsafe trait VecZnxDftAllocBytesImpl { } pub unsafe trait VecZnxDftToVecZnxBigTmpBytesImpl { - fn vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(module: &Module) -> usize; + fn vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(module: &Module, n: usize) -> usize; } pub unsafe trait VecZnxDftToVecZnxBigImpl { diff --git a/backend/src/hal/oep/vmp_pmat.rs b/backend/src/hal/oep/vmp_pmat.rs index 56e7299..926ae27 100644 --- a/backend/src/hal/oep/vmp_pmat.rs +++ b/backend/src/hal/oep/vmp_pmat.rs @@ -22,7 +22,14 @@ pub unsafe trait VmpPMatFromBytesImpl { } pub unsafe trait VmpPrepareTmpBytesImpl { - fn vmp_prepare_tmp_bytes_impl(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize; + fn vmp_prepare_tmp_bytes_impl( + module: &Module, + n: usize, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + ) -> usize; } pub unsafe trait VmpPMatPrepareImpl { @@ -35,6 +42,7 @@ pub unsafe trait VmpPMatPrepareImpl { pub unsafe trait VmpApplyTmpBytesImpl { fn vmp_apply_tmp_bytes_impl( module: &Module, + n: usize, res_size: usize, a_size: usize, b_rows: usize, @@ -55,6 +63,7 @@ pub unsafe trait VmpApplyImpl { pub unsafe trait VmpApplyAddTmpBytesImpl { fn vmp_apply_add_tmp_bytes_impl( module: &Module, + n: usize, res_size: usize, a_size: usize, b_rows: usize, diff --git a/backend/src/hal/tests/serialization.rs b/backend/src/hal/tests/serialization.rs index 3d78a4f..b052eb0 100644 --- a/backend/src/hal/tests/serialization.rs +++ b/backend/src/hal/tests/serialization.rs @@ -3,7 +3,7 @@ use std::fmt::Debug; use sampling::source::Source; use crate::hal::{ - api::{FillUniform, ZnxZero}, + api::{FillUniform, Reset}, layouts::{ReaderFrom, WriterTo}, }; @@ -12,7 +12,7 @@ use crate::hal::{ /// - `T` must implement I/O traits, zeroing, cloning, and random filling. pub fn test_reader_writer_interface(mut original: T) where - T: WriterTo + ReaderFrom + PartialEq + Eq + Debug + Clone + ZnxZero + FillUniform, + T: WriterTo + ReaderFrom + PartialEq + Eq + Debug + Clone + Reset + FillUniform, { // Fill original with uniform random data let mut source = Source::new([0u8; 32]); @@ -24,7 +24,7 @@ where // Prepare receiver: same shape, but zeroed let mut receiver = original.clone(); - receiver.zero(); + receiver.reset(); // Deserialize from buffer let mut reader: &[u8] = &buffer; @@ -45,7 +45,7 @@ fn scalar_znx_serialize() { #[test] fn vec_znx_serialize() { - let original: crate::hal::layouts::VecZnx> = crate::hal::layouts::VecZnx::alloc::(1024, 3, 4); + let original: crate::hal::layouts::VecZnx> = crate::hal::layouts::VecZnx::alloc(1024, 3, 4); test_reader_writer_interface(original); } diff --git a/backend/src/hal/tests/vec_znx/generics.rs b/backend/src/hal/tests/vec_znx/generics.rs index 626c170..cedb3a7 100644 --- a/backend/src/hal/tests/vec_znx/generics.rs +++ b/backend/src/hal/tests/vec_znx/generics.rs @@ -2,25 +2,23 @@ use itertools::izip; use sampling::source::Source; use crate::hal::{ - api::{ - VecZnxAddNormal, VecZnxAlloc, VecZnxDecodeVeci64, VecZnxEncodeVeci64, VecZnxFillUniform, VecZnxStd, ZnxInfos, ZnxView, - ZnxViewMut, - }, + api::{VecZnxAddNormal, VecZnxDecodeVeci64, VecZnxEncodeVeci64, VecZnxFillUniform, VecZnxStd, ZnxInfos, ZnxView, ZnxViewMut}, layouts::{Backend, Module, VecZnx}, }; pub fn test_vec_znx_fill_uniform(module: &Module) where - Module: VecZnxFillUniform + VecZnxStd + VecZnxAlloc, + Module: VecZnxFillUniform + VecZnxStd, { + let n: usize = module.n(); let basek: usize = 17; let size: usize = 5; let mut source: Source = Source::new([0u8; 32]); let cols: usize = 2; - let zero: Vec = vec![0; module.n()]; + let zero: Vec = vec![0; n]; let one_12_sqrt: f64 = 0.28867513459481287; (0..cols).for_each(|col_i| { - let mut a: VecZnx<_> = module.vec_znx_alloc(cols, size); + let mut a: VecZnx<_> = VecZnx::alloc(n, cols, size); module.vec_znx_fill_uniform(basek, &mut a, col_i, size * basek, &mut source); (0..cols).for_each(|col_j| { if col_j != col_i { @@ -42,8 +40,9 @@ where pub fn test_vec_znx_add_normal(module: &Module) where - Module: VecZnxAddNormal + VecZnxStd + VecZnxAlloc, + Module: VecZnxAddNormal + VecZnxStd, { + let n: usize = module.n(); let basek: usize = 17; let k: usize = 2 * 17; let size: usize = 5; @@ -51,10 +50,10 @@ where let bound: f64 = 6.0 * sigma; let mut source: Source = Source::new([0u8; 32]); let cols: usize = 2; - let zero: Vec = vec![0; module.n()]; + let zero: Vec = vec![0; n]; let k_f64: f64 = (1u64 << k as u64) as f64; (0..cols).for_each(|col_i| { - let mut a: VecZnx<_> = module.vec_znx_alloc(cols, size); + let mut a: VecZnx<_> = VecZnx::alloc(n, cols, size); module.vec_znx_add_normal(basek, &mut a, col_i, k, &mut source, sigma, bound); (0..cols).for_each(|col_j| { if col_j != col_i { @@ -71,21 +70,22 @@ where pub fn test_vec_znx_encode_vec_i64_lo_norm(module: &Module) where - Module: VecZnxEncodeVeci64 + VecZnxDecodeVeci64 + VecZnxAlloc, + Module: VecZnxEncodeVeci64 + VecZnxDecodeVeci64, { + let n: usize = module.n(); let basek: usize = 17; let size: usize = 5; let k: usize = size * basek - 5; - let mut a: VecZnx<_> = module.vec_znx_alloc(2, size); + let mut a: VecZnx> = VecZnx::alloc(n, 2, size); let mut source: Source = Source::new([0u8; 32]); let raw: &mut [i64] = a.raw_mut(); raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); (0..a.cols()).for_each(|col_i| { - let mut have: Vec = vec![i64::default(); module.n()]; + let mut have: Vec = vec![i64::default(); n]; have.iter_mut() .for_each(|x| *x = (source.next_i64() << 56) >> 56); module.encode_vec_i64(basek, &mut a, col_i, k, &have, 10); - let mut want: Vec = vec![i64::default(); module.n()]; + let mut want: Vec = vec![i64::default(); n]; module.decode_vec_i64(basek, &a, col_i, k, &mut want); izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); }); @@ -93,17 +93,18 @@ where pub fn test_vec_znx_encode_vec_i64_hi_norm(module: &Module) where - Module: VecZnxEncodeVeci64 + VecZnxDecodeVeci64 + VecZnxAlloc, + Module: VecZnxEncodeVeci64 + VecZnxDecodeVeci64, { + let n: usize = module.n(); let basek: usize = 17; let size: usize = 5; for k in [1, basek / 2, size * basek - 5] { - let mut a: VecZnx<_> = module.vec_znx_alloc(2, size); + let mut a: VecZnx> = VecZnx::alloc(n, 2, size); let mut source = Source::new([0u8; 32]); let raw: &mut [i64] = a.raw_mut(); raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); (0..a.cols()).for_each(|col_i| { - let mut have: Vec = vec![i64::default(); module.n()]; + let mut have: Vec = vec![i64::default(); n]; have.iter_mut().for_each(|x| { if k < 64 { *x = source.next_u64n(1 << k, (1 << k) - 1) as i64; @@ -112,7 +113,7 @@ where } }); module.encode_vec_i64(basek, &mut a, col_i, k, &have, 63); - let mut want: Vec = vec![i64::default(); module.n()]; + let mut want: Vec = vec![i64::default(); n]; module.decode_vec_i64(basek, &a, col_i, k, &mut want); izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); }) diff --git a/backend/src/implementation/cpu_spqlios/ffi/vec_znx_dft.rs b/backend/src/implementation/cpu_spqlios/ffi/vec_znx_dft.rs index 00bb2cd..fbf1e49 100644 --- a/backend/src/implementation/cpu_spqlios/ffi/vec_znx_dft.rs +++ b/backend/src/implementation/cpu_spqlios/ffi/vec_znx_dft.rs @@ -56,7 +56,7 @@ unsafe extern "C" { ); } unsafe extern "C" { - pub unsafe fn vec_znx_idft_tmp_bytes(module: *const MODULE) -> u64; + pub unsafe fn vec_znx_idft_tmp_bytes(module: *const MODULE, n: u64) -> u64; } unsafe extern "C" { pub unsafe fn vec_znx_idft_tmp_a( @@ -67,19 +67,3 @@ unsafe extern "C" { a_size: u64, ); } - -unsafe extern "C" { - pub unsafe fn vec_znx_dft_automorphism( - module: *const MODULE, - d: i64, - res_dft: *mut VEC_ZNX_DFT, - res_size: u64, - a_dft: *const VEC_ZNX_DFT, - a_size: u64, - tmp: *mut u8, - ); -} - -unsafe extern "C" { - pub unsafe fn vec_znx_dft_automorphism_tmp_bytes(module: *const MODULE) -> u64; -} diff --git a/backend/src/implementation/cpu_spqlios/ffi/vmp.rs b/backend/src/implementation/cpu_spqlios/ffi/vmp.rs index c742cea..b9ae29a 100644 --- a/backend/src/implementation/cpu_spqlios/ffi/vmp.rs +++ b/backend/src/implementation/cpu_spqlios/ffi/vmp.rs @@ -86,6 +86,7 @@ unsafe extern "C" { unsafe extern "C" { pub unsafe fn vmp_apply_dft_to_dft_tmp_bytes( module: *const MODULE, + nn: u64, res_size: u64, a_size: u64, nrows: u64, @@ -109,5 +110,5 @@ unsafe extern "C" { } unsafe extern "C" { - pub unsafe fn vmp_prepare_tmp_bytes(module: *const MODULE, nrows: u64, ncols: u64) -> u64; + pub unsafe fn vmp_prepare_tmp_bytes(module: *const MODULE, nn: u64, nrows: u64, ncols: u64) -> u64; } diff --git a/backend/src/implementation/cpu_spqlios/mat_znx.rs b/backend/src/implementation/cpu_spqlios/mat_znx.rs deleted file mode 100644 index f6b7294..0000000 --- a/backend/src/implementation/cpu_spqlios/mat_znx.rs +++ /dev/null @@ -1,41 +0,0 @@ -use crate::{ - hal::{ - layouts::{Backend, MatZnxOwned, Module}, - oep::{MatZnxAllocBytesImpl, MatZnxAllocImpl, MatZnxFromBytesImpl}, - }, - implementation::cpu_spqlios::CPUAVX, -}; - -unsafe impl MatZnxAllocImpl for B -where - B: CPUAVX, -{ - fn mat_znx_alloc_impl(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxOwned { - MatZnxOwned::alloc(module.n(), rows, cols_in, cols_out, size) - } -} - -unsafe impl MatZnxAllocBytesImpl for B -where - B: CPUAVX, -{ - fn mat_znx_alloc_bytes_impl(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { - MatZnxOwned::bytes_of(module.n(), rows, cols_in, cols_out, size) - } -} - -unsafe impl MatZnxFromBytesImpl for B -where - B: CPUAVX, -{ - fn mat_znx_from_bytes_impl( - module: &Module, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - bytes: Vec, - ) -> MatZnxOwned { - MatZnxOwned::from_bytes(module.n(), rows, cols_in, cols_out, size, bytes) - } -} diff --git a/backend/src/implementation/cpu_spqlios/mod.rs b/backend/src/implementation/cpu_spqlios/mod.rs index dcb5b9c..570a23b 100644 --- a/backend/src/implementation/cpu_spqlios/mod.rs +++ b/backend/src/implementation/cpu_spqlios/mod.rs @@ -1,8 +1,6 @@ mod ffi; -mod mat_znx; mod module_fft64; mod module_ntt120; -mod scalar_znx; mod scratch; mod svp_ppol_fft64; mod svp_ppol_ntt120; diff --git a/backend/src/implementation/cpu_spqlios/scalar_znx.rs b/backend/src/implementation/cpu_spqlios/scalar_znx.rs deleted file mode 100644 index e83c722..0000000 --- a/backend/src/implementation/cpu_spqlios/scalar_znx.rs +++ /dev/null @@ -1,34 +0,0 @@ -use crate::{ - hal::{ - layouts::{Backend, ScalarZnxOwned}, - oep::{ScalarZnxAllocBytesImpl, ScalarZnxAllocImpl, ScalarZnxFromBytesImpl}, - }, - implementation::cpu_spqlios::CPUAVX, -}; - -unsafe impl ScalarZnxAllocBytesImpl for B -where - B: CPUAVX, -{ - fn scalar_znx_alloc_bytes_impl(n: usize, cols: usize) -> usize { - ScalarZnxOwned::bytes_of(n, cols) - } -} - -unsafe impl ScalarZnxAllocImpl for B -where - B: CPUAVX, -{ - fn scalar_znx_alloc_impl(n: usize, cols: usize) -> ScalarZnxOwned { - ScalarZnxOwned::alloc(n, cols) - } -} - -unsafe impl ScalarZnxFromBytesImpl for B -where - B: CPUAVX, -{ - fn scalar_znx_from_bytes_impl(n: usize, cols: usize, bytes: Vec) -> ScalarZnxOwned { - ScalarZnxOwned::from_bytes(n, cols, bytes) - } -} diff --git a/backend/src/implementation/cpu_spqlios/scratch.rs b/backend/src/implementation/cpu_spqlios/scratch.rs index 1f234c4..9dad5e7 100644 --- a/backend/src/implementation/cpu_spqlios/scratch.rs +++ b/backend/src/implementation/cpu_spqlios/scratch.rs @@ -6,10 +6,10 @@ use crate::{ api::ScratchFromBytes, layouts::{Backend, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}, oep::{ - ScalarZnxAllocBytesImpl, ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, - SvpPPolAllocBytesImpl, TakeMatZnxImpl, TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl, - TakeVecZnxDftImpl, TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, TakeVmpPMatImpl, - VecZnxAllocBytesImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl, VmpPMatAllocBytesImpl, + ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, SvpPPolAllocBytesImpl, + TakeMatZnxImpl, TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl, + TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, TakeVmpPMatImpl, VecZnxBigAllocBytesImpl, + VecZnxDftAllocBytesImpl, VmpPMatAllocBytesImpl, }, }, implementation::cpu_spqlios::CPUAVX, @@ -76,10 +76,10 @@ where unsafe impl TakeScalarZnxImpl for B where - B: CPUAVX + ScalarZnxAllocBytesImpl, + B: CPUAVX, { fn take_scalar_znx_impl(scratch: &mut Scratch, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, B::scalar_znx_alloc_bytes_impl(n, cols)); + let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, ScalarZnx::alloc_bytes(n, cols)); ( ScalarZnx::from_data(take_slice, n, cols), Scratch::from_bytes(rem_slice), @@ -102,13 +102,10 @@ where unsafe impl TakeVecZnxImpl for B where - B: CPUAVX + VecZnxAllocBytesImpl, + B: CPUAVX, { fn take_vec_znx_impl(scratch: &mut Scratch, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Scratch) { - let (take_slice, rem_slice) = take_slice_aligned( - &mut scratch.data, - B::vec_znx_alloc_bytes_impl(n, cols, size), - ); + let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, VecZnx::alloc_bytes(n, cols, size)); ( VecZnx::from_data(take_slice, n, cols, size), Scratch::from_bytes(rem_slice), @@ -240,7 +237,7 @@ where ) -> (MatZnx<&mut [u8]>, &mut Scratch) { let (take_slice, rem_slice) = take_slice_aligned( &mut scratch.data, - MatZnx::>::bytes_of(n, rows, cols_in, cols_out, size), + MatZnx::alloc_bytes(n, rows, cols_in, cols_out, size), ); ( MatZnx::from_data(take_slice, n, rows, cols_in, cols_out, size), diff --git a/backend/src/implementation/cpu_spqlios/spqlios-arithmetic b/backend/src/implementation/cpu_spqlios/spqlios-arithmetic index 7160f58..de62af3 160000 --- a/backend/src/implementation/cpu_spqlios/spqlios-arithmetic +++ b/backend/src/implementation/cpu_spqlios/spqlios-arithmetic @@ -1 +1 @@ -Subproject commit 7160f588da49712a042931ea247b4259b95cefcc +Subproject commit de62af3507776597231e0c0d2b26495a0c92d207 diff --git a/backend/src/implementation/cpu_spqlios/vec_znx.rs b/backend/src/implementation/cpu_spqlios/vec_znx.rs index faecee1..e1da258 100644 --- a/backend/src/implementation/cpu_spqlios/vec_znx.rs +++ b/backend/src/implementation/cpu_spqlios/vec_znx.rs @@ -14,16 +14,16 @@ use crate::{ VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxSwithcDegree, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, }, - layouts::{Backend, Module, ScalarZnx, ScalarZnxToRef, Scratch, VecZnx, VecZnxOwned, VecZnxToMut, VecZnxToRef}, + layouts::{Backend, Module, ScalarZnx, ScalarZnxToRef, Scratch, VecZnx, VecZnxToMut, VecZnxToRef}, oep::{ VecZnxAddDistF64Impl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarInplaceImpl, - VecZnxAllocBytesImpl, VecZnxAllocImpl, VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, - VecZnxDecodeCoeffsi64Impl, VecZnxDecodeVecFloatImpl, VecZnxDecodeVeci64Impl, VecZnxEncodeCoeffsi64Impl, - VecZnxEncodeVeci64Impl, VecZnxFillDistF64Impl, VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxFromBytesImpl, - VecZnxLshInplaceImpl, VecZnxMergeImpl, VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl, - VecZnxNegateInplaceImpl, VecZnxNormalizeImpl, VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, - VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl, VecZnxSplitImpl, VecZnxStdImpl, - VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl, + VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, VecZnxDecodeCoeffsi64Impl, + VecZnxDecodeVecFloatImpl, VecZnxDecodeVeci64Impl, VecZnxEncodeCoeffsi64Impl, VecZnxEncodeVeci64Impl, + VecZnxFillDistF64Impl, VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxLshInplaceImpl, VecZnxMergeImpl, + VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, + VecZnxNormalizeImpl, VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, + VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl, VecZnxSplitImpl, VecZnxStdImpl, VecZnxSubABInplaceImpl, + VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl, }, }, implementation::cpu_spqlios::{ @@ -32,33 +32,6 @@ use crate::{ }, }; -unsafe impl VecZnxAllocImpl for B -where - B: CPUAVX, -{ - fn vec_znx_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxOwned { - VecZnxOwned::alloc::(n, cols, size) - } -} - -unsafe impl VecZnxFromBytesImpl for B -where - B: CPUAVX, -{ - fn vec_znx_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec) -> VecZnxOwned { - VecZnxOwned::from_bytes::(n, cols, size, bytes) - } -} - -unsafe impl VecZnxAllocBytesImpl for B -where - B: CPUAVX, -{ - fn vec_znx_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize { - VecZnxOwned::alloc_bytes::(n, cols, size) - } -} - unsafe impl VecZnxNormalizeTmpBytesImpl for B where B: CPUAVX, @@ -156,9 +129,8 @@ where #[cfg(debug_assertions)] { - assert_eq!(a.n(), module.n()); - assert_eq!(b.n(), module.n()); - assert_eq!(res.n(), module.n()); + assert_eq!(a.n(), res.n()); + assert_eq!(b.n(), res.n()); assert_ne!(a.as_ptr(), b.as_ptr()); } unsafe { @@ -192,8 +164,7 @@ where #[cfg(debug_assertions)] { - assert_eq!(a.n(), module.n()); - assert_eq!(res.n(), module.n()); + assert_eq!(a.n(), res.n()); } unsafe { vec_znx::vec_znx_add( @@ -232,8 +203,7 @@ where #[cfg(debug_assertions)] { - assert_eq!(a.n(), module.n()); - assert_eq!(res.n(), module.n()); + assert_eq!(a.n(), res.n()); } unsafe { @@ -269,9 +239,8 @@ where #[cfg(debug_assertions)] { - assert_eq!(a.n(), module.n()); - assert_eq!(b.n(), module.n()); - assert_eq!(res.n(), module.n()); + assert_eq!(a.n(), res.n()); + assert_eq!(b.n(), res.n()); assert_ne!(a.as_ptr(), b.as_ptr()); } unsafe { @@ -304,8 +273,7 @@ where let mut res: VecZnx<&mut [u8]> = res.to_mut(); #[cfg(debug_assertions)] { - assert_eq!(a.n(), module.n()); - assert_eq!(res.n(), module.n()); + assert_eq!(a.n(), res.n()); } unsafe { vec_znx::vec_znx_sub( @@ -337,8 +305,7 @@ where let mut res: VecZnx<&mut [u8]> = res.to_mut(); #[cfg(debug_assertions)] { - assert_eq!(a.n(), module.n()); - assert_eq!(res.n(), module.n()); + assert_eq!(a.n(), res.n()); } unsafe { vec_znx::vec_znx_sub( @@ -377,8 +344,7 @@ where #[cfg(debug_assertions)] { - assert_eq!(a.n(), module.n()); - assert_eq!(res.n(), module.n()); + assert_eq!(a.n(), res.n()); } unsafe { @@ -411,8 +377,7 @@ where let mut res: VecZnx<&mut [u8]> = res.to_mut(); #[cfg(debug_assertions)] { - assert_eq!(a.n(), module.n()); - assert_eq!(res.n(), module.n()); + assert_eq!(a.n(), res.n()); } unsafe { vec_znx::vec_znx_negate( @@ -437,10 +402,6 @@ where A: VecZnxToMut, { let mut a: VecZnx<&mut [u8]> = a.to_mut(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), module.n()); - } unsafe { vec_znx::vec_znx_negate( module.ptr() as *const module_info_t, @@ -604,8 +565,7 @@ where let mut res: VecZnx<&mut [u8]> = res.to_mut(); #[cfg(debug_assertions)] { - assert_eq!(a.n(), module.n()); - assert_eq!(res.n(), module.n()); + assert_eq!(a.n(), res.n()); } unsafe { vec_znx::vec_znx_automorphism( @@ -633,7 +593,6 @@ where let mut a: VecZnx<&mut [u8]> = a.to_mut(); #[cfg(debug_assertions)] { - assert_eq!(a.n(), module.n()); assert!( k & 1 != 0, "invalid galois element: must be odd but is {}", @@ -668,8 +627,8 @@ where let mut res: VecZnx<&mut [u8]> = res.to_mut(); #[cfg(debug_assertions)] { - assert_eq!(a.n(), module.n()); - assert_eq!(res.n(), module.n()); + assert_eq!(a.n(), res.n()); + assert_eq!(res.n(), res.n()); } unsafe { vec_znx::vec_znx_mul_xp_minus_one( @@ -697,7 +656,7 @@ where let mut res: VecZnx<&mut [u8]> = res.to_mut(); #[cfg(debug_assertions)] { - assert_eq!(res.n(), module.n()); + assert_eq!(res.n(), res.n()); } unsafe { vec_znx::vec_znx_mul_xp_minus_one( @@ -749,7 +708,7 @@ pub fn vec_znx_split_ref( let (n_in, n_out) = (a.n(), res[0].to_mut().n()); - let (mut buf, _) = scratch.take_vec_znx(module, 1, a.size()); + let (mut buf, _) = scratch.take_vec_znx(n_in.max(n_out), 1, a.size()); debug_assert!( n_out < n_in, diff --git a/backend/src/implementation/cpu_spqlios/vec_znx_big_fft64.rs b/backend/src/implementation/cpu_spqlios/vec_znx_big_fft64.rs index 86fe6cc..e4edf5c 100644 --- a/backend/src/implementation/cpu_spqlios/vec_znx_big_fft64.rs +++ b/backend/src/implementation/cpu_spqlios/vec_znx_big_fft64.rs @@ -210,9 +210,8 @@ unsafe impl VecZnxBigAddImpl for FFT64 { #[cfg(debug_assertions)] { - assert_eq!(a.n(), module.n()); - assert_eq!(b.n(), module.n()); - assert_eq!(res.n(), module.n()); + assert_eq!(a.n(), res.n()); + assert_eq!(b.n(), res.n()); assert_ne!(a.as_ptr(), b.as_ptr()); } unsafe { @@ -244,8 +243,7 @@ unsafe impl VecZnxBigAddInplaceImpl for FFT64 { #[cfg(debug_assertions)] { - assert_eq!(a.n(), module.n()); - assert_eq!(res.n(), module.n()); + assert_eq!(a.n(), res.n()); } unsafe { vec_znx::vec_znx_add( @@ -285,9 +283,8 @@ unsafe impl VecZnxBigAddSmallImpl for FFT64 { #[cfg(debug_assertions)] { - assert_eq!(a.n(), module.n()); - assert_eq!(b.n(), module.n()); - assert_eq!(res.n(), module.n()); + assert_eq!(a.n(), res.n()); + assert_eq!(b.n(), res.n()); assert_ne!(a.as_ptr(), b.as_ptr()); } unsafe { @@ -319,8 +316,7 @@ unsafe impl VecZnxBigAddSmallInplaceImpl for FFT64 { #[cfg(debug_assertions)] { - assert_eq!(a.n(), module.n()); - assert_eq!(res.n(), module.n()); + assert_eq!(a.n(), res.n()); } unsafe { vec_znx::vec_znx_add( @@ -360,9 +356,8 @@ unsafe impl VecZnxBigSubImpl for FFT64 { #[cfg(debug_assertions)] { - assert_eq!(a.n(), module.n()); - assert_eq!(b.n(), module.n()); - assert_eq!(res.n(), module.n()); + assert_eq!(a.n(), res.n()); + assert_eq!(b.n(), res.n()); assert_ne!(a.as_ptr(), b.as_ptr()); } unsafe { @@ -394,8 +389,7 @@ unsafe impl VecZnxBigSubABInplaceImpl for FFT64 { #[cfg(debug_assertions)] { - assert_eq!(a.n(), module.n()); - assert_eq!(res.n(), module.n()); + assert_eq!(a.n(), res.n()); } unsafe { vec_znx::vec_znx_sub( @@ -426,8 +420,7 @@ unsafe impl VecZnxBigSubBAInplaceImpl for FFT64 { #[cfg(debug_assertions)] { - assert_eq!(a.n(), module.n()); - assert_eq!(res.n(), module.n()); + assert_eq!(a.n(), res.n()); } unsafe { vec_znx::vec_znx_sub( @@ -467,9 +460,8 @@ unsafe impl VecZnxBigSubSmallAImpl for FFT64 { #[cfg(debug_assertions)] { - assert_eq!(a.n(), module.n()); - assert_eq!(b.n(), module.n()); - assert_eq!(res.n(), module.n()); + assert_eq!(a.n(), res.n()); + assert_eq!(b.n(), res.n()); assert_ne!(a.as_ptr(), b.as_ptr()); } unsafe { @@ -501,8 +493,7 @@ unsafe impl VecZnxBigSubSmallAInplaceImpl for FFT64 { #[cfg(debug_assertions)] { - assert_eq!(a.n(), module.n()); - assert_eq!(res.n(), module.n()); + assert_eq!(a.n(), res.n()); } unsafe { vec_znx::vec_znx_sub( @@ -542,9 +533,8 @@ unsafe impl VecZnxBigSubSmallBImpl for FFT64 { #[cfg(debug_assertions)] { - assert_eq!(a.n(), module.n()); - assert_eq!(b.n(), module.n()); - assert_eq!(res.n(), module.n()); + assert_eq!(a.n(), res.n()); + assert_eq!(b.n(), res.n()); assert_ne!(a.as_ptr(), b.as_ptr()); } unsafe { @@ -576,8 +566,7 @@ unsafe impl VecZnxBigSubSmallBInplaceImpl for FFT64 { #[cfg(debug_assertions)] { - assert_eq!(a.n(), module.n()); - assert_eq!(res.n(), module.n()); + assert_eq!(a.n(), res.n()); } unsafe { vec_znx::vec_znx_sub( @@ -602,10 +591,6 @@ unsafe impl VecZnxBigNegateInplaceImpl for FFT64 { A: VecZnxBigToMut, { let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), module.n()); - } unsafe { vec_znx::vec_znx_negate( module.ptr(), @@ -677,8 +662,7 @@ unsafe impl VecZnxBigAutomorphismImpl for FFT64 { #[cfg(debug_assertions)] { - assert_eq!(a.n(), module.n()); - assert_eq!(res.n(), module.n()); + assert_eq!(a.n(), res.n()); } unsafe { vec_znx::vec_znx_automorphism( @@ -702,11 +686,6 @@ unsafe impl VecZnxBigAutomorphismInplaceImpl for FFT64 { A: VecZnxBigToMut, { let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut(); - - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), module.n()); - } unsafe { vec_znx::vec_znx_automorphism( module.ptr(), diff --git a/backend/src/implementation/cpu_spqlios/vec_znx_dft_fft64.rs b/backend/src/implementation/cpu_spqlios/vec_znx_dft_fft64.rs index 2d9559b..768e5c4 100644 --- a/backend/src/implementation/cpu_spqlios/vec_znx_dft_fft64.rs +++ b/backend/src/implementation/cpu_spqlios/vec_znx_dft_fft64.rs @@ -57,8 +57,8 @@ unsafe impl VecZnxDftAllocImpl for FFT64 { } unsafe impl VecZnxDftToVecZnxBigTmpBytesImpl for FFT64 { - fn vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(module: &Module) -> usize { - unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(module.ptr()) as usize } + fn vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(module: &Module, n: usize) -> usize { + unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(module.ptr(), n as u64) as usize } } } @@ -74,26 +74,31 @@ unsafe impl VecZnxDftToVecZnxBigImpl for FFT64 { R: VecZnxBigToMut, A: VecZnxDftToRef, { - let mut res_mut: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); - let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + let a: VecZnxDft<&[u8], FFT64> = a.to_ref(); - let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_dft_to_vec_znx_big_tmp_bytes()); + #[cfg(debug_assertions)] + { + assert_eq!(res.n(), a.n()) + } - let min_size: usize = res_mut.size().min(a_ref.size()); + let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_dft_to_vec_znx_big_tmp_bytes(a.n())); + + let min_size: usize = res.size().min(a.size()); unsafe { (0..min_size).for_each(|j| { vec_znx_dft::vec_znx_idft( module.ptr(), - res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t, + res.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t, 1 as u64, - a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, + a.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, 1 as u64, tmp_bytes.as_mut_ptr(), ) }); - (min_size..res_mut.size()).for_each(|j| { - res_mut.zero_at(res_col, j); + (min_size..res.size()).for_each(|j| { + res.zero_at(res_col, j); }); } } diff --git a/backend/src/implementation/cpu_spqlios/vmp_pmat_fft64.rs b/backend/src/implementation/cpu_spqlios/vmp_pmat_fft64.rs index 8b1106f..8730880 100644 --- a/backend/src/implementation/cpu_spqlios/vmp_pmat_fft64.rs +++ b/backend/src/implementation/cpu_spqlios/vmp_pmat_fft64.rs @@ -57,10 +57,18 @@ unsafe impl VmpPMatAllocImpl for FFT64 { } unsafe impl VmpPrepareTmpBytesImpl for FFT64 { - fn vmp_prepare_tmp_bytes_impl(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { + fn vmp_prepare_tmp_bytes_impl( + module: &Module, + n: usize, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + ) -> usize { unsafe { vmp::vmp_prepare_tmp_bytes( module.ptr(), + n as u64, (rows * cols_in) as u64, (cols_out * size) as u64, ) as usize @@ -79,8 +87,7 @@ unsafe impl VmpPMatPrepareImpl for FFT64 { #[cfg(debug_assertions)] { - assert_eq!(res.n(), module.n()); - assert_eq!(a.n(), module.n()); + assert_eq!(a.n(), res.n()); assert_eq!( res.cols_in(), a.cols_in(), @@ -111,7 +118,8 @@ unsafe impl VmpPMatPrepareImpl for FFT64 { ); } - let (tmp_bytes, _) = scratch.take_slice(module.vmp_prepare_tmp_bytes(a.rows(), a.cols_in(), a.cols_out(), a.size())); + let (tmp_bytes, _) = + scratch.take_slice(module.vmp_prepare_tmp_bytes(res.n(), a.rows(), a.cols_in(), a.cols_out(), a.size())); unsafe { vmp::vmp_prepare_contiguous( @@ -129,6 +137,7 @@ unsafe impl VmpPMatPrepareImpl for FFT64 { unsafe impl VmpApplyTmpBytesImpl for FFT64 { fn vmp_apply_tmp_bytes_impl( module: &Module, + n: usize, res_size: usize, a_size: usize, b_rows: usize, @@ -139,6 +148,7 @@ unsafe impl VmpApplyTmpBytesImpl for FFT64 { unsafe { vmp::vmp_apply_dft_to_dft_tmp_bytes( module.ptr(), + n as u64, (res_size * b_cols_out) as u64, (a_size * b_cols_in) as u64, (b_rows * b_cols_in) as u64, @@ -161,9 +171,8 @@ unsafe impl VmpApplyImpl for FFT64 { #[cfg(debug_assertions)] { - assert_eq!(res.n(), module.n()); - assert_eq!(b.n(), module.n()); - assert_eq!(a.n(), module.n()); + assert_eq!(b.n(), res.n()); + assert_eq!(a.n(), res.n()); assert_eq!( res.cols(), b.cols_out(), @@ -181,6 +190,7 @@ unsafe impl VmpApplyImpl for FFT64 { } let (tmp_bytes, _) = scratch.take_slice(module.vmp_apply_tmp_bytes( + res.n(), res.size(), a.size(), b.rows(), @@ -207,6 +217,7 @@ unsafe impl VmpApplyImpl for FFT64 { unsafe impl VmpApplyAddTmpBytesImpl for FFT64 { fn vmp_apply_add_tmp_bytes_impl( module: &Module, + n: usize, res_size: usize, a_size: usize, b_rows: usize, @@ -217,6 +228,7 @@ unsafe impl VmpApplyAddTmpBytesImpl for FFT64 { unsafe { vmp::vmp_apply_dft_to_dft_tmp_bytes( module.ptr(), + n as u64, (res_size * b_cols_out) as u64, (a_size * b_cols_in) as u64, (b_rows * b_cols_in) as u64, @@ -241,9 +253,8 @@ unsafe impl VmpApplyAddImpl for FFT64 { { use crate::hal::api::ZnxInfos; - assert_eq!(res.n(), module.n()); - assert_eq!(b.n(), module.n()); - assert_eq!(a.n(), module.n()); + assert_eq!(b.n(), res.n()); + assert_eq!(a.n(), res.n()); assert_eq!( res.cols(), b.cols_out(), @@ -261,6 +272,7 @@ unsafe impl VmpApplyAddImpl for FFT64 { } let (tmp_bytes, _) = scratch.take_slice(module.vmp_apply_tmp_bytes( + res.n(), res.size(), a.size(), b.rows(), diff --git a/core/benches/external_product_glwe_fft64.rs b/core/benches/external_product_glwe_fft64.rs index 6fa0f0d..b0c9c1b 100644 --- a/core/benches/external_product_glwe_fft64.rs +++ b/core/benches/external_product_glwe_fft64.rs @@ -3,7 +3,7 @@ use std::hint::black_box; use backend::{ hal::{ - api::{ModuleNew, ScalarZnxAlloc, ScratchOwnedAlloc, ScratchOwnedBorrow}, + api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow}, layouts::{Module, ScalarZnx, ScratchOwned}, }, implementation::cpu_spqlios::FFT64, @@ -26,6 +26,7 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) { fn runner(p: Params) -> impl FnMut() { let module: Module = Module::::new(1 << p.log_n); + let n: usize = module.n(); let basek: usize = p.basek; let k_ct_in: usize = p.k_ct_in; let k_ct_out: usize = p.k_ct_out; @@ -36,16 +37,17 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) { let rows: usize = 1; //(p.k_ct_in.div_ceil(p.basek); let sigma: f64 = 3.2; - let mut ct_ggsw: GGSWCiphertext> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank); - let mut ct_glwe_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct_in, rank); - let mut ct_glwe_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct_out, rank); - let pt_rgsw: ScalarZnx> = module.scalar_znx_alloc(1); + let mut ct_ggsw: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ggsw, rows, digits, rank); + let mut ct_glwe_in: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct_in, rank); + let mut ct_glwe_out: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct_out, rank); + let pt_rgsw: ScalarZnx> = ScalarZnx::alloc(n, 1); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, ct_ggsw.k(), rank) - | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_glwe_in.k()) + GGSWCiphertext::encrypt_sk_scratch_space(&module, n, basek, ct_ggsw.k(), rank) + | GLWECiphertext::encrypt_sk_scratch_space(&module, n, basek, ct_glwe_in.k()) | GLWECiphertext::external_product_scratch_space( &module, + n, basek, ct_glwe_out.k(), ct_glwe_in.k(), @@ -59,7 +61,7 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) { let mut source_xe = Source::new([0u8; 32]); let mut source_xa = Source::new([0u8; 32]); - let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_dft: GLWESecretExec, FFT64> = GLWESecretExec::from(&module, &sk); @@ -121,6 +123,7 @@ fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) { fn runner(p: Params) -> impl FnMut() { let module: Module = Module::::new(1 << p.log_n); + let n = module.n(); let basek: usize = p.basek; let k_glwe: usize = p.k_ct; let k_ggsw: usize = p.k_ggsw; @@ -130,21 +133,29 @@ fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) { let rows: usize = p.k_ct.div_ceil(p.basek); let sigma: f64 = 3.2; - let mut ct_ggsw: GGSWCiphertext> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank); - let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_glwe, rank); - let pt_rgsw: ScalarZnx> = module.scalar_znx_alloc(1); + let mut ct_ggsw: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ggsw, rows, digits, rank); + let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_glwe, rank); + let pt_rgsw: ScalarZnx> = ScalarZnx::alloc(n, 1); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, ct_ggsw.k(), rank) - | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_glwe.k()) - | GLWECiphertext::external_product_inplace_scratch_space(&module, basek, ct_glwe.k(), ct_ggsw.k(), digits, rank), + GGSWCiphertext::encrypt_sk_scratch_space(&module, n, basek, ct_ggsw.k(), rank) + | GLWECiphertext::encrypt_sk_scratch_space(&module, n, basek, ct_glwe.k()) + | GLWECiphertext::external_product_inplace_scratch_space( + &module, + n, + basek, + ct_glwe.k(), + ct_ggsw.k(), + digits, + rank, + ), ); let mut source_xs = Source::new([0u8; 32]); let mut source_xe = Source::new([0u8; 32]); let mut source_xa = Source::new([0u8; 32]); - let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_dft: GLWESecretExec, FFT64> = GLWESecretExec::from(&module, &sk); diff --git a/core/benches/keyswitch_glwe_fft64.rs b/core/benches/keyswitch_glwe_fft64.rs index 696615d..2a41836 100644 --- a/core/benches/keyswitch_glwe_fft64.rs +++ b/core/benches/keyswitch_glwe_fft64.rs @@ -31,6 +31,7 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { fn runner(p: Params) -> impl FnMut() { let module: Module = Module::::new(1 << p.log_n); + let n = module.n(); let basek: usize = p.basek; let k_rlwe_in: usize = p.k_ct_in; let k_rlwe_out: usize = p.k_ct_out; @@ -42,15 +43,16 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { let rows: usize = p.k_ct_in.div_ceil(p.basek * digits); let sigma: f64 = 3.2; - let mut ksk: AutomorphismKey> = AutomorphismKey::alloc(&module, basek, k_grlwe, rows, digits, rank_out); - let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_rlwe_in, rank_in); - let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_rlwe_out, rank_out); + let mut ksk: AutomorphismKey> = AutomorphismKey::alloc(n, basek, k_grlwe, rows, digits, rank_out); + let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_rlwe_in, rank_in); + let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_rlwe_out, rank_out); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank_in, rank_out) - | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_in.k()) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, n, basek, ksk.k(), rank_in, rank_out) + | GLWECiphertext::encrypt_sk_scratch_space(&module, n, basek, ct_in.k()) | GLWECiphertext::keyswitch_scratch_space( &module, + n, basek, ct_out.k(), ct_in.k(), @@ -65,11 +67,11 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { let mut source_xe = Source::new([0u8; 32]); let mut source_xa = Source::new([0u8; 32]); - let mut sk_in: GLWESecret> = GLWESecret::alloc(&module, rank_in); + let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank_in); sk_in.fill_ternary_prob(0.5, &mut source_xs); let sk_in_dft: GLWESecretExec, FFT64> = GLWESecretExec::from(&module, &sk_in); - let mut sk_out: GLWESecret> = GLWESecret::alloc(&module, rank_out); + let mut sk_out: GLWESecret> = GLWESecret::alloc(n, rank_out); sk_out.fill_ternary_prob(0.5, &mut source_xs); ksk.encrypt_sk( @@ -137,6 +139,7 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) { fn runner(p: Params) -> impl FnMut() { let module: Module = Module::::new(1 << p.log_n); + let n = module.n(); let basek: usize = p.basek; let k_ct: usize = p.k_ct; let k_ksk: usize = p.k_ksk; @@ -146,24 +149,24 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) { let rows: usize = p.k_ct.div_ceil(p.basek); let sigma: f64 = 3.2; - let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank, rank); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); + let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc(n, basek, k_ksk, rows, digits, rank, rank); + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct, rank); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank, rank) - | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) - | GLWECiphertext::keyswitch_inplace_scratch_space(&module, basek, ct.k(), ksk.k(), digits, rank), + GLWESwitchingKey::encrypt_sk_scratch_space(&module, n, basek, ksk.k(), rank, rank) + | GLWECiphertext::encrypt_sk_scratch_space(&module, n, basek, ct.k()) + | GLWECiphertext::keyswitch_inplace_scratch_space(&module, n, basek, ct.k(), ksk.k(), digits, rank), ); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut sk_in: GLWESecret> = GLWESecret::alloc(&module, rank); + let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank); sk_in.fill_ternary_prob(0.5, &mut source_xs); let sk_in_dft: GLWESecretExec, FFT64> = GLWESecretExec::from(&module, &sk_in); - let mut sk_out: GLWESecret> = GLWESecret::alloc(&module, rank); + let mut sk_out: GLWESecret> = GLWESecret::alloc(n, rank); sk_out.fill_ternary_prob(0.5, &mut source_xs); ksk.encrypt_sk( diff --git a/core/src/blind_rotation/cggi.rs b/core/src/blind_rotation/cggi.rs index 796daef..644fe65 100644 --- a/core/src/blind_rotation/cggi.rs +++ b/core/src/blind_rotation/cggi.rs @@ -1,20 +1,20 @@ use backend::hal::{ api::{ ScratchAvailable, SvpApply, SvpPPolAllocBytes, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice, - TakeVecZnxSlice, VecZnxAddInplace, VecZnxAllocBytes, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, - VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, - VecZnxDftSubABInplace, VecZnxDftToVecZnxBig, VecZnxDftToVecZnxBigTmpBytes, VecZnxDftZero, VecZnxMulXpMinusOneInplace, - VecZnxNormalize, VecZnxNormalizeInplace, VecZnxRotate, VecZnxSubABInplace, VmpApplyTmpBytes, ZnxView, ZnxZero, + TakeVecZnxSlice, VecZnxAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalizeTmpBytes, VecZnxCopy, + VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftSubABInplace, VecZnxDftToVecZnxBig, + VecZnxDftToVecZnxBigTmpBytes, VecZnxDftZero, VecZnxMulXpMinusOneInplace, VecZnxNormalize, VecZnxNormalizeInplace, + VecZnxRotate, VecZnxSubABInplace, VmpApplyTmpBytes, ZnxView, ZnxZero, }, - layouts::{Backend, DataMut, DataRef, Module, Scratch, SvpPPol}, + layouts::{Backend, DataMut, DataRef, Module, Scratch, SvpPPol, VecZnx}, }; use itertools::izip; use crate::{ - GLWECiphertext, GLWECiphertextToMut, GLWEExternalProductFamily, GLWEOps, Infos, LWECiphertext, TakeGLWECt, + GLWECiphertext, GLWECiphertextToMut, GLWEExternalProductFamily, GLWEOps, Infos, LWECiphertext, LWECiphertextToRef, + TakeGLWECt, blind_rotation::{key::BlindRotationKeyCGGIExec, lut::LookUpTable}, dist::Distribution, - lwe::ciphertext::LWECiphertextToRef, }; pub trait CCGIBlindRotationFamily = VecZnxBigAllocBytes @@ -42,6 +42,7 @@ pub trait CCGIBlindRotationFamily = VecZnxBigAllocBytes pub fn cggi_blind_rotate_scratch_space( module: &Module, + n: usize, block_size: usize, extension_factor: usize, basek: usize, @@ -51,22 +52,22 @@ pub fn cggi_blind_rotate_scratch_space( rank: usize, ) -> usize where - Module: CCGIBlindRotationFamily + VecZnxAllocBytes, + Module: CCGIBlindRotationFamily, { let brk_size: usize = k_brk.div_ceil(basek); if block_size > 1 { let cols: usize = rank + 1; - let acc_dft: usize = module.vec_znx_dft_alloc_bytes(cols, rows) * extension_factor; - let acc_big: usize = module.vec_znx_big_alloc_bytes(1, brk_size); - let vmp_res: usize = module.vec_znx_dft_alloc_bytes(cols, brk_size) * extension_factor; - let vmp_xai: usize = module.vec_znx_dft_alloc_bytes(1, brk_size); + let acc_dft: usize = module.vec_znx_dft_alloc_bytes(n, cols, rows) * extension_factor; + let acc_big: usize = module.vec_znx_big_alloc_bytes(n, 1, brk_size); + let vmp_res: usize = module.vec_znx_dft_alloc_bytes(n, cols, brk_size) * extension_factor; + let vmp_xai: usize = module.vec_znx_dft_alloc_bytes(n, 1, brk_size); let acc_dft_add: usize = vmp_res; - let vmp: usize = module.vmp_apply_tmp_bytes(brk_size, rows, rows, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2) + let vmp: usize = module.vmp_apply_tmp_bytes(n, brk_size, rows, rows, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2) let acc: usize; if extension_factor > 1 { - acc = module.vec_znx_alloc_bytes(cols, k_res.div_ceil(basek)) * extension_factor; + acc = VecZnx::alloc_bytes(n, cols, k_res.div_ceil(basek)) * extension_factor; } else { acc = 0; } @@ -76,12 +77,10 @@ where + acc_dft_add + vmp_res + vmp_xai - + (vmp - | (acc_big - + (module.vec_znx_big_normalize_tmp_bytes(module.n()) | module.vec_znx_dft_to_vec_znx_big_tmp_bytes()))); + + (vmp | (acc_big + (module.vec_znx_big_normalize_tmp_bytes(n) | module.vec_znx_dft_to_vec_znx_big_tmp_bytes(n)))); } else { - GLWECiphertext::bytes_of(module, basek, k_res, rank) - + GLWECiphertext::external_product_scratch_space(module, basek, k_res, k_res, k_brk, 1, rank) + GLWECiphertext::bytes_of(n, basek, k_res, rank) + + GLWECiphertext::external_product_scratch_space(module, n, basek, k_res, k_res, k_brk, 1, rank) } } @@ -97,8 +96,7 @@ pub fn cggi_blind_rotate( DataIn: DataRef, DataBrk: DataRef, Module: CCGIBlindRotationFamily, - Scratch: - TakeVecZnxDftSlice + TakeVecZnxDft + TakeVecZnxBig + TakeVecZnx + ScratchAvailable + TakeVecZnxSlice, + Scratch: TakeVecZnxDftSlice + TakeVecZnxDft + TakeVecZnxBig + TakeVecZnx + ScratchAvailable + TakeVecZnxSlice, { match brk.dist { Distribution::BinaryBlock(_) | Distribution::BinaryFixed(_) | Distribution::BinaryProb(_) | Distribution::ZERO => { @@ -129,18 +127,19 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended: CCGIBlindRotationFamily, - Scratch: TakeVecZnxDftSlice + TakeVecZnxDft + TakeVecZnxBig + TakeVecZnxSlice, + Scratch: TakeVecZnxDftSlice + TakeVecZnxDft + TakeVecZnxBig + TakeVecZnxSlice, { + let n_glwe: usize = brk.n(); let extension_factor: usize = lut.extension_factor(); let basek: usize = res.basek(); let rows: usize = brk.rows(); let cols: usize = res.rank() + 1; - let (mut acc, scratch1) = scratch.take_vec_znx_slice(extension_factor, module, cols, res.size()); - let (mut acc_dft, scratch2) = scratch1.take_vec_znx_dft_slice(extension_factor, module, cols, rows); - let (mut vmp_res, scratch3) = scratch2.take_vec_znx_dft_slice(extension_factor, module, cols, brk.size()); - let (mut acc_add_dft, scratch4) = scratch3.take_vec_znx_dft_slice(extension_factor, module, cols, brk.size()); - let (mut vmp_xai, scratch5) = scratch4.take_vec_znx_dft(module, 1, brk.size()); + let (mut acc, scratch1) = scratch.take_vec_znx_slice(extension_factor, n_glwe, cols, res.size()); + let (mut acc_dft, scratch2) = scratch1.take_vec_znx_dft_slice(extension_factor, n_glwe, cols, rows); + let (mut vmp_res, scratch3) = scratch2.take_vec_znx_dft_slice(extension_factor, n_glwe, cols, brk.size()); + let (mut acc_add_dft, scratch4) = scratch3.take_vec_znx_dft_slice(extension_factor, n_glwe, cols, brk.size()); + let (mut vmp_xai, scratch5) = scratch4.take_vec_znx_dft(n_glwe, 1, brk.size()); (0..extension_factor).for_each(|i| { acc[i].zero(); @@ -156,7 +155,7 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended = vec![0i64; lwe.n() + 1]; // TODO: from scratch space let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref(); - let two_n: usize = 2 * module.n(); + let two_n: usize = 2 * n_glwe; let two_n_ext: usize = 2 * lut.domain_size(); negate_and_mod_switch_2n(two_n_ext, &mut lwe_2n, &lwe_ref); @@ -244,7 +243,7 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended: CCGIBlindRotationFamily, Scratch: TakeVecZnxDft + TakeVecZnxBig, { + let n_glwe: usize = brk.n(); let mut lwe_2n: Vec = vec![0i64; lwe.n() + 1]; // TODO: from scratch space let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut(); let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref(); - let two_n: usize = module.n() << 1; + let two_n: usize = n_glwe << 1; let basek: usize = brk.basek(); - let rows = brk.rows(); + let rows: usize = brk.rows(); let cols: usize = out_mut.rank() + 1; @@ -298,10 +298,10 @@ pub(crate) fn cggi_blind_rotate_block_binary, B>>; if let Some(b) = &brk.x_pow_a { @@ -336,7 +336,7 @@ pub(crate) fn cggi_blind_rotate_block_binary: CCGIBlindRotationFamily, - Scratch: TakeVecZnxDft + TakeVecZnxBig + TakeVecZnx + ScratchAvailable, + Scratch: TakeVecZnxDft + TakeVecZnxBig + TakeVecZnx + ScratchAvailable, { #[cfg(debug_assertions)] { assert_eq!( res.n(), - module.n(), + brk.n(), "res.n(): {} != brk.n(): {}", res.n(), - module.n() + brk.n() ); assert_eq!( lut.domain_size(), - module.n(), + brk.n(), "lut.n(): {} != brk.n(): {}", lut.domain_size(), - module.n() - ); - assert_eq!( - brk.n(), - module.n(), - "brk.n(): {} != brk.n(): {}", - brk.n(), - module.n() + brk.n() ); assert_eq!( res.rank(), @@ -416,7 +409,7 @@ pub(crate) fn cggi_blind_rotate_binary_standard WriterTo for BlindRotationKeyCGGI { } impl BlindRotationKeyCGGI> { - pub fn alloc(module: &Module, n_lwe: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self - where - Module: MatZnxAlloc, - { + pub fn alloc(n_gglwe: usize, n_lwe: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self { let mut data: Vec>> = Vec::with_capacity(n_lwe); - (0..n_lwe).for_each(|_| data.push(GGSWCiphertext::alloc(module, basek, k, rows, 1, rank))); + (0..n_lwe).for_each(|_| data.push(GGSWCiphertext::alloc(n_gglwe, basek, k, rows, 1, rank))); Self { keys: data, dist: Distribution::NONE, } } - pub fn generate_from_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize + pub fn generate_from_sk_scratch_space(module: &Module, n: usize, basek: usize, k: usize, rank: usize) -> usize where - Module: GGSWEncryptSkFamily + VecZnxAllocBytes, + Module: GGSWEncryptSkFamily, { - GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k, rank) + GGSWCiphertext::encrypt_sk_scratch_space(module, n, basek, k, rank) } } @@ -141,13 +138,13 @@ impl BlindRotationKeyCGGI { ) where DataSkGLWE: DataRef, DataSkLWE: DataRef, - Module: GGSWEncryptSkFamily + ScalarZnxAlloc + VecZnxAddScalarInplace, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + Module: GGSWEncryptSkFamily + VecZnxAddScalarInplace, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { #[cfg(debug_assertions)] { assert_eq!(self.keys.len(), sk_lwe.n()); - assert_eq!(sk_glwe.n(), module.n()); + assert!(sk_glwe.n() <= module.n()); assert_eq!(sk_glwe.rank(), self.keys[0].rank()); match sk_lwe.dist { Distribution::BinaryBlock(_) @@ -162,7 +159,7 @@ impl BlindRotationKeyCGGI { self.dist = sk_lwe.dist; - let mut pt: ScalarZnx> = module.scalar_znx_alloc(1); + let mut pt: ScalarZnx> = ScalarZnx::alloc(sk_glwe.n(), 1); let sk_ref: ScalarZnx<&[u8]> = sk_lwe.data.to_ref(); self.keys.iter_mut().enumerate().for_each(|(i, ggsw)| { @@ -220,12 +217,16 @@ impl BlindRotationKeyCGGIExec { pub trait BlindRotationKeyCGGIExecLayoutFamily = GGSWLayoutFamily + SvpPPolAlloc + SvpPrepare; impl BlindRotationKeyCGGIExec, B> { - pub fn alloc(module: &Module, n_lwe: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self + pub fn alloc(module: &Module, n_glwe: usize, n_lwe: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self where Module: BlindRotationKeyCGGIExecLayoutFamily, { let mut data: Vec, B>> = Vec::with_capacity(n_lwe); - (0..n_lwe).for_each(|_| data.push(GGSWCiphertextExec::alloc(module, basek, k, rows, 1, rank))); + (0..n_lwe).for_each(|_| { + data.push(GGSWCiphertextExec::alloc( + module, n_glwe, basek, k, rows, 1, rank, + )) + }); Self { data, dist: Distribution::NONE, @@ -236,10 +237,11 @@ impl BlindRotationKeyCGGIExec, B> { pub fn from(module: &Module, other: &BlindRotationKeyCGGI, scratch: &mut Scratch) -> Self where DataOther: DataRef, - Module: BlindRotationKeyCGGIExecLayoutFamily + ScalarZnxAlloc, + Module: BlindRotationKeyCGGIExecLayoutFamily, { let mut brk: BlindRotationKeyCGGIExec, B> = Self::alloc( module, + other.n(), other.keys.len(), other.basek(), other.k(), @@ -255,13 +257,15 @@ impl BlindRotationKeyCGGIExec { pub fn prepare(&mut self, module: &Module, other: &BlindRotationKeyCGGI, scratch: &mut Scratch) where DataOther: DataRef, - Module: BlindRotationKeyCGGIExecLayoutFamily + ScalarZnxAlloc, + Module: BlindRotationKeyCGGIExecLayoutFamily, { #[cfg(debug_assertions)] { assert_eq!(self.data.len(), other.keys.len()); } + let n: usize = other.n(); + self.data .iter_mut() .zip(other.keys.iter()) @@ -273,10 +277,10 @@ impl BlindRotationKeyCGGIExec { match other.dist { Distribution::BinaryBlock(_) => { - let mut x_pow_a: Vec, B>> = Vec::with_capacity(module.n() << 1); - let mut buf: ScalarZnx> = module.scalar_znx_alloc(1); - (0..module.n() << 1).for_each(|i| { - let mut res: SvpPPol, B> = module.svp_ppol_alloc(1); + let mut x_pow_a: Vec, B>> = Vec::with_capacity(n << 1); + let mut buf: ScalarZnx> = ScalarZnx::alloc(n, 1); + (0..n << 1).for_each(|i| { + let mut res: SvpPPol, B> = module.svp_ppol_alloc(n, 1); set_xai_plus_y(module, i, 0, &mut res, &mut buf); x_pow_a.push(res); }); @@ -293,7 +297,7 @@ where C: DataMut, Module: SvpPrepare, { - let n: usize = module.n(); + let n: usize = res.n(); { let raw: &mut [i64] = buf.at_mut(0, 0); diff --git a/core/src/blind_rotation/lut.rs b/core/src/blind_rotation/lut.rs index 7df9150..a06aa46 100644 --- a/core/src/blind_rotation/lut.rs +++ b/core/src/blind_rotation/lut.rs @@ -1,7 +1,7 @@ use backend::hal::{ api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAlloc, VecZnxCopy, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, - VecZnxRotateInplace, VecZnxSwithcDegree, ZnxInfos, ZnxViewMut, + ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxCopy, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotateInplace, + VecZnxSwithcDegree, ZnxInfos, ZnxViewMut, }, layouts::{Backend, Module, ScratchOwned, VecZnx}, oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl}, @@ -14,10 +14,7 @@ pub struct LookUpTable { } impl LookUpTable { - pub fn alloc(module: &Module, basek: usize, k: usize, extension_factor: usize) -> Self - where - Module: VecZnxAlloc, - { + pub fn alloc(n: usize, basek: usize, k: usize, extension_factor: usize) -> Self { #[cfg(debug_assertions)] { assert!( @@ -29,7 +26,7 @@ impl LookUpTable { let size: usize = k.div_ceil(basek); let mut data: Vec>> = Vec::with_capacity(extension_factor); (0..extension_factor).for_each(|_| { - data.push(module.vec_znx_alloc(1, size)); + data.push(VecZnx::alloc(n, 1, size)); }); Self { data, basek, k } } @@ -69,13 +66,13 @@ impl LookUpTable { // #elements in lookup table let f_len: usize = f.len(); - // If LUT size > module.n() + // If LUT size > TakeScalarZnx let domain_size: usize = self.domain_size(); let size: usize = self.k.div_ceil(self.basek); // Equivalent to AUTO([f(0), -f(n-1), -f(n-2), ..., -f(1)], -1) - let mut lut_full: VecZnx> = VecZnx::alloc::(domain_size, 1, size); + let mut lut_full: VecZnx> = VecZnx::alloc(domain_size, 1, size); let lut_at: &mut [i64] = lut_full.at_mut(0, limbs - 1); diff --git a/core/src/blind_rotation/mod.rs b/core/src/blind_rotation/mod.rs index 6a454e6..9a31472 100644 --- a/core/src/blind_rotation/mod.rs +++ b/core/src/blind_rotation/mod.rs @@ -1,10 +1,10 @@ -pub mod cggi; -pub mod key; -pub mod lut; +mod cggi; +mod key; +mod lut; -pub use cggi::{CCGIBlindRotationFamily, cggi_blind_rotate, cggi_blind_rotate_scratch_space}; -pub use key::{BlindRotationKeyCGGI, BlindRotationKeyCGGIExec, BlindRotationKeyCGGIExecLayoutFamily}; -pub use lut::LookUpTable; +pub use cggi::*; +pub use key::*; +pub use lut::*; #[cfg(test)] -mod test; +mod tests; diff --git a/core/src/blind_rotation/test/mod.rs b/core/src/blind_rotation/test/mod.rs deleted file mode 100644 index 18ac93c..0000000 --- a/core/src/blind_rotation/test/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod cggi; -pub mod lut; diff --git a/core/src/blind_rotation/tests/cpu_spqlios/fft64.rs b/core/src/blind_rotation/tests/cpu_spqlios/fft64.rs new file mode 100644 index 0000000..f3ad51c --- /dev/null +++ b/core/src/blind_rotation/tests/cpu_spqlios/fft64.rs @@ -0,0 +1,39 @@ +use backend::{ + hal::{api::ModuleNew, layouts::Module}, + implementation::cpu_spqlios::FFT64, +}; + +use crate::blind_rotation::tests::{ + generic_cggi::blind_rotatio_test, + generic_lut::{test_lut_extended, test_lut_standard}, +}; + +#[test] +fn lut_standard() { + let module: Module = Module::::new(32); + test_lut_standard(&module); +} + +#[test] +fn lut_extended() { + let module: Module = Module::::new(32); + test_lut_extended(&module); +} + +#[test] +fn standard() { + let module: Module = Module::::new(512); + blind_rotatio_test(&module, 224, 1, 1); +} + +#[test] +fn block_binary() { + let module: Module = Module::::new(512); + blind_rotatio_test(&module, 224, 7, 1); +} + +#[test] +fn block_binary_extended() { + let module: Module = Module::::new(512); + blind_rotatio_test(&module, 224, 7, 2); +} diff --git a/core/src/blind_rotation/tests/cpu_spqlios/mod.rs b/core/src/blind_rotation/tests/cpu_spqlios/mod.rs new file mode 100644 index 0000000..aebaafb --- /dev/null +++ b/core/src/blind_rotation/tests/cpu_spqlios/mod.rs @@ -0,0 +1 @@ +mod fft64; diff --git a/core/src/blind_rotation/test/cggi.rs b/core/src/blind_rotation/tests/generic_cggi.rs similarity index 65% rename from core/src/blind_rotation/test/cggi.rs rename to core/src/blind_rotation/tests/generic_cggi.rs index 73f7c57..ffec8af 100644 --- a/core/src/blind_rotation/test/cggi.rs +++ b/core/src/blind_rotation/tests/generic_cggi.rs @@ -1,63 +1,33 @@ -use backend::{ - hal::{ - api::{ - MatZnxAlloc, ModuleNew, ScalarZnxAlloc, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddNormal, - VecZnxAddScalarInplace, VecZnxAlloc, VecZnxAllocBytes, VecZnxEncodeCoeffsi64, VecZnxFillUniform, VecZnxRotateInplace, - VecZnxSub, VecZnxSwithcDegree, ZnxView, - }, - layouts::{Backend, Module, ScratchOwned}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl, - TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl, - }, +use backend::hal::{ + api::{ + ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxEncodeCoeffsi64, VecZnxFillUniform, + VecZnxRotateInplace, VecZnxSub, VecZnxSwithcDegree, ZnxView, + }, + layouts::{Backend, Module, ScratchOwned}, + oep::{ + ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl, + TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl, }, - implementation::cpu_spqlios::FFT64, }; use sampling::source::Source; use crate::{ - BlindRotationKeyCGGIExecLayoutFamily, CCGIBlindRotationFamily, GLWECiphertext, GLWEDecryptFamily, GLWEPlaintext, GLWESecret, - GLWESecretExec, GLWESecretFamily, Infos, LWECiphertext, LWESecret, - blind_rotation::{ - cggi::{cggi_blind_rotate, cggi_blind_rotate_scratch_space, negate_and_mod_switch_2n}, - key::{BlindRotationKeyCGGI, BlindRotationKeyCGGIExec}, - lut::LookUpTable, - }, - lwe::{LWEPlaintext, ciphertext::LWECiphertextToRef}, + BlindRotationKeyCGGI, BlindRotationKeyCGGIExec, BlindRotationKeyCGGIExecLayoutFamily, CCGIBlindRotationFamily, + GLWECiphertext, GLWEDecryptFamily, GLWEPlaintext, GLWESecret, GLWESecretExec, GLWESecretFamily, Infos, LWECiphertext, + LWECiphertextToRef, LWEPlaintext, LWESecret, LookUpTable, cggi_blind_rotate, cggi_blind_rotate_scratch_space, + negate_and_mod_switch_2n, }; -#[test] -fn standard() { - let module: Module = Module::::new(512); - blind_rotatio_test(&module, 224, 1, 1); -} - -#[test] -fn block_binary() { - let module: Module = Module::::new(512); - blind_rotatio_test(&module, 224, 7, 1); -} - -#[test] -fn block_binary_extended() { - let module: Module = Module::::new(512); - blind_rotatio_test(&module, 224, 7, 2); -} - pub(crate) trait CGGITestModuleFamily = CCGIBlindRotationFamily + GLWESecretFamily + GLWEDecryptFamily + BlindRotationKeyCGGIExecLayoutFamily - + VecZnxAlloc - + ScalarZnxAlloc + VecZnxFillUniform + VecZnxAddNormal - + VecZnxAllocBytes + VecZnxAddScalarInplace + VecZnxEncodeCoeffsi64 + VecZnxRotateInplace + VecZnxSwithcDegree - + MatZnxAlloc + VecZnxSub; pub(crate) trait CGGITestScratchFamily = VecZnxDftAllocBytesImpl + VecZnxBigAllocBytesImpl @@ -70,13 +40,13 @@ pub(crate) trait CGGITestScratchFamily = VecZnxDftAllocBytesImpl + TakeVecZnxImpl + TakeVecZnxSliceImpl; -fn blind_rotatio_test(module: &Module, n_lwe: usize, block_size: usize, extension_factor: usize) +pub(crate) fn blind_rotatio_test(module: &Module, n_lwe: usize, block_size: usize, extension_factor: usize) where Module: CGGITestModuleFamily, B: CGGITestScratchFamily, { + let n: usize = module.n(); let basek: usize = 19; - let k_lwe: usize = 24; let k_brk: usize = 3 * basek; let rows_brk: usize = 2; // Ensures first limb is noise-free. @@ -90,7 +60,7 @@ where let mut source_xe: Source = Source::new([2u8; 32]); let mut source_xa: Source = Source::new([1u8; 32]); - let mut sk_glwe: GLWESecret> = GLWESecret::alloc(module, rank); + let mut sk_glwe: GLWESecret> = GLWESecret::alloc(n, rank); sk_glwe.fill_ternary_prob(0.5, &mut source_xs); let sk_glwe_dft: GLWESecretExec, B> = GLWESecretExec::from(module, &sk_glwe); @@ -98,11 +68,12 @@ where sk_lwe.fill_binary_block(block_size, &mut source_xs); let mut scratch: ScratchOwned = ScratchOwned::::alloc(BlindRotationKeyCGGI::generate_from_sk_scratch_space( - module, basek, k_brk, rank, + module, n, basek, k_brk, rank, )); let mut scratch_br: ScratchOwned = ScratchOwned::::alloc(cggi_blind_rotate_scratch_space( module, + n, block_size, extension_factor, basek, @@ -112,7 +83,7 @@ where rank, )); - let mut brk: BlindRotationKeyCGGI> = BlindRotationKeyCGGI::alloc(module, n_lwe, basek, k_brk, rows_brk, rank); + let mut brk: BlindRotationKeyCGGI> = BlindRotationKeyCGGI::alloc(n, n_lwe, basek, k_brk, rows_brk, rank); brk.generate_from_sk( module, @@ -147,16 +118,16 @@ where .enumerate() .for_each(|(i, x)| *x = 2 * (i as i64) + 1); - let mut lut: LookUpTable = LookUpTable::alloc(module, basek, k_lut, extension_factor); + let mut lut: LookUpTable = LookUpTable::alloc(n, basek, k_lut, extension_factor); lut.set(module, &f, message_modulus); - let mut res: GLWECiphertext> = GLWECiphertext::alloc(module, basek, k_res, rank); + let mut res: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_res, rank); let brk_exec: BlindRotationKeyCGGIExec, B> = BlindRotationKeyCGGIExec::from(module, &brk, scratch_br.borrow()); cggi_blind_rotate(module, &mut res, &lwe, &lut, &brk_exec, scratch_br.borrow()); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k_res); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_res); res.decrypt(module, &mut pt_have, &sk_glwe_dft, scratch.borrow()); diff --git a/core/src/blind_rotation/test/lut.rs b/core/src/blind_rotation/tests/generic_lut.rs similarity index 58% rename from core/src/blind_rotation/test/lut.rs rename to core/src/blind_rotation/tests/generic_lut.rs index bd893fc..86263b0 100644 --- a/core/src/blind_rotation/test/lut.rs +++ b/core/src/blind_rotation/tests/generic_lut.rs @@ -1,18 +1,19 @@ use std::vec; -use backend::{ - hal::{ - api::{ModuleNew, ZnxView}, - layouts::Module, - }, - implementation::cpu_spqlios::FFT64, +use backend::hal::{ + api::{VecZnxCopy, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotateInplace, VecZnxSwithcDegree, ZnxView}, + layouts::{Backend, Module}, + oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl}, }; -use crate::blind_rotation::lut::{DivRound, LookUpTable}; +use crate::{DivRound, LookUpTable}; -#[test] -fn standard() { - let module: Module = Module::::new(32); +pub(crate) fn test_lut_standard(module: &Module) +where + Module: VecZnxRotateInplace + VecZnxNormalizeInplace + VecZnxNormalizeTmpBytes + VecZnxSwithcDegree + VecZnxCopy, + B: ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, +{ + let n: usize = module.n(); let basek: usize = 20; let k_lut: usize = 40; let message_modulus: usize = 16; @@ -25,11 +26,11 @@ fn standard() { .enumerate() .for_each(|(i, x)| *x = (i as i64) - 8); - let mut lut: LookUpTable = LookUpTable::alloc(&module, basek, k_lut, extension_factor); - lut.set(&module, &f, log_scale); + let mut lut: LookUpTable = LookUpTable::alloc(n, basek, k_lut, extension_factor); + lut.set(module, &f, log_scale); let half_step: i64 = lut.domain_size().div_round(message_modulus << 1) as i64; - lut.rotate(&module, half_step); + lut.rotate(module, half_step); let step: usize = lut.domain_size().div_round(message_modulus); @@ -39,14 +40,17 @@ fn standard() { f[i / step] % message_modulus as i64, lut.data[0].raw()[0] / (1 << (log_scale % basek)) as i64 ); - lut.rotate(&module, -1); + lut.rotate(module, -1); }); }); } -#[test] -fn extended() { - let module: Module = Module::::new(32); +pub(crate) fn test_lut_extended(module: &Module) +where + Module: VecZnxRotateInplace + VecZnxNormalizeInplace + VecZnxNormalizeTmpBytes + VecZnxSwithcDegree + VecZnxCopy, + B: ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, +{ + let n: usize = module.n(); let basek: usize = 20; let k_lut: usize = 40; let message_modulus: usize = 16; @@ -59,7 +63,7 @@ fn extended() { .enumerate() .for_each(|(i, x)| *x = (i as i64) - 8); - let mut lut: LookUpTable = LookUpTable::alloc(&module, basek, k_lut, extension_factor); + let mut lut: LookUpTable = LookUpTable::alloc(n, basek, k_lut, extension_factor); lut.set(&module, &f, log_scale); let half_step: i64 = lut.domain_size().div_round(message_modulus << 1) as i64; diff --git a/core/src/blind_rotation/tests/generics_automorphism_key.rs b/core/src/blind_rotation/tests/generics_automorphism_key.rs new file mode 100644 index 0000000..5677da1 --- /dev/null +++ b/core/src/blind_rotation/tests/generics_automorphism_key.rs @@ -0,0 +1,396 @@ +use backend::hal::{ + api::{ + ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxCopy, + VecZnxStd, VecZnxSubScalarInplace, VecZnxSwithcDegree, + }, + layouts::{Backend, Module, ScratchOwned}, + oep::{ + ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, + TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, + }, +}; +use sampling::source::Source; + +use crate::{ + AutomorphismKey, AutomorphismKeyCompressed, AutomorphismKeyEncryptSkFamily, AutomorphismKeyExec, GGLWEExecLayoutFamily, + GLWEDecryptFamily, GLWEKeyswitchFamily, GLWEPlaintext, GLWESecret, GLWESecretExec, Infos, + noise::log2_std_noise_gglwe_product, +}; + +pub(crate) trait AutomorphismTestModuleFamily = AutomorphismKeyEncryptSkFamily + + GLWEKeyswitchFamily + + VecZnxAutomorphism + + GGLWEExecLayoutFamily + + VecZnxSwithcDegree + + VecZnxAddScalarInplace + + VecZnxAutomorphism + + VecZnxAutomorphismInplace + + GLWEDecryptFamily + + VecZnxSubScalarInplace + + VecZnxStd + + VecZnxCopy; +pub(crate) trait AutomorphismTestScratchFamily = ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxDftImpl + + TakeVecZnxImpl + + TakeSvpPPolImpl + + TakeVecZnxBigImpl; + +pub(crate) fn test_automorphisk_key_encrypt_sk( + module: &Module, + basek: usize, + k_ksk: usize, + digits: usize, + rank: usize, + sigma: f64, +) where + Module: AutomorphismTestModuleFamily, + B: AutomorphismTestScratchFamily, +{ + let n: usize = TakeScalarZnx; + let rows: usize = (k_ksk - digits * basek) / (digits * basek); + + let mut atk: AutomorphismKey> = AutomorphismKey::alloc(n, basek, k_ksk, rows, digits, rank); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(AutomorphismKey::encrypt_sk_scratch_space( + module, n, basek, k_ksk, rank, + )); + + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let p = -5; + + atk.encrypt_sk( + module, + p, + &sk, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + let mut sk_out: GLWESecret> = sk.clone(); + (0..atk.rank()).for_each(|i| { + module.vec_znx_automorphism( + module.galois_element_inv(p), + &mut sk_out.data.as_vec_znx_mut(), + i, + &sk.data.as_vec_znx(), + i, + ); + }); + let sk_out_exec = GLWESecretExec::from(module, &sk_out); + + atk.key + .key + .assert_noise(module, &sk_out_exec, &sk.data, sigma); +} + +pub(crate) fn test_automorphisk_key_encrypt_sk_compressed( + module: &Module, + basek: usize, + k_ksk: usize, + digits: usize, + rank: usize, + sigma: f64, +) where + Module: AutomorphismTestModuleFamily, + B: AutomorphismTestScratchFamily, +{ + let n: usize = TakeScalarZnx; + let rows: usize = (k_ksk - digits * basek) / (digits * basek); + + let mut atk_compressed: AutomorphismKeyCompressed> = + AutomorphismKeyCompressed::alloc(n, basek, k_ksk, rows, digits, rank); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(AutomorphismKey::encrypt_sk_scratch_space( + module, n, basek, k_ksk, rank, + )); + + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let p = -5; + + let seed_xa: [u8; 32] = [1u8; 32]; + + atk_compressed.encrypt_sk( + module, + p, + &sk, + seed_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + let mut sk_out: GLWESecret> = sk.clone(); + (0..atk_compressed.rank()).for_each(|i| { + module.vec_znx_automorphism( + module.galois_element_inv(p), + &mut sk_out.data.as_vec_znx_mut(), + i, + &sk.data.as_vec_znx(), + i, + ); + }); + let sk_out_exec = GLWESecretExec::from(module, &sk_out); + + let mut atk: AutomorphismKey> = AutomorphismKey::alloc(n, basek, k_ksk, rows, digits, rank); + atk.decompress(module, &atk_compressed); + + atk.key + .key + .assert_noise(module, &sk_out_exec, &sk.data, sigma); +} + +pub(crate) fn test_gglwe_automorphism( + module: &Module, + p0: i64, + p1: i64, + basek: usize, + digits: usize, + k_in: usize, + k_out: usize, + k_apply: usize, + sigma: f64, + rank: usize, +) where + Module: AutomorphismTestModuleFamily, + B: AutomorphismTestScratchFamily, +{ + let n: usize = TakeScalarZnx; + let digits_in: usize = 1; + + let rows_in: usize = k_in / (basek * digits); + let rows_apply: usize = k_in.div_ceil(basek * digits); + + let mut auto_key_in: AutomorphismKey> = AutomorphismKey::alloc(n, basek, k_in, rows_in, digits_in, rank); + let mut auto_key_out: AutomorphismKey> = AutomorphismKey::alloc(n, basek, k_out, rows_in, digits_in, rank); + let mut auto_key_apply: AutomorphismKey> = AutomorphismKey::alloc(n, basek, k_apply, rows_apply, digits, rank); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + AutomorphismKey::encrypt_sk_scratch_space(module, n, basek, k_apply, rank) + | AutomorphismKey::automorphism_scratch_space(module, n, basek, k_out, k_in, k_apply, digits, rank), + ); + + let mut sk: GLWESecret> = GLWESecret::alloc(TakeScalarZnx, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + // gglwe_{s1}(s0) = s0 -> s1 + auto_key_in.encrypt_sk( + module, + p0, + &sk, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + // gglwe_{s2}(s1) -> s1 -> s2 + auto_key_apply.encrypt_sk( + module, + p1, + &sk, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + let mut auto_key_apply_exec: AutomorphismKeyExec, B> = + AutomorphismKeyExec::alloc(module, n, basek, k_apply, rows_apply, digits, rank); + + auto_key_apply_exec.prepare(module, &auto_key_apply, scratch.borrow()); + + // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) + auto_key_out.automorphism(module, &auto_key_in, &auto_key_apply_exec, scratch.borrow()); + + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(TakeScalarZnx, basek, k_out); + + let mut sk_auto: GLWESecret> = GLWESecret::alloc(TakeScalarZnx, rank); + sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk + (0..rank).for_each(|i| { + module.vec_znx_automorphism( + module.galois_element_inv(p0 * p1), + &mut sk_auto.data.as_vec_znx_mut(), + i, + &sk.data.as_vec_znx(), + i, + ); + }); + + let sk_auto_dft: GLWESecretExec, B> = GLWESecretExec::from(module, &sk_auto); + + (0..auto_key_out.rank_in()).for_each(|col_i| { + (0..auto_key_out.rows()).for_each(|row_i| { + auto_key_out + .at(row_i, col_i) + .decrypt(module, &mut pt, &sk_auto_dft, scratch.borrow()); + + module.vec_znx_sub_scalar_inplace( + &mut pt.data, + 0, + (digits_in - 1) + row_i * digits_in, + &sk.data, + col_i, + ); + + let noise_have: f64 = module.vec_znx_std(basek, &pt.data, 0).log2(); + let noise_want: f64 = log2_std_noise_gglwe_product( + TakeScalarZnx as f64, + basek * digits, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + rank as f64, + k_out, + k_apply, + ); + + assert!( + noise_have < noise_want + 0.5, + "{} {}", + noise_have, + noise_want + ); + }); + }); +} + +pub(crate) fn test_gglwe_automorphism_inplace( + module: &Module, + p0: i64, + p1: i64, + basek: usize, + digits: usize, + k_in: usize, + k_apply: usize, + sigma: f64, + rank: usize, +) where + Module: AutomorphismTestModuleFamily, + B: AutomorphismTestScratchFamily, +{ + let n: usize = TakeScalarZnx; + let digits_in: usize = 1; + + let rows_in: usize = k_in / (basek * digits); + let rows_apply: usize = k_in.div_ceil(basek * digits); + + let mut auto_key: AutomorphismKey> = AutomorphismKey::alloc(n, basek, k_in, rows_in, digits_in, rank); + let mut auto_key_apply: AutomorphismKey> = AutomorphismKey::alloc(n, basek, k_apply, rows_apply, digits, rank); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + AutomorphismKey::encrypt_sk_scratch_space(module, n, basek, k_apply, rank) + | AutomorphismKey::automorphism_inplace_scratch_space(module, n, basek, k_in, k_apply, digits, rank), + ); + + let mut sk: GLWESecret> = GLWESecret::alloc(TakeScalarZnx, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + // gglwe_{s1}(s0) = s0 -> s1 + auto_key.encrypt_sk( + module, + p0, + &sk, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + // gglwe_{s2}(s1) -> s1 -> s2 + auto_key_apply.encrypt_sk( + module, + p1, + &sk, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + let mut auto_key_apply_exec: AutomorphismKeyExec, B> = + AutomorphismKeyExec::alloc(module, n, basek, k_apply, rows_apply, digits, rank); + + auto_key_apply_exec.prepare(module, &auto_key_apply, scratch.borrow()); + + // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) + auto_key.automorphism_inplace(module, &auto_key_apply_exec, scratch.borrow()); + + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(TakeScalarZnx, basek, k_in); + + let mut sk_auto: GLWESecret> = GLWESecret::alloc(TakeScalarZnx, rank); + sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk + + (0..rank).for_each(|i| { + module.vec_znx_automorphism( + module.galois_element_inv(p0 * p1), + &mut sk_auto.data.as_vec_znx_mut(), + i, + &sk.data.as_vec_znx(), + i, + ); + }); + + let sk_auto_dft: GLWESecretExec, B> = GLWESecretExec::from(module, &sk_auto); + + (0..auto_key.rank_in()).for_each(|col_i| { + (0..auto_key.rows()).for_each(|row_i| { + auto_key + .at(row_i, col_i) + .decrypt(module, &mut pt, &sk_auto_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace( + &mut pt.data, + 0, + (digits_in - 1) + row_i * digits_in, + &sk.data, + col_i, + ); + + let noise_have: f64 = module.vec_znx_std(basek, &pt.data, 0).log2(); + let noise_want: f64 = log2_std_noise_gglwe_product( + TakeScalarZnx as f64, + basek * digits, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + rank as f64, + k_in, + k_apply, + ); + + assert!( + noise_have < noise_want + 0.5, + "{} {}", + noise_have, + noise_want + ); + }); + }); +} diff --git a/core/src/blind_rotation/tests/key.rs b/core/src/blind_rotation/tests/key.rs new file mode 100644 index 0000000..7ae3f4c --- /dev/null +++ b/core/src/blind_rotation/tests/key.rs @@ -0,0 +1,321 @@ +use backend::hal::{ + api::{ScratchAvailable, SvpPPolAlloc, SvpPrepare, TakeVecZnx, TakeVecZnxDft, VecZnxAddScalarInplace, ZnxView, ZnxViewMut}, + layouts::{Backend, Data, DataMut, DataRef, Module, ReaderFrom, ScalarZnx, ScalarZnxToRef, Scratch, SvpPPol, WriterTo}, +}; +use sampling::source::Source; + +use crate::{ + Distribution, GGSWCiphertext, GGSWCiphertextExec, GGSWEncryptSkFamily, GGSWLayoutFamily, GLWESecretExec, Infos, LWESecret, +}; + +pub struct BlindRotationKeyCGGI { + pub(crate) keys: Vec>, + pub(crate) dist: Distribution, +} + +impl PartialEq for BlindRotationKeyCGGI { + fn eq(&self, other: &Self) -> bool { + if self.keys.len() != other.keys.len() { + return false; + } + for (a, b) in self.keys.iter().zip(other.keys.iter()) { + if a != b { + return false; + } + } + self.dist == other.dist + } +} + +impl Eq for BlindRotationKeyCGGI {} + +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; + +impl ReaderFrom for BlindRotationKeyCGGI { + fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { + match Distribution::read_from(reader) { + Ok(dist) => self.dist = dist, + Err(e) => return Err(e), + } + let len: usize = reader.read_u64::()? as usize; + if self.keys.len() != len { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("self.keys.len()={} != read len={}", self.keys.len(), len), + )); + } + for key in &mut self.keys { + key.read_from(reader)?; + } + Ok(()) + } +} + +impl WriterTo for BlindRotationKeyCGGI { + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + match self.dist.write_to(writer) { + Ok(()) => {} + Err(e) => return Err(e), + } + writer.write_u64::(self.keys.len() as u64)?; + for key in &self.keys { + key.write_to(writer)?; + } + Ok(()) + } +} + +impl BlindRotationKeyCGGI> { + pub fn alloc(n_gglwe: usize, n_lwe: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self { + let mut data: Vec>> = Vec::with_capacity(n_lwe); + (0..n_lwe).for_each(|_| data.push(GGSWCiphertext::alloc(n_gglwe, basek, k, rows, 1, rank))); + Self { + keys: data, + dist: Distribution::NONE, + } + } + + pub fn generate_from_sk_scratch_space(module: &Module, n: usize, basek: usize, k: usize, rank: usize) -> usize + where + Module: GGSWEncryptSkFamily, + { + GGSWCiphertext::encrypt_sk_scratch_space(module, n, basek, k, rank) + } +} + +impl BlindRotationKeyCGGI { + #[allow(dead_code)] + pub(crate) fn n(&self) -> usize { + self.keys[0].n() + } + + #[allow(dead_code)] + pub(crate) fn rows(&self) -> usize { + self.keys[0].rows() + } + + #[allow(dead_code)] + pub(crate) fn k(&self) -> usize { + self.keys[0].k() + } + + #[allow(dead_code)] + pub(crate) fn size(&self) -> usize { + self.keys[0].size() + } + + #[allow(dead_code)] + pub(crate) fn rank(&self) -> usize { + self.keys[0].rank() + } + + pub(crate) fn basek(&self) -> usize { + self.keys[0].basek() + } + + #[allow(dead_code)] + pub(crate) fn block_size(&self) -> usize { + match self.dist { + Distribution::BinaryBlock(value) => value, + _ => 1, + } + } +} + +impl BlindRotationKeyCGGI { + pub fn generate_from_sk( + &mut self, + module: &Module, + sk_glwe: &GLWESecretExec, + sk_lwe: &LWESecret, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + DataSkGLWE: DataRef, + DataSkLWE: DataRef, + Module: GGSWEncryptSkFamily + VecZnxAddScalarInplace, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.keys.len(), sk_lwe.n()); + assert!(sk_glwe.n() <= TakeScalarZnx); + assert_eq!(sk_glwe.rank(), self.keys[0].rank()); + match sk_lwe.dist { + Distribution::BinaryBlock(_) + | Distribution::BinaryFixed(_) + | Distribution::BinaryProb(_) + | Distribution::ZERO => {} + _ => panic!( + "invalid GLWESecret distribution: must be BinaryBlock, BinaryFixed or BinaryProb (or ZERO for debugging)" + ), + } + } + + self.dist = sk_lwe.dist; + + let mut pt: ScalarZnx> = ScalarZnx::alloc(sk_glwe.n(), 1); + let sk_ref: ScalarZnx<&[u8]> = sk_lwe.data.to_ref(); + + self.keys.iter_mut().enumerate().for_each(|(i, ggsw)| { + pt.at_mut(0, 0)[0] = sk_ref.at(0, 0)[i]; + ggsw.encrypt_sk(module, &pt, sk_glwe, source_xa, source_xe, sigma, scratch); + }); + } +} + +#[derive(PartialEq, Eq)] +pub struct BlindRotationKeyCGGIExec { + pub(crate) data: Vec>, + pub(crate) dist: Distribution, + pub(crate) x_pow_a: Option, B>>>, +} + +impl BlindRotationKeyCGGIExec { + #[allow(dead_code)] + pub(crate) fn n(&self) -> usize { + self.data[0].n() + } + + #[allow(dead_code)] + pub(crate) fn rows(&self) -> usize { + self.data[0].rows() + } + + #[allow(dead_code)] + pub(crate) fn k(&self) -> usize { + self.data[0].k() + } + + #[allow(dead_code)] + pub(crate) fn size(&self) -> usize { + self.data[0].size() + } + + #[allow(dead_code)] + pub(crate) fn rank(&self) -> usize { + self.data[0].rank() + } + + pub(crate) fn basek(&self) -> usize { + self.data[0].basek() + } + + pub(crate) fn block_size(&self) -> usize { + match self.dist { + Distribution::BinaryBlock(value) => value, + _ => 1, + } + } +} + +pub trait BlindRotationKeyCGGIExecLayoutFamily = GGSWLayoutFamily + SvpPPolAlloc + SvpPrepare; + +impl BlindRotationKeyCGGIExec, B> { + pub fn alloc(module: &Module, n_glwe: usize, n_lwe: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self + where + Module: BlindRotationKeyCGGIExecLayoutFamily, + { + let mut data: Vec, B>> = Vec::with_capacity(n_lwe); + (0..n_lwe).for_each(|_| { + data.push(GGSWCiphertextExec::alloc( + module, n_glwe, basek, k, rows, 1, rank, + )) + }); + Self { + data, + dist: Distribution::NONE, + x_pow_a: None, + } + } + + pub fn from(module: &Module, other: &BlindRotationKeyCGGI, scratch: &mut Scratch) -> Self + where + DataOther: DataRef, + Module: BlindRotationKeyCGGIExecLayoutFamily, + { + let mut brk: BlindRotationKeyCGGIExec, B> = Self::alloc( + module, + other.n(), + other.keys.len(), + other.basek(), + other.k(), + other.rows(), + other.rank(), + ); + brk.prepare(module, other, scratch); + brk + } +} + +impl BlindRotationKeyCGGIExec { + pub fn prepare(&mut self, module: &Module, other: &BlindRotationKeyCGGI, scratch: &mut Scratch) + where + DataOther: DataRef, + Module: BlindRotationKeyCGGIExecLayoutFamily, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.data.len(), other.keys.len()); + } + + let n: usize = other.n(); + + self.data + .iter_mut() + .zip(other.keys.iter()) + .for_each(|(ggsw_exec, other)| { + ggsw_exec.prepare(module, other, scratch); + }); + + self.dist = other.dist; + + match other.dist { + Distribution::BinaryBlock(_) => { + let mut x_pow_a: Vec, B>> = Vec::with_capacity(n << 1); + let mut buf: ScalarZnx> = ScalarZnx::alloc(n, 1); + (0..n << 1).for_each(|i| { + let mut res: SvpPPol, B> = module.svp_ppol_alloc(n, 1); + set_xai_plus_y(module, i, 0, &mut res, &mut buf); + x_pow_a.push(res); + }); + self.x_pow_a = Some(x_pow_a); + } + _ => {} + } + } +} + +pub fn set_xai_plus_y(module: &Module, ai: usize, y: i64, res: &mut SvpPPol, buf: &mut ScalarZnx) +where + A: DataMut, + C: DataMut, + Module: SvpPrepare, +{ + let n: usize = res.n(); + + { + let raw: &mut [i64] = buf.at_mut(0, 0); + if ai < n { + raw[ai] = 1; + } else { + raw[(ai - n) & (n - 1)] = -1; + } + raw[0] += y; + } + + module.svp_prepare(res, 0, buf, 0); + + { + let raw: &mut [i64] = buf.at_mut(0, 0); + + if ai < n { + raw[ai] = 0; + } else { + raw[(ai - n) & (n - 1)] = 0; + } + raw[0] = 0; + } +} diff --git a/core/src/blind_rotation/tests/mod.rs b/core/src/blind_rotation/tests/mod.rs new file mode 100644 index 0000000..ed5e334 --- /dev/null +++ b/core/src/blind_rotation/tests/mod.rs @@ -0,0 +1,3 @@ +mod cpu_spqlios; +mod generic_cggi; +mod generic_lut; diff --git a/core/src/gglwe/automorphism.rs b/core/src/gglwe/automorphism.rs index 06c4f63..19b7e88 100644 --- a/core/src/gglwe/automorphism.rs +++ b/core/src/gglwe/automorphism.rs @@ -8,6 +8,7 @@ use crate::{AutomorphismKey, AutomorphismKeyExec, GLWECiphertext, GLWEKeyswitchF impl AutomorphismKey> { pub fn automorphism_scratch_space( module: &Module, + n: usize, basek: usize, k_out: usize, k_in: usize, @@ -18,11 +19,12 @@ impl AutomorphismKey> { where Module: GLWEKeyswitchFamily, { - GLWECiphertext::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank, rank) + GLWECiphertext::keyswitch_scratch_space(module, n, basek, k_out, k_in, k_ksk, digits, rank, rank) } pub fn automorphism_inplace_scratch_space( module: &Module, + n: usize, basek: usize, k_out: usize, k_ksk: usize, @@ -32,7 +34,7 @@ impl AutomorphismKey> { where Module: GLWEKeyswitchFamily, { - AutomorphismKey::automorphism_scratch_space(module, basek, k_out, k_out, k_ksk, digits, rank) + AutomorphismKey::automorphism_scratch_space(module, n, basek, k_out, k_out, k_ksk, digits, rank) } } diff --git a/core/src/gglwe/encryption.rs b/core/src/gglwe/encryption.rs index 9754002..f4cc8de 100644 --- a/core/src/gglwe/encryption.rs +++ b/core/src/gglwe/encryption.rs @@ -1,8 +1,8 @@ use backend::hal::{ api::{ - ScalarZnxAllocBytes, ScratchAvailable, SvpApply, TakeScalarZnx, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, - VecZnxAddScalarInplace, VecZnxAllocBytes, VecZnxAutomorphism, VecZnxBigAllocBytes, VecZnxDftToVecZnxBigTmpA, - VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSwithcDegree, ZnxZero, + ScratchAvailable, SvpApply, TakeScalarZnx, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAddScalarInplace, + VecZnxAutomorphism, VecZnxBigAllocBytes, VecZnxDftToVecZnxBigTmpA, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, + VecZnxSwithcDegree, ZnxZero, }, layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch}, }; @@ -18,15 +18,15 @@ use crate::{ pub trait GGLWEEncryptSkFamily = GLWEEncryptSkFamily + GLWESecretFamily; impl GGLWECiphertext> { - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, n: usize, basek: usize, k: usize) -> usize where - Module: GGLWEEncryptSkFamily + VecZnxAllocBytes, + Module: GGLWEEncryptSkFamily, { - GLWECiphertext::encrypt_sk_scratch_space(module, basek, k) - + (GLWEPlaintext::byte_of(module, basek, k) | module.vec_znx_normalize_tmp_bytes(module.n())) + GLWECiphertext::encrypt_sk_scratch_space(module, n, basek, k) + + (GLWEPlaintext::byte_of(n, basek, k) | module.vec_znx_normalize_tmp_bytes(n)) } - pub fn encrypt_pk_scratch_space(_module: &Module, _basek: usize, _k: usize, _rank: usize) -> usize { + pub fn encrypt_pk_scratch_space(_module: &Module, _n: usize, _basek: usize, _k: usize, _rank: usize) -> usize { unimplemented!() } } @@ -42,8 +42,8 @@ impl GGLWECiphertext { sigma: f64, scratch: &mut Scratch, ) where - Module: GGLWEEncryptSkFamily + VecZnxAllocBytes + VecZnxAddScalarInplace, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + Module: GGLWEEncryptSkFamily + VecZnxAddScalarInplace, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { #[cfg(debug_assertions)] { @@ -63,16 +63,15 @@ impl GGLWECiphertext { self.rank_out(), sk.rank() ); - assert_eq!(self.n(), module.n()); - assert_eq!(sk.n(), module.n()); - assert_eq!(pt.n(), module.n()); + assert_eq!(self.n(), sk.n()); + assert_eq!(pt.n(), sk.n()); assert!( - scratch.available() >= GGLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k()), + scratch.available() >= GGLWECiphertext::encrypt_sk_scratch_space(module, sk.n(), self.basek(), self.k()), "scratch.available: {} < GGLWECiphertext::encrypt_sk_scratch_space(module, self.rank()={}, self.size()={}): {}", scratch.available(), self.rank(), self.size(), - GGLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k()) + GGLWECiphertext::encrypt_sk_scratch_space(module, sk.n(), self.basek(), self.k()) ); assert!( self.rows() * self.digits() * self.basek() <= self.k(), @@ -91,7 +90,7 @@ impl GGLWECiphertext { let k: usize = self.k(); let rank_in: usize = self.rank_in(); - let (mut tmp_pt, scrach_1) = scratch.take_glwe_pt(module, basek, k); + let (mut tmp_pt, scrach_1) = scratch.take_glwe_pt(sk.n(), basek, k); // For each input column (i.e. rank) produces a GGLWE ciphertext of rank_out+1 columns // // Example for ksk rank 2 to rank 3: @@ -125,11 +124,11 @@ impl GGLWECiphertext { } impl GGLWECiphertextCompressed> { - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, n: usize, basek: usize, k: usize) -> usize where - Module: GLWESwitchingKeyEncryptSkFamily + VecZnxAllocBytes, + Module: GLWESwitchingKeyEncryptSkFamily, { - GGLWECiphertext::encrypt_sk_scratch_space(module, basek, k) + GGLWECiphertext::encrypt_sk_scratch_space(module, n, basek, k) } } @@ -144,8 +143,8 @@ impl GGLWECiphertextCompressed { sigma: f64, scratch: &mut Scratch, ) where - Module: GGLWEEncryptSkFamily + VecZnxAllocBytes + VecZnxAddScalarInplace, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + Module: GGLWEEncryptSkFamily + VecZnxAddScalarInplace, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { #[cfg(debug_assertions)] { @@ -165,16 +164,16 @@ impl GGLWECiphertextCompressed { self.rank_out(), sk.rank() ); - assert_eq!(self.n(), module.n()); - assert_eq!(sk.n(), module.n()); - assert_eq!(pt.n(), module.n()); + assert_eq!(self.n(), sk.n()); + assert_eq!(pt.n(), sk.n()); assert!( - scratch.available() >= GGLWECiphertextCompressed::encrypt_sk_scratch_space(module, self.basek(), self.k()), + scratch.available() + >= GGLWECiphertextCompressed::encrypt_sk_scratch_space(module, sk.n(), self.basek(), self.k()), "scratch.available: {} < GGLWECiphertext::encrypt_sk_scratch_space(module, self.rank()={}, self.size()={}): {}", scratch.available(), self.rank(), self.size(), - GGLWECiphertextCompressed::encrypt_sk_scratch_space(module, self.basek(), self.k()) + GGLWECiphertextCompressed::encrypt_sk_scratch_space(module, sk.n(), self.basek(), self.k()) ); assert!( self.rows() * self.digits() * self.basek() <= self.k(), @@ -196,7 +195,7 @@ impl GGLWECiphertextCompressed { let mut source_xa = Source::new(seed); - let (mut tmp_pt, scrach_1) = scratch.take_glwe_pt(module, basek, k); + let (mut tmp_pt, scrach_1) = scratch.take_glwe_pt(sk.n(), basek, k); (0..rank_in).for_each(|col_i| { (0..rows).for_each(|row_i| { // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt @@ -237,27 +236,29 @@ pub trait GLWESwitchingKeyEncryptSkFamily = GGLWEEncryptSkFamily; impl GLWESwitchingKey> { pub fn encrypt_sk_scratch_space( module: &Module, + n: usize, basek: usize, k: usize, rank_in: usize, rank_out: usize, ) -> usize where - Module: GLWESwitchingKeyEncryptSkFamily + ScalarZnxAllocBytes + VecZnxAllocBytes, + Module: GLWESwitchingKeyEncryptSkFamily, { - (GGLWECiphertext::encrypt_sk_scratch_space(module, basek, k) | module.scalar_znx_alloc_bytes(1)) - + module.scalar_znx_alloc_bytes(rank_in) - + GLWESecretExec::bytes_of(module, rank_out) + (GGLWECiphertext::encrypt_sk_scratch_space(module, n, basek, k) | ScalarZnx::alloc_bytes(n, 1)) + + ScalarZnx::alloc_bytes(n, rank_in) + + GLWESecretExec::bytes_of(module, n, rank_out) } pub fn encrypt_pk_scratch_space( module: &Module, + _n: usize, _basek: usize, _k: usize, _rank_in: usize, _rank_out: usize, ) -> usize { - GGLWECiphertext::encrypt_pk_scratch_space(module, _basek, _k, _rank_out) + GGLWECiphertext::encrypt_pk_scratch_space(module, _n, _basek, _k, _rank_out) } } @@ -272,13 +273,8 @@ impl GLWESwitchingKey { sigma: f64, scratch: &mut Scratch, ) where - Module: GLWESwitchingKeyEncryptSkFamily - + ScalarZnxAllocBytes - + VecZnxSwithcDegree - + VecZnxAllocBytes - + VecZnxAddScalarInplace, - Scratch: - ScratchAvailable + TakeScalarZnx + TakeVecZnxDft + TakeGLWESecretExec + ScratchAvailable + TakeVecZnx, + Module: GLWESwitchingKeyEncryptSkFamily + VecZnxSwithcDegree + VecZnxAddScalarInplace, + Scratch: ScratchAvailable + TakeScalarZnx + TakeVecZnxDft + TakeGLWESecretExec + ScratchAvailable + TakeVecZnx, { #[cfg(debug_assertions)] { @@ -288,6 +284,7 @@ impl GLWESwitchingKey { scratch.available() >= GLWESwitchingKey::encrypt_sk_scratch_space( module, + sk_out.n(), self.basek(), self.k(), self.rank_in(), @@ -297,6 +294,7 @@ impl GLWESwitchingKey { scratch.available(), GLWESwitchingKey::encrypt_sk_scratch_space( module, + sk_out.n(), self.basek(), self.k(), self.rank_in(), @@ -305,7 +303,9 @@ impl GLWESwitchingKey { ) } - let (mut sk_in_tmp, scratch1) = scratch.take_scalar_znx(module, sk_in.rank()); + let n: usize = sk_in.n().max(sk_out.n()); + + let (mut sk_in_tmp, scratch1) = scratch.take_scalar_znx(n, sk_in.rank()); (0..sk_in.rank()).for_each(|i| { module.vec_znx_switch_degree( &mut sk_in_tmp.as_vec_znx_mut(), @@ -315,9 +315,9 @@ impl GLWESwitchingKey { ); }); - let (mut sk_out_tmp, scratch2) = scratch1.take_glwe_secret_exec(module, sk_out.rank()); + let (mut sk_out_tmp, scratch2) = scratch1.take_glwe_secret_exec(n, sk_out.rank()); { - let (mut tmp, _) = scratch2.take_scalar_znx(module, 1); + let (mut tmp, _) = scratch2.take_scalar_znx(n, 1); (0..sk_out.rank()).for_each(|i| { module.vec_znx_switch_degree(&mut tmp.as_vec_znx_mut(), 0, &sk_out.data.as_vec_znx(), i); module.svp_prepare(&mut sk_out_tmp.data, i, &tmp, 0); @@ -341,17 +341,18 @@ impl GLWESwitchingKey { impl GLWESwitchingKeyCompressed> { pub fn encrypt_sk_scratch_space( module: &Module, + n: usize, basek: usize, k: usize, rank_in: usize, rank_out: usize, ) -> usize where - Module: GLWESwitchingKeyEncryptSkFamily + ScalarZnxAllocBytes + VecZnxAllocBytes, + Module: GLWESwitchingKeyEncryptSkFamily, { - (GGLWECiphertext::encrypt_sk_scratch_space(module, basek, k) | module.scalar_znx_alloc_bytes(1)) - + module.scalar_znx_alloc_bytes(rank_in) - + GLWESecretExec::bytes_of(module, rank_out) + (GGLWECiphertext::encrypt_sk_scratch_space(module, n, basek, k) | ScalarZnx::alloc_bytes(n, 1)) + + ScalarZnx::alloc_bytes(n, rank_in) + + GLWESecretExec::bytes_of(module, n, rank_out) } } @@ -366,13 +367,8 @@ impl GLWESwitchingKeyCompressed { sigma: f64, scratch: &mut Scratch, ) where - Module: GLWESwitchingKeyEncryptSkFamily - + ScalarZnxAllocBytes - + VecZnxSwithcDegree - + VecZnxAllocBytes - + VecZnxAddScalarInplace, - Scratch: - ScratchAvailable + TakeScalarZnx + TakeVecZnxDft + TakeGLWESecretExec + ScratchAvailable + TakeVecZnx, + Module: GLWESwitchingKeyEncryptSkFamily + VecZnxSwithcDegree + VecZnxAddScalarInplace, + Scratch: ScratchAvailable + TakeScalarZnx + TakeVecZnxDft + TakeGLWESecretExec + ScratchAvailable + TakeVecZnx, { #[cfg(debug_assertions)] { @@ -382,6 +378,7 @@ impl GLWESwitchingKeyCompressed { scratch.available() >= GLWESwitchingKey::encrypt_sk_scratch_space( module, + sk_out.n(), self.basek(), self.k(), self.rank_in(), @@ -391,6 +388,7 @@ impl GLWESwitchingKeyCompressed { scratch.available(), GLWESwitchingKey::encrypt_sk_scratch_space( module, + sk_out.n(), self.basek(), self.k(), self.rank_in(), @@ -399,7 +397,9 @@ impl GLWESwitchingKeyCompressed { ) } - let (mut sk_in_tmp, scratch1) = scratch.take_scalar_znx(module, sk_in.rank()); + let n: usize = sk_in.n().max(sk_out.n()); + + let (mut sk_in_tmp, scratch1) = scratch.take_scalar_znx(n, sk_in.rank()); (0..sk_in.rank()).for_each(|i| { module.vec_znx_switch_degree( &mut sk_in_tmp.as_vec_znx_mut(), @@ -409,9 +409,9 @@ impl GLWESwitchingKeyCompressed { ); }); - let (mut sk_out_tmp, scratch2) = scratch1.take_glwe_secret_exec(module, sk_out.rank()); + let (mut sk_out_tmp, scratch2) = scratch1.take_glwe_secret_exec(n, sk_out.rank()); { - let (mut tmp, _) = scratch2.take_scalar_znx(module, 1); + let (mut tmp, _) = scratch2.take_scalar_znx(n, 1); (0..sk_out.rank()).for_each(|i| { module.vec_znx_switch_degree(&mut tmp.as_vec_znx_mut(), 0, &sk_out.data.as_vec_znx(), i); module.svp_prepare(&mut sk_out_tmp.data, i, &tmp, 0); @@ -435,15 +435,15 @@ impl GLWESwitchingKeyCompressed { pub trait AutomorphismKeyEncryptSkFamily = GGLWEEncryptSkFamily; impl AutomorphismKey> { - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, n: usize, basek: usize, k: usize, rank: usize) -> usize where - Module: AutomorphismKeyEncryptSkFamily + ScalarZnxAllocBytes + VecZnxAllocBytes, + Module: AutomorphismKeyEncryptSkFamily, { - GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, rank, rank) + GLWESecret::bytes_of(module, rank) + GLWESwitchingKey::encrypt_sk_scratch_space(module, n, basek, k, rank, rank) + GLWESecret::bytes_of(n, rank) } - pub fn encrypt_pk_scratch_space(module: &Module, _basek: usize, _k: usize, _rank: usize) -> usize { - GLWESwitchingKey::encrypt_pk_scratch_space(module, _basek, _k, _rank, _rank) + pub fn encrypt_pk_scratch_space(module: &Module, _n: usize, _basek: usize, _k: usize, _rank: usize) -> usize { + GLWESwitchingKey::encrypt_pk_scratch_space(module, _n, _basek, _k, _rank, _rank) } } @@ -458,31 +458,26 @@ impl AutomorphismKey { sigma: f64, scratch: &mut Scratch, ) where - Module: AutomorphismKeyEncryptSkFamily - + ScalarZnxAllocBytes - + VecZnxAllocBytes - + VecZnxAutomorphism - + VecZnxSwithcDegree - + VecZnxAddScalarInplace, - Scratch: ScratchAvailable + TakeScalarZnx + TakeVecZnxDft + TakeGLWESecretExec + TakeVecZnx, + Module: AutomorphismKeyEncryptSkFamily + VecZnxAutomorphism + VecZnxSwithcDegree + VecZnxAddScalarInplace, + Scratch: ScratchAvailable + TakeScalarZnx + TakeVecZnxDft + TakeGLWESecretExec + TakeVecZnx, { #[cfg(debug_assertions)] { - assert_eq!(self.n(), module.n()); - assert_eq!(sk.n(), module.n()); + assert_eq!(self.n(), sk.n()); assert_eq!(self.rank_out(), self.rank_in()); assert_eq!(sk.rank(), self.rank()); assert!( - scratch.available() >= AutomorphismKey::encrypt_sk_scratch_space(module, self.basek(), self.k(), self.rank()), + scratch.available() + >= AutomorphismKey::encrypt_sk_scratch_space(module, sk.n(), self.basek(), self.k(), self.rank()), "scratch.available(): {} < AutomorphismKey::encrypt_sk_scratch_space(module, self.rank()={}, self.size()={}): {}", scratch.available(), self.rank(), self.size(), - AutomorphismKey::encrypt_sk_scratch_space(module, self.basek(), self.k(), self.rank()) + AutomorphismKey::encrypt_sk_scratch_space(module, sk.n(), self.basek(), self.k(), self.rank()) ) } - let (mut sk_out, scratch_1) = scratch.take_glwe_secret(module, sk.rank()); + let (mut sk_out, scratch_1) = scratch.take_glwe_secret(sk.n(), sk.rank()); { (0..self.rank()).for_each(|i| { @@ -504,11 +499,11 @@ impl AutomorphismKey { } impl AutomorphismKeyCompressed> { - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, n: usize, basek: usize, k: usize, rank: usize) -> usize where - Module: AutomorphismKeyEncryptSkFamily + ScalarZnxAllocBytes + VecZnxAllocBytes, + Module: AutomorphismKeyEncryptSkFamily, { - GLWESwitchingKeyCompressed::encrypt_sk_scratch_space(module, basek, k, rank, rank) + GLWESecret::bytes_of(module, rank) + GLWESwitchingKeyCompressed::encrypt_sk_scratch_space(module, n, basek, k, rank, rank) + GLWESecret::bytes_of(n, rank) } } @@ -523,32 +518,26 @@ impl AutomorphismKeyCompressed { sigma: f64, scratch: &mut Scratch, ) where - Module: AutomorphismKeyEncryptSkFamily - + ScalarZnxAllocBytes - + VecZnxAllocBytes - + VecZnxSwithcDegree - + VecZnxAutomorphism - + VecZnxAddScalarInplace, - Scratch: ScratchAvailable + TakeScalarZnx + TakeVecZnxDft + TakeGLWESecretExec + TakeVecZnx, + Module: AutomorphismKeyEncryptSkFamily + VecZnxSwithcDegree + VecZnxAutomorphism + VecZnxAddScalarInplace, + Scratch: ScratchAvailable + TakeScalarZnx + TakeVecZnxDft + TakeGLWESecretExec + TakeVecZnx, { #[cfg(debug_assertions)] { - assert_eq!(self.n(), module.n()); - assert_eq!(sk.n(), module.n()); + assert_eq!(self.n(), sk.n()); assert_eq!(self.rank_out(), self.rank_in()); assert_eq!(sk.rank(), self.rank()); assert!( scratch.available() - >= AutomorphismKeyCompressed::encrypt_sk_scratch_space(module, self.basek(), self.k(), self.rank()), + >= AutomorphismKeyCompressed::encrypt_sk_scratch_space(module, sk.n(), self.basek(), self.k(), self.rank()), "scratch.available(): {} < AutomorphismKey::encrypt_sk_scratch_space(module, self.rank()={}, self.size()={}): {}", scratch.available(), self.rank(), self.size(), - AutomorphismKeyCompressed::encrypt_sk_scratch_space(module, self.basek(), self.k(), self.rank()) + AutomorphismKeyCompressed::encrypt_sk_scratch_space(module, sk.n(), self.basek(), self.k(), self.rank()) ) } - let (mut sk_out, scratch_1) = scratch.take_glwe_secret(module, sk.rank()); + let (mut sk_out, scratch_1) = scratch.take_glwe_secret(sk.n(), sk.rank()); { (0..self.rank()).for_each(|i| { @@ -573,16 +562,16 @@ pub trait GLWETensorKeyEncryptSkFamily = GGLWEEncryptSkFamily + VecZnxBigAllocBytes + VecZnxDftToVecZnxBigTmpA + SvpApply; impl GLWETensorKey> { - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, n: usize, basek: usize, k: usize, rank: usize) -> usize where - Module: GLWETensorKeyEncryptSkFamily + ScalarZnxAllocBytes + VecZnxAllocBytes, + Module: GLWETensorKeyEncryptSkFamily, { - GLWESecretExec::bytes_of(module, rank) - + module.vec_znx_dft_alloc_bytes(rank, 1) - + module.vec_znx_big_alloc_bytes(1, 1) - + module.vec_znx_dft_alloc_bytes(1, 1) - + GLWESecret::bytes_of(module, 1) - + GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, rank, rank) + GLWESecretExec::bytes_of(module, n, rank) + + module.vec_znx_dft_alloc_bytes(n, rank, 1) + + module.vec_znx_big_alloc_bytes(n, 1, 1) + + module.vec_znx_dft_alloc_bytes(n, 1, 1) + + GLWESecret::bytes_of(n, 1) + + GLWESwitchingKey::encrypt_sk_scratch_space(module, n, basek, k, rank, rank) } } @@ -596,35 +585,31 @@ impl GLWETensorKey { sigma: f64, scratch: &mut Scratch, ) where - Module: GLWETensorKeyEncryptSkFamily - + ScalarZnxAllocBytes - + VecZnxSwithcDegree - + VecZnxAllocBytes - + VecZnxAddScalarInplace, - Scratch: - ScratchAvailable + TakeVecZnxDft + TakeVecZnxBig + TakeGLWESecretExec + TakeScalarZnx + TakeVecZnx, + Module: GLWETensorKeyEncryptSkFamily + VecZnxSwithcDegree + VecZnxAddScalarInplace, + Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnxBig + TakeGLWESecretExec + TakeScalarZnx + TakeVecZnx, { #[cfg(debug_assertions)] { assert_eq!(self.rank(), sk.rank()); - assert_eq!(self.n(), module.n()); - assert_eq!(sk.n(), module.n()); + assert_eq!(self.n(), sk.n()); } + let n: usize = sk.n(); + let rank: usize = self.rank(); - let (mut sk_dft_prep, scratch1) = scratch.take_glwe_secret_exec(module, rank); + let (mut sk_dft_prep, scratch1) = scratch.take_glwe_secret_exec(n, rank); sk_dft_prep.prepare(module, &sk); - let (mut sk_dft, scratch2) = scratch1.take_vec_znx_dft(module, rank, 1); + let (mut sk_dft, scratch2) = scratch1.take_vec_znx_dft(n, rank, 1); (0..rank).for_each(|i| { module.vec_znx_dft_from_vec_znx(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); }); - let (mut sk_ij_big, scratch3) = scratch2.take_vec_znx_big(module, 1, 1); - let (mut sk_ij, scratch4) = scratch3.take_glwe_secret(module, 1); - let (mut sk_ij_dft, scratch5) = scratch4.take_vec_znx_dft(module, 1, 1); + let (mut sk_ij_big, scratch3) = scratch2.take_vec_znx_big(n, 1, 1); + let (mut sk_ij, scratch4) = scratch3.take_glwe_secret(n, 1); + let (mut sk_ij_dft, scratch5) = scratch4.take_vec_znx_dft(n, 1, 1); (0..rank).for_each(|i| { (i..rank).for_each(|j| { @@ -648,11 +633,11 @@ impl GLWETensorKey { } impl GLWETensorKeyCompressed> { - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, n: usize, basek: usize, k: usize, rank: usize) -> usize where - Module: GLWETensorKeyEncryptSkFamily + ScalarZnxAllocBytes + VecZnxAllocBytes, + Module: GLWETensorKeyEncryptSkFamily, { - GLWETensorKey::encrypt_sk_scratch_space(module, basek, k, rank) + GLWETensorKey::encrypt_sk_scratch_space(module, n, basek, k, rank) } } @@ -666,35 +651,30 @@ impl GLWETensorKeyCompressed { sigma: f64, scratch: &mut Scratch, ) where - Module: GLWETensorKeyEncryptSkFamily - + ScalarZnxAllocBytes - + VecZnxSwithcDegree - + VecZnxAllocBytes - + VecZnxAddScalarInplace, - Scratch: - ScratchAvailable + TakeVecZnxDft + TakeVecZnxBig + TakeGLWESecretExec + TakeScalarZnx + TakeVecZnx, + Module: GLWETensorKeyEncryptSkFamily + VecZnxSwithcDegree + VecZnxAddScalarInplace, + Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnxBig + TakeGLWESecretExec + TakeScalarZnx + TakeVecZnx, { #[cfg(debug_assertions)] { assert_eq!(self.rank(), sk.rank()); - assert_eq!(self.n(), module.n()); - assert_eq!(sk.n(), module.n()); + assert_eq!(self.n(), sk.n()); } + let n: usize = sk.n(); let rank: usize = self.rank(); - let (mut sk_dft_prep, scratch1) = scratch.take_glwe_secret_exec(module, rank); + let (mut sk_dft_prep, scratch1) = scratch.take_glwe_secret_exec(n, rank); sk_dft_prep.prepare(module, &sk); - let (mut sk_dft, scratch2) = scratch1.take_vec_znx_dft(module, rank, 1); + let (mut sk_dft, scratch2) = scratch1.take_vec_znx_dft(n, rank, 1); (0..rank).for_each(|i| { module.vec_znx_dft_from_vec_znx(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); }); - let (mut sk_ij_big, scratch3) = scratch2.take_vec_znx_big(module, 1, 1); - let (mut sk_ij, scratch4) = scratch3.take_glwe_secret(module, 1); - let (mut sk_ij_dft, scratch5) = scratch4.take_vec_znx_dft(module, 1, 1); + let (mut sk_ij_big, scratch3) = scratch2.take_vec_znx_big(n, 1, 1); + let (mut sk_ij, scratch4) = scratch3.take_glwe_secret(n, 1); + let (mut sk_ij_dft, scratch5) = scratch4.take_vec_znx_dft(n, 1, 1); let mut source_xa: Source = Source::new(seed_xa); diff --git a/core/src/gglwe/external_product.rs b/core/src/gglwe/external_product.rs index b067a58..4b47789 100644 --- a/core/src/gglwe/external_product.rs +++ b/core/src/gglwe/external_product.rs @@ -8,6 +8,7 @@ use crate::{AutomorphismKey, GGSWCiphertextExec, GLWECiphertext, GLWEExternalPro impl GLWESwitchingKey> { pub fn external_product_scratch_space( module: &Module, + n: usize, basek: usize, k_out: usize, k_in: usize, @@ -18,11 +19,12 @@ impl GLWESwitchingKey> { where Module: GLWEExternalProductFamily, { - GLWECiphertext::external_product_scratch_space(module, basek, k_out, k_in, k_ggsw, digits, rank) + GLWECiphertext::external_product_scratch_space(module, n, basek, k_out, k_in, k_ggsw, digits, rank) } pub fn external_product_inplace_scratch_space( module: &Module, + n: usize, basek: usize, k_out: usize, k_ggsw: usize, @@ -32,7 +34,7 @@ impl GLWESwitchingKey> { where Module: GLWEExternalProductFamily, { - GLWECiphertext::external_product_inplace_scratch_space(module, basek, k_out, k_ggsw, digits, rank) + GLWECiphertext::external_product_inplace_scratch_space(module, n, basek, k_out, k_ggsw, digits, rank) } } @@ -118,6 +120,7 @@ impl GLWESwitchingKey { impl AutomorphismKey> { pub fn external_product_scratch_space( module: &Module, + n: usize, basek: usize, k_out: usize, k_in: usize, @@ -128,11 +131,12 @@ impl AutomorphismKey> { where Module: GLWEExternalProductFamily, { - GLWESwitchingKey::external_product_scratch_space(module, basek, k_out, k_in, ggsw_k, digits, rank) + GLWESwitchingKey::external_product_scratch_space(module, n, basek, k_out, k_in, ggsw_k, digits, rank) } pub fn external_product_inplace_scratch_space( module: &Module, + n: usize, basek: usize, k_out: usize, ggsw_k: usize, @@ -142,7 +146,7 @@ impl AutomorphismKey> { where Module: GLWEExternalProductFamily, { - GLWESwitchingKey::external_product_inplace_scratch_space(module, basek, k_out, ggsw_k, digits, rank) + GLWESwitchingKey::external_product_inplace_scratch_space(module, n, basek, k_out, ggsw_k, digits, rank) } } diff --git a/core/src/gglwe/keyswitch.rs b/core/src/gglwe/keyswitch.rs index 0ddbb64..e210419 100644 --- a/core/src/gglwe/keyswitch.rs +++ b/core/src/gglwe/keyswitch.rs @@ -10,6 +10,7 @@ use crate::{ impl AutomorphismKey> { pub fn keyswitch_scratch_space( module: &Module, + n: usize, basek: usize, k_out: usize, k_in: usize, @@ -20,11 +21,12 @@ impl AutomorphismKey> { where Module: GLWEKeyswitchFamily, { - GLWESwitchingKey::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank, rank) + GLWESwitchingKey::keyswitch_scratch_space(module, n, basek, k_out, k_in, k_ksk, digits, rank, rank) } pub fn keyswitch_inplace_scratch_space( module: &Module, + n: usize, basek: usize, k_out: usize, k_ksk: usize, @@ -34,7 +36,7 @@ impl AutomorphismKey> { where Module: GLWEKeyswitchFamily, { - GLWESwitchingKey::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank) + GLWESwitchingKey::keyswitch_inplace_scratch_space(module, n, basek, k_out, k_ksk, digits, rank) } } @@ -68,6 +70,7 @@ impl AutomorphismKey { impl GLWESwitchingKey> { pub fn keyswitch_scratch_space( module: &Module, + n: usize, basek: usize, k_out: usize, k_in: usize, @@ -79,11 +82,14 @@ impl GLWESwitchingKey> { where Module: GLWEKeyswitchFamily, { - GLWECiphertext::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank_in, rank_out) + GLWECiphertext::keyswitch_scratch_space( + module, n, basek, k_out, k_in, k_ksk, digits, rank_in, rank_out, + ) } pub fn keyswitch_inplace_scratch_space( module: &Module, + n: usize, basek: usize, k_out: usize, k_ksk: usize, @@ -93,7 +99,7 @@ impl GLWESwitchingKey> { where Module: GLWEKeyswitchFamily, { - GLWECiphertext::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank) + GLWECiphertext::keyswitch_inplace_scratch_space(module, n, basek, k_out, k_ksk, digits, rank) } } diff --git a/core/src/gglwe/layout.rs b/core/src/gglwe/layout.rs index b0d9f2d..6659641 100644 --- a/core/src/gglwe/layout.rs +++ b/core/src/gglwe/layout.rs @@ -1,14 +1,15 @@ use backend::hal::{ - api::{MatZnxAlloc, MatZnxAllocBytes, VmpPMatAlloc, VmpPMatAllocBytes, VmpPMatPrepare}, - layouts::{Backend, Data, DataMut, DataRef, MatZnx, Module, ReaderFrom, WriterTo}, + api::{FillUniform, Reset, VmpPMatAlloc, VmpPMatAllocBytes, VmpPMatPrepare}, + layouts::{Backend, Data, DataMut, DataRef, MatZnx, ReaderFrom, WriterTo}, }; use crate::{GLWECiphertext, Infos}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; pub trait GGLWEExecLayoutFamily = VmpPMatAlloc + VmpPMatAllocBytes + VmpPMatPrepare; +use std::fmt; -#[derive(PartialEq, Eq)] +#[derive(PartialEq, Eq, Clone)] pub struct GGLWECiphertext { pub(crate) data: MatZnx, pub(crate) basek: usize, @@ -16,6 +17,40 @@ pub struct GGLWECiphertext { pub(crate) digits: usize, } +impl fmt::Debug for GGLWECiphertext { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self) + } +} + +impl FillUniform for GGLWECiphertext { + fn fill_uniform(&mut self, source: &mut sampling::source::Source) { + self.data.fill_uniform(source); + } +} + +impl Reset for GGLWECiphertext +where + MatZnx: Reset, +{ + fn reset(&mut self) { + self.data.reset(); + self.basek = 0; + self.k = 0; + self.digits = 0; + } +} + +impl fmt::Display for GGLWECiphertext { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "(GGLWECiphertext: basek={} k={} digits={}) {}", + self.basek, self.k, self.digits, self.data + ) + } +} + impl GGLWECiphertext { pub fn at(&self, row: usize, col: usize) -> GLWECiphertext<&[u8]> { GLWECiphertext { @@ -37,18 +72,7 @@ impl GGLWECiphertext { } impl GGLWECiphertext> { - pub fn alloc( - module: &Module, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> Self - where - Module: MatZnxAlloc, - { + pub fn alloc(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize, rank_out: usize) -> Self { let size: usize = k.div_ceil(basek); debug_assert!( size > digits, @@ -66,25 +90,14 @@ impl GGLWECiphertext> { ); Self { - data: module.mat_znx_alloc(rows, rank_in, rank_out + 1, size), + data: MatZnx::alloc(n, rows, rank_in, rank_out + 1, size), basek: basek, k, digits, } } - pub fn bytes_of( - module: &Module, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> usize - where - Module: MatZnxAllocBytes, - { + pub fn bytes_of(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize, rank_out: usize) -> usize { let size: usize = k.div_ceil(basek); debug_assert!( size > digits, @@ -101,7 +114,7 @@ impl GGLWECiphertext> { size ); - module.mat_znx_alloc_bytes(rows, rank_in, rank_out + 1, rows) + MatZnx::alloc_bytes(n, rows, rank_in, rank_out + 1, rows) } } @@ -157,46 +170,57 @@ impl WriterTo for GGLWECiphertext { } } -#[derive(PartialEq, Eq)] +#[derive(PartialEq, Eq, Clone)] pub struct GLWESwitchingKey { pub(crate) key: GGLWECiphertext, pub(crate) sk_in_n: usize, // Degree of sk_in pub(crate) sk_out_n: usize, // Degree of sk_out } +impl fmt::Debug for GLWESwitchingKey { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self) + } +} + +impl fmt::Display for GLWESwitchingKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "(GLWESwitchingKey: sk_in_n={} sk_out_n={}) {}", + self.sk_in_n, self.sk_out_n, self.key.data + ) + } +} + +impl FillUniform for GLWESwitchingKey { + fn fill_uniform(&mut self, source: &mut sampling::source::Source) { + self.key.fill_uniform(source); + } +} + +impl Reset for GLWESwitchingKey +where + MatZnx: Reset, +{ + fn reset(&mut self) { + self.key.reset(); + self.sk_in_n = 0; + self.sk_out_n = 0; + } +} + impl GLWESwitchingKey> { - pub fn alloc( - module: &Module, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> Self - where - Module: MatZnxAlloc, - { + pub fn alloc(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize, rank_out: usize) -> Self { GLWESwitchingKey { - key: GGLWECiphertext::alloc(module, basek, k, rows, digits, rank_in, rank_out), + key: GGLWECiphertext::alloc(n, basek, k, rows, digits, rank_in, rank_out), sk_in_n: 0, sk_out_n: 0, } } - pub fn bytes_of( - module: &Module, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> usize - where - Module: MatZnxAllocBytes, - { - GGLWECiphertext::>::bytes_of(module, basek, k, rows, digits, rank_in, rank_out) + pub fn bytes_of(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize, rank_out: usize) -> usize { + GGLWECiphertext::>::bytes_of(n, basek, k, rows, digits, rank_in, rank_out) } } @@ -270,28 +294,50 @@ impl WriterTo for GLWESwitchingKey { } } -#[derive(PartialEq, Eq)] +#[derive(PartialEq, Eq, Clone)] pub struct AutomorphismKey { pub(crate) key: GLWESwitchingKey, pub(crate) p: i64, } +impl fmt::Debug for AutomorphismKey { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self) + } +} + +impl FillUniform for AutomorphismKey { + fn fill_uniform(&mut self, source: &mut sampling::source::Source) { + self.key.fill_uniform(source); + } +} + +impl Reset for AutomorphismKey +where + MatZnx: Reset, +{ + fn reset(&mut self) { + self.key.reset(); + self.p = 0; + } +} + +impl fmt::Display for AutomorphismKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "(AutomorphismKey: p={}) {}", self.p, self.key) + } +} + impl AutomorphismKey> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self - where - Module: MatZnxAlloc, - { + pub fn alloc(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self { AutomorphismKey { - key: GLWESwitchingKey::alloc(module, basek, k, rows, digits, rank, rank), + key: GLWESwitchingKey::alloc(n, basek, k, rows, digits, rank, rank), p: 0, } } - pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize - where - Module: MatZnxAllocBytes, - { - GLWESwitchingKey::>::bytes_of(module, basek, k, rows, digits, rank, rank) + pub fn bytes_of(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize { + GLWESwitchingKey::bytes_of(n, basek, k, rows, digits, rank, rank) } } @@ -359,32 +405,59 @@ impl WriterTo for AutomorphismKey { } } -#[derive(PartialEq, Eq)] +#[derive(PartialEq, Eq, Clone)] pub struct GLWETensorKey { pub(crate) keys: Vec>, } +impl fmt::Debug for GLWETensorKey { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self) + } +} + +impl FillUniform for GLWETensorKey { + fn fill_uniform(&mut self, source: &mut sampling::source::Source) { + self.keys + .iter_mut() + .for_each(|key: &mut GLWESwitchingKey| key.fill_uniform(source)) + } +} + +impl Reset for GLWETensorKey +where + MatZnx: Reset, +{ + fn reset(&mut self) { + self.keys + .iter_mut() + .for_each(|key: &mut GLWESwitchingKey| key.reset()) + } +} + +impl fmt::Display for GLWETensorKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!(f, "(GLWETensorKey)",)?; + for (i, key) in self.keys.iter().enumerate() { + write!(f, "{}: {}", i, key)?; + } + Ok(()) + } +} + impl GLWETensorKey> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self - where - Module: MatZnxAlloc, - { + pub fn alloc(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self { let mut keys: Vec>> = Vec::new(); let pairs: usize = (((rank + 1) * rank) >> 1).max(1); (0..pairs).for_each(|_| { - keys.push(GLWESwitchingKey::alloc( - module, basek, k, rows, digits, 1, rank, - )); + keys.push(GLWESwitchingKey::alloc(n, basek, k, rows, digits, 1, rank)); }); Self { keys: keys } } - pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize - where - Module: MatZnxAllocBytes, - { + pub fn bytes_of(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize { let pairs: usize = (((rank + 1) * rank) >> 1).max(1); - pairs * GLWESwitchingKey::>::bytes_of(module, basek, k, rows, digits, 1, rank) + pairs * GLWESwitchingKey::>::bytes_of(n, basek, k, rows, digits, 1, rank) } } diff --git a/core/src/gglwe/layouts_compressed.rs b/core/src/gglwe/layouts_compressed.rs index 305d812..8b85629 100644 --- a/core/src/gglwe/layouts_compressed.rs +++ b/core/src/gglwe/layouts_compressed.rs @@ -1,12 +1,13 @@ use backend::hal::{ - api::{MatZnxAlloc, MatZnxAllocBytes, VecZnxCopy, VecZnxFillUniform}, + api::{FillUniform, Reset, VecZnxCopy, VecZnxFillUniform}, layouts::{Backend, Data, DataMut, DataRef, MatZnx, Module, ReaderFrom, WriterTo}, }; use crate::{AutomorphismKey, Decompress, GGLWECiphertext, GLWECiphertextCompressed, GLWESwitchingKey, GLWETensorKey, Infos}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; +use std::fmt; -#[derive(PartialEq, Eq)] +#[derive(PartialEq, Eq, Clone)] pub struct GGLWECiphertextCompressed { pub(crate) data: MatZnx, pub(crate) basek: usize, @@ -16,19 +17,44 @@ pub struct GGLWECiphertextCompressed { pub(crate) seed: Vec<[u8; 32]>, } +impl fmt::Debug for GGLWECiphertextCompressed { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self) + } +} + +impl FillUniform for GGLWECiphertextCompressed { + fn fill_uniform(&mut self, source: &mut sampling::source::Source) { + self.data.fill_uniform(source); + } +} + +impl Reset for GGLWECiphertextCompressed +where + MatZnx: Reset, +{ + fn reset(&mut self) { + self.data.reset(); + self.basek = 0; + self.k = 0; + self.digits = 0; + self.rank_out = 0; + self.seed = Vec::new(); + } +} + +impl fmt::Display for GGLWECiphertextCompressed { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "(GGLWECiphertextCompressed: basek={} k={} digits={}) {}", + self.basek, self.k, self.digits, self.data + ) + } +} + impl GGLWECiphertextCompressed> { - pub fn alloc( - module: &Module, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> Self - where - Module: MatZnxAlloc, - { + pub fn alloc(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize, rank_out: usize) -> Self { let size: usize = k.div_ceil(basek); debug_assert!( size > digits, @@ -46,7 +72,7 @@ impl GGLWECiphertextCompressed> { ); Self { - data: module.mat_znx_alloc(rows, rank_in, 1, size), + data: MatZnx::alloc(n, rows, rank_in, 1, size), basek: basek, k, rank_out, @@ -55,10 +81,7 @@ impl GGLWECiphertextCompressed> { } } - pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize) -> usize - where - Module: MatZnxAllocBytes, - { + pub fn bytes_of(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize) -> usize { let size: usize = k.div_ceil(basek); debug_assert!( size > digits, @@ -75,7 +98,7 @@ impl GGLWECiphertextCompressed> { size ); - module.mat_znx_alloc_bytes(rows, rank_in, 1, rows) + MatZnx::alloc_bytes(n, rows, rank_in, 1, rows) } } @@ -145,11 +168,9 @@ impl ReaderFrom for GGLWECiphertextCompressed { self.digits = reader.read_u64::()? as usize; self.rank_out = reader.read_u64::()? as usize; let seed_len = reader.read_u64::()? as usize; - if seed_len != self.seed.len() { - } else { - for s in &mut self.seed { - reader.read_exact(s)?; - } + self.seed = vec![[0u8; 32]; seed_len]; + for s in &mut self.seed { + reader.read_exact(s)?; } self.data.read_from(reader) } @@ -228,13 +249,46 @@ impl Decompress { pub(crate) key: GGLWECiphertextCompressed, pub(crate) sk_in_n: usize, // Degree of sk_in pub(crate) sk_out_n: usize, // Degree of sk_out } +impl fmt::Debug for GLWESwitchingKeyCompressed { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self) + } +} + +impl FillUniform for GLWESwitchingKeyCompressed { + fn fill_uniform(&mut self, source: &mut sampling::source::Source) { + self.key.fill_uniform(source); + } +} + +impl Reset for GLWESwitchingKeyCompressed +where + MatZnx: Reset, +{ + fn reset(&mut self) { + self.key.reset(); + self.sk_in_n = 0; + self.sk_out_n = 0; + } +} + +impl fmt::Display for GLWESwitchingKeyCompressed { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "(GLWESwitchingKeyCompressed: sk_in_n={} sk_out_n={}) {}", + self.sk_in_n, self.sk_out_n, self.key.data + ) + } +} + impl Infos for GLWESwitchingKeyCompressed { type Inner = MatZnx; @@ -270,30 +324,16 @@ impl GLWESwitchingKeyCompressed { } impl GLWESwitchingKeyCompressed> { - pub fn alloc( - module: &Module, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> Self - where - Module: MatZnxAlloc, - { + pub fn alloc(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize, rank_out: usize) -> Self { GLWESwitchingKeyCompressed { - key: GGLWECiphertextCompressed::alloc(module, basek, k, rows, digits, rank_in, rank_out), + key: GGLWECiphertextCompressed::alloc(n, basek, k, rows, digits, rank_in, rank_out), sk_in_n: 0, sk_out_n: 0, } } - pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize) -> usize - where - Module: MatZnxAllocBytes, - { - GGLWECiphertextCompressed::>::bytes_of(module, basek, k, rows, digits, rank_in) + pub fn bytes_of(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize) -> usize { + GGLWECiphertextCompressed::bytes_of(n, basek, k, rows, digits, rank_in) } } @@ -327,28 +367,50 @@ impl GLWESwitchingKey { } } -#[derive(PartialEq, Eq)] +#[derive(PartialEq, Eq, Clone)] pub struct AutomorphismKeyCompressed { pub(crate) key: GLWESwitchingKeyCompressed, pub(crate) p: i64, } +impl fmt::Debug for AutomorphismKeyCompressed { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self) + } +} + +impl FillUniform for AutomorphismKeyCompressed { + fn fill_uniform(&mut self, source: &mut sampling::source::Source) { + self.key.fill_uniform(source); + } +} + +impl Reset for AutomorphismKeyCompressed +where + MatZnx: Reset, +{ + fn reset(&mut self) { + self.key.reset(); + self.p = 0; + } +} + +impl fmt::Display for AutomorphismKeyCompressed { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "(AutomorphismKeyCompressed: p={}) {}", self.p, self.key) + } +} + impl AutomorphismKeyCompressed> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self - where - Module: MatZnxAlloc, - { + pub fn alloc(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self { AutomorphismKeyCompressed { - key: GLWESwitchingKeyCompressed::alloc(module, basek, k, rows, digits, rank, rank), + key: GLWESwitchingKeyCompressed::alloc(n, basek, k, rows, digits, rank, rank), p: 0, } } - pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize - where - Module: MatZnxAllocBytes, - { - GLWESwitchingKeyCompressed::>::bytes_of(module, basek, k, rows, digits, rank) + pub fn bytes_of(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize { + GLWESwitchingKeyCompressed::>::bytes_of(n, basek, k, rows, digits, rank) } } @@ -410,32 +472,61 @@ impl AutomorphismKey { } } -#[derive(PartialEq, Eq)] +#[derive(PartialEq, Eq, Clone)] pub struct GLWETensorKeyCompressed { pub(crate) keys: Vec>, } +impl fmt::Debug for GLWETensorKeyCompressed { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self) + } +} + +impl FillUniform for GLWETensorKeyCompressed { + fn fill_uniform(&mut self, source: &mut sampling::source::Source) { + self.keys + .iter_mut() + .for_each(|key: &mut GLWESwitchingKeyCompressed| key.fill_uniform(source)) + } +} + +impl Reset for GLWETensorKeyCompressed +where + MatZnx: Reset, +{ + fn reset(&mut self) { + self.keys + .iter_mut() + .for_each(|key: &mut GLWESwitchingKeyCompressed| key.reset()) + } +} + +impl fmt::Display for GLWETensorKeyCompressed { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!(f, "(GLWETensorKeyCompressed)",)?; + for (i, key) in self.keys.iter().enumerate() { + write!(f, "{}: {}", i, key)?; + } + Ok(()) + } +} + impl GLWETensorKeyCompressed> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self - where - Module: MatZnxAlloc, - { + pub fn alloc(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self { let mut keys: Vec>> = Vec::new(); let pairs: usize = (((rank + 1) * rank) >> 1).max(1); (0..pairs).for_each(|_| { keys.push(GLWESwitchingKeyCompressed::alloc( - module, basek, k, rows, digits, 1, rank, + n, basek, k, rows, digits, 1, rank, )); }); Self { keys: keys } } - pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize - where - Module: MatZnxAllocBytes, - { + pub fn bytes_of(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize { let pairs: usize = (((rank + 1) * rank) >> 1).max(1); - pairs * GLWESwitchingKeyCompressed::>::bytes_of(module, basek, k, rows, digits, 1) + pairs * GLWESwitchingKeyCompressed::bytes_of(n, basek, k, rows, digits, 1) } } diff --git a/core/src/gglwe/layouts_exec.rs b/core/src/gglwe/layouts_exec.rs index ba10fe0..4ba0d35 100644 --- a/core/src/gglwe/layouts_exec.rs +++ b/core/src/gglwe/layouts_exec.rs @@ -14,7 +14,16 @@ pub struct GGLWECiphertextExec { } impl GGLWECiphertextExec, B> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize, rank_out: usize) -> Self + pub fn alloc( + module: &Module, + n: usize, + basek: usize, + k: usize, + rows: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + ) -> Self where Module: GGLWEExecLayoutFamily, { @@ -35,7 +44,7 @@ impl GGLWECiphertextExec, B> { ); Self { - data: module.vmp_pmat_alloc(rows, rank_in, rank_out + 1, size), + data: module.vmp_pmat_alloc(n, rows, rank_in, rank_out + 1, size), basek: basek, k, digits, @@ -44,6 +53,7 @@ impl GGLWECiphertextExec, B> { pub fn bytes_of( module: &Module, + n: usize, basek: usize, k: usize, rows: usize, @@ -70,7 +80,7 @@ impl GGLWECiphertextExec, B> { size ); - module.vmp_pmat_alloc_bytes(rows, rank_in, rank_out + 1, rows) + module.vmp_pmat_alloc_bytes(n, rows, rank_in, rank_out + 1, rows) } } @@ -129,12 +139,21 @@ pub struct GLWESwitchingKeyExec { } impl GLWESwitchingKeyExec, B> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize, rank_out: usize) -> Self + pub fn alloc( + module: &Module, + n: usize, + basek: usize, + k: usize, + rows: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + ) -> Self where Module: GGLWEExecLayoutFamily, { GLWESwitchingKeyExec::, B> { - key: GGLWECiphertextExec::alloc(module, basek, k, rows, digits, rank_in, rank_out), + key: GGLWECiphertextExec::alloc(module, n, basek, k, rows, digits, rank_in, rank_out), sk_in_n: 0, sk_out_n: 0, } @@ -142,6 +161,7 @@ impl GLWESwitchingKeyExec, B> { pub fn bytes_of( module: &Module, + n: usize, basek: usize, k: usize, rows: usize, @@ -152,7 +172,7 @@ impl GLWESwitchingKeyExec, B> { where Module: GGLWEExecLayoutFamily, { - GGLWECiphertextExec::bytes_of(module, basek, k, rows, digits, rank_in, rank_out) + GGLWECiphertextExec::bytes_of(module, n, basek, k, rows, digits, rank_in, rank_out) } pub fn from(module: &Module, other: &GLWESwitchingKey, scratch: &mut Scratch) -> Self @@ -161,6 +181,7 @@ impl GLWESwitchingKeyExec, B> { { let mut ksk_exec: GLWESwitchingKeyExec, B> = Self::alloc( module, + other.n(), other.basek(), other.k(), other.rows(), @@ -234,21 +255,21 @@ pub struct AutomorphismKeyExec { } impl AutomorphismKeyExec, B> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self + pub fn alloc(module: &Module, n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self where Module: GGLWEExecLayoutFamily, { AutomorphismKeyExec::, B> { - key: GLWESwitchingKeyExec::alloc(module, basek, k, rows, digits, rank, rank), + key: GLWESwitchingKeyExec::alloc(module, n, basek, k, rows, digits, rank, rank), p: 0, } } - pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize + pub fn bytes_of(module: &Module, n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize where Module: GGLWEExecLayoutFamily, { - GLWESwitchingKeyExec::, B>::bytes_of(module, basek, k, rows, digits, rank, rank) + GLWESwitchingKeyExec::bytes_of(module, n, basek, k, rows, digits, rank, rank) } pub fn from(module: &Module, other: &AutomorphismKey, scratch: &mut Scratch) -> Self @@ -257,6 +278,7 @@ impl AutomorphismKeyExec, B> { { let mut atk_exec: AutomorphismKeyExec, B> = Self::alloc( module, + other.n(), other.basek(), other.k(), other.rows(), @@ -323,7 +345,7 @@ pub struct GLWETensorKeyExec { } impl GLWETensorKeyExec, B> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self + pub fn alloc(module: &Module, n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self where Module: GGLWEExecLayoutFamily, { @@ -331,18 +353,18 @@ impl GLWETensorKeyExec, B> { let pairs: usize = (((rank + 1) * rank) >> 1).max(1); (0..pairs).for_each(|_| { keys.push(GLWESwitchingKeyExec::alloc( - module, basek, k, rows, digits, 1, rank, + module, n, basek, k, rows, digits, 1, rank, )); }); Self { keys } } - pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize + pub fn bytes_of(module: &Module, n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize where Module: GGLWEExecLayoutFamily, { let pairs: usize = (((rank + 1) * rank) >> 1).max(1); - pairs * GLWESwitchingKeyExec::, B>::bytes_of(module, basek, k, rows, digits, 1, rank) + pairs * GLWESwitchingKeyExec::bytes_of(module, n, basek, k, rows, digits, 1, rank) } } diff --git a/core/src/gglwe/noise.rs b/core/src/gglwe/noise.rs index def18c2..b46cca7 100644 --- a/core/src/gglwe/noise.rs +++ b/core/src/gglwe/noise.rs @@ -1,5 +1,5 @@ use backend::hal::{ - api::{ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAlloc, VecZnxStd, VecZnxSubScalarInplace, ZnxZero}, + api::{ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxStd, VecZnxSubScalarInplace, ZnxZero}, layouts::{Backend, DataRef, Module, ScalarZnx, ScratchOwned}, oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl}, }; @@ -16,15 +16,20 @@ impl GGLWECiphertext { ) where DataSk: DataRef, DataWant: DataRef, - Module: GLWEDecryptFamily + VecZnxStd + VecZnxAlloc + VecZnxSubScalarInplace, + Module: GLWEDecryptFamily + VecZnxStd + VecZnxSubScalarInplace, B: TakeVecZnxDftImpl + TakeVecZnxBigImpl + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, { let digits: usize = self.digits(); let basek: usize = self.basek(); let k: usize = self.k(); - let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space(module, basek, k)); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k); + let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space( + module, + self.n(), + basek, + k, + )); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(self.n(), basek, k); (0..self.rank_in()).for_each(|col_i| { (0..self.rows()).for_each(|row_i| { diff --git a/core/src/gglwe/tests/generic_serialization.rs b/core/src/gglwe/tests/generic_serialization.rs new file mode 100644 index 0000000..a89c269 --- /dev/null +++ b/core/src/gglwe/tests/generic_serialization.rs @@ -0,0 +1,54 @@ +use backend::hal::tests::serialization::test_reader_writer_interface; + +use crate::{ + AutomorphismKey, AutomorphismKeyCompressed, GGLWECiphertext, GGLWECiphertextCompressed, GLWESwitchingKey, + GLWESwitchingKeyCompressed, GLWETensorKey, GLWETensorKeyCompressed, +}; + +#[test] +fn test_gglwe_serialization() { + let original: GGLWECiphertext> = GGLWECiphertext::alloc(1024, 12, 54, 3, 1, 2, 2); + test_reader_writer_interface(original); +} + +#[test] +fn test_gglwe_serialization_compressed() { + let original: GGLWECiphertextCompressed> = GGLWECiphertextCompressed::alloc(1024, 12, 54, 3, 1, 2, 2); + test_reader_writer_interface(original); +} + +#[test] +fn test_glwe_switching_key_serialization() { + let original: GLWESwitchingKey> = GLWESwitchingKey::alloc(1024, 12, 54, 3, 1, 2, 2); + test_reader_writer_interface(original); +} + +#[test] +fn test_glwe_switching_key_serialization_compressed() { + let original: GLWESwitchingKeyCompressed> = GLWESwitchingKeyCompressed::alloc(1024, 12, 54, 3, 1, 2, 2); + test_reader_writer_interface(original); +} + +#[test] +fn test_automorphism_key_serialization() { + let original: AutomorphismKey> = AutomorphismKey::alloc(1024, 12, 54, 3, 1, 2); + test_reader_writer_interface(original); +} + +#[test] +fn test_automorphism_key_serialization_compressed() { + let original: AutomorphismKeyCompressed> = AutomorphismKeyCompressed::alloc(1024, 12, 54, 3, 1, 2); + test_reader_writer_interface(original); +} + +#[test] +fn test_tensor_key_serialization() { + let original: GLWETensorKey> = GLWETensorKey::alloc(1024, 12, 54, 3, 1, 2); + test_reader_writer_interface(original); +} + +#[test] +fn test_tensor_key_serialization_compressed() { + let original: GLWETensorKeyCompressed> = GLWETensorKeyCompressed::alloc(1024, 12, 54, 3, 1, 2); + test_reader_writer_interface(original); +} diff --git a/core/src/gglwe/tests/generics_automorphism_key.rs b/core/src/gglwe/tests/generics_automorphism_key.rs index 93e004c..208d79b 100644 --- a/core/src/gglwe/tests/generics_automorphism_key.rs +++ b/core/src/gglwe/tests/generics_automorphism_key.rs @@ -1,8 +1,7 @@ use backend::hal::{ api::{ - MatZnxAlloc, ScalarZnxAlloc, ScalarZnxAllocBytes, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddScalarInplace, - VecZnxAlloc, VecZnxAllocBytes, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxCopy, VecZnxStd, - VecZnxSubScalarInplace, VecZnxSwithcDegree, + ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxCopy, + VecZnxStd, VecZnxSubScalarInplace, VecZnxSwithcDegree, }, layouts::{Backend, Module, ScratchOwned}, oep::{ @@ -18,19 +17,14 @@ use crate::{ noise::log2_std_noise_gglwe_product, }; -pub(crate) trait AutomorphismTestModuleFamily = MatZnxAlloc - + AutomorphismKeyEncryptSkFamily - + ScalarZnxAllocBytes - + VecZnxAllocBytes +pub(crate) trait AutomorphismTestModuleFamily = AutomorphismKeyEncryptSkFamily + GLWEKeyswitchFamily - + ScalarZnxAlloc + VecZnxAutomorphism + GGLWEExecLayoutFamily + VecZnxSwithcDegree + VecZnxAddScalarInplace + VecZnxAutomorphism + VecZnxAutomorphismInplace - + VecZnxAlloc + GLWEDecryptFamily + VecZnxSubScalarInplace + VecZnxStd @@ -55,19 +49,20 @@ pub(crate) fn test_automorphisk_key_encrypt_sk( Module: AutomorphismTestModuleFamily, B: AutomorphismTestScratchFamily, { + let n: usize = module.n(); let rows: usize = (k_ksk - digits * basek) / (digits * basek); - let mut atk: AutomorphismKey> = AutomorphismKey::alloc(module, basek, k_ksk, rows, digits, rank); + let mut atk: AutomorphismKey> = AutomorphismKey::alloc(n, basek, k_ksk, rows, digits, rank); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc(AutomorphismKey::encrypt_sk_scratch_space( - module, basek, k_ksk, rank, + module, n, basek, k_ksk, rank, )); - let mut sk: GLWESecret> = GLWESecret::alloc(module, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); sk.fill_ternary_prob(0.5, &mut source_xs); let p = -5; @@ -110,19 +105,20 @@ pub(crate) fn test_automorphisk_key_encrypt_sk_compressed( Module: AutomorphismTestModuleFamily, B: AutomorphismTestScratchFamily, { + let n: usize = module.n(); let rows: usize = (k_ksk - digits * basek) / (digits * basek); let mut atk_compressed: AutomorphismKeyCompressed> = - AutomorphismKeyCompressed::alloc(module, basek, k_ksk, rows, digits, rank); + AutomorphismKeyCompressed::alloc(n, basek, k_ksk, rows, digits, rank); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc(AutomorphismKey::encrypt_sk_scratch_space( - module, basek, k_ksk, rank, + module, n, basek, k_ksk, rank, )); - let mut sk: GLWESecret> = GLWESecret::alloc(module, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); sk.fill_ternary_prob(0.5, &mut source_xs); let p = -5; @@ -151,7 +147,7 @@ pub(crate) fn test_automorphisk_key_encrypt_sk_compressed( }); let sk_out_exec = GLWESecretExec::from(module, &sk_out); - let mut atk: AutomorphismKey> = AutomorphismKey::alloc(module, basek, k_ksk, rows, digits, rank); + let mut atk: AutomorphismKey> = AutomorphismKey::alloc(n, basek, k_ksk, rows, digits, rank); atk.decompress(module, &atk_compressed); atk.key @@ -174,30 +170,31 @@ pub(crate) fn test_gglwe_automorphism( Module: AutomorphismTestModuleFamily, B: AutomorphismTestScratchFamily, { + let n: usize = module.n(); let digits_in: usize = 1; let rows_in: usize = k_in / (basek * digits); let rows_apply: usize = k_in.div_ceil(basek * digits); - let mut auto_key_in: AutomorphismKey> = AutomorphismKey::alloc(&module, basek, k_in, rows_in, digits_in, rank); - let mut auto_key_out: AutomorphismKey> = AutomorphismKey::alloc(&module, basek, k_out, rows_in, digits_in, rank); - let mut auto_key_apply: AutomorphismKey> = AutomorphismKey::alloc(&module, basek, k_apply, rows_apply, digits, rank); + let mut auto_key_in: AutomorphismKey> = AutomorphismKey::alloc(n, basek, k_in, rows_in, digits_in, rank); + let mut auto_key_out: AutomorphismKey> = AutomorphismKey::alloc(n, basek, k_out, rows_in, digits_in, rank); + let mut auto_key_apply: AutomorphismKey> = AutomorphismKey::alloc(n, basek, k_apply, rows_apply, digits, rank); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - AutomorphismKey::encrypt_sk_scratch_space(&module, basek, k_apply, rank) - | AutomorphismKey::automorphism_scratch_space(&module, basek, k_out, k_in, k_apply, digits, rank), + AutomorphismKey::encrypt_sk_scratch_space(module, n, basek, k_apply, rank) + | AutomorphismKey::automorphism_scratch_space(module, n, basek, k_out, k_in, k_apply, digits, rank), ); - let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); sk.fill_ternary_prob(0.5, &mut source_xs); // gglwe_{s1}(s0) = s0 -> s1 auto_key_in.encrypt_sk( - &module, + module, p0, &sk, &mut source_xa, @@ -208,7 +205,7 @@ pub(crate) fn test_gglwe_automorphism( // gglwe_{s2}(s1) -> s1 -> s2 auto_key_apply.encrypt_sk( - &module, + module, p1, &sk, &mut source_xa, @@ -218,21 +215,16 @@ pub(crate) fn test_gglwe_automorphism( ); let mut auto_key_apply_exec: AutomorphismKeyExec, B> = - AutomorphismKeyExec::alloc(&module, basek, k_apply, rows_apply, digits, rank); + AutomorphismKeyExec::alloc(module, n, basek, k_apply, rows_apply, digits, rank); - auto_key_apply_exec.prepare(&module, &auto_key_apply, scratch.borrow()); + auto_key_apply_exec.prepare(module, &auto_key_apply, scratch.borrow()); // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) - auto_key_out.automorphism( - &module, - &auto_key_in, - &auto_key_apply_exec, - scratch.borrow(), - ); + auto_key_out.automorphism(module, &auto_key_in, &auto_key_apply_exec, scratch.borrow()); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_out); - let mut sk_auto: GLWESecret> = GLWESecret::alloc(&module, rank); + let mut sk_auto: GLWESecret> = GLWESecret::alloc(n, rank); sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk (0..rank).for_each(|i| { module.vec_znx_automorphism( @@ -244,13 +236,13 @@ pub(crate) fn test_gglwe_automorphism( ); }); - let sk_auto_dft: GLWESecretExec, B> = GLWESecretExec::from(&module, &sk_auto); + let sk_auto_dft: GLWESecretExec, B> = GLWESecretExec::from(module, &sk_auto); (0..auto_key_out.rank_in()).for_each(|col_i| { (0..auto_key_out.rows()).for_each(|row_i| { auto_key_out .at(row_i, col_i) - .decrypt(&module, &mut pt, &sk_auto_dft, scratch.borrow()); + .decrypt(module, &mut pt, &sk_auto_dft, scratch.borrow()); module.vec_znx_sub_scalar_inplace( &mut pt.data, @@ -262,7 +254,7 @@ pub(crate) fn test_gglwe_automorphism( let noise_have: f64 = module.vec_znx_std(basek, &pt.data, 0).log2(); let noise_want: f64 = log2_std_noise_gglwe_product( - module.n() as f64, + n as f64, basek * digits, 0.5, 0.5, @@ -298,29 +290,30 @@ pub(crate) fn test_gglwe_automorphism_inplace( Module: AutomorphismTestModuleFamily, B: AutomorphismTestScratchFamily, { + let n: usize = module.n(); let digits_in: usize = 1; let rows_in: usize = k_in / (basek * digits); let rows_apply: usize = k_in.div_ceil(basek * digits); - let mut auto_key: AutomorphismKey> = AutomorphismKey::alloc(&module, basek, k_in, rows_in, digits_in, rank); - let mut auto_key_apply: AutomorphismKey> = AutomorphismKey::alloc(&module, basek, k_apply, rows_apply, digits, rank); + let mut auto_key: AutomorphismKey> = AutomorphismKey::alloc(n, basek, k_in, rows_in, digits_in, rank); + let mut auto_key_apply: AutomorphismKey> = AutomorphismKey::alloc(n, basek, k_apply, rows_apply, digits, rank); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - AutomorphismKey::encrypt_sk_scratch_space(&module, basek, k_apply, rank) - | AutomorphismKey::automorphism_inplace_scratch_space(&module, basek, k_in, k_apply, digits, rank), + AutomorphismKey::encrypt_sk_scratch_space(module, n, basek, k_apply, rank) + | AutomorphismKey::automorphism_inplace_scratch_space(module, n, basek, k_in, k_apply, digits, rank), ); - let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); sk.fill_ternary_prob(0.5, &mut source_xs); // gglwe_{s1}(s0) = s0 -> s1 auto_key.encrypt_sk( - &module, + module, p0, &sk, &mut source_xa, @@ -331,7 +324,7 @@ pub(crate) fn test_gglwe_automorphism_inplace( // gglwe_{s2}(s1) -> s1 -> s2 auto_key_apply.encrypt_sk( - &module, + module, p1, &sk, &mut source_xa, @@ -341,16 +334,16 @@ pub(crate) fn test_gglwe_automorphism_inplace( ); let mut auto_key_apply_exec: AutomorphismKeyExec, B> = - AutomorphismKeyExec::alloc(&module, basek, k_apply, rows_apply, digits, rank); + AutomorphismKeyExec::alloc(module, n, basek, k_apply, rows_apply, digits, rank); - auto_key_apply_exec.prepare(&module, &auto_key_apply, scratch.borrow()); + auto_key_apply_exec.prepare(module, &auto_key_apply, scratch.borrow()); // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) - auto_key.automorphism_inplace(&module, &auto_key_apply_exec, scratch.borrow()); + auto_key.automorphism_inplace(module, &auto_key_apply_exec, scratch.borrow()); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_in); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_in); - let mut sk_auto: GLWESecret> = GLWESecret::alloc(&module, rank); + let mut sk_auto: GLWESecret> = GLWESecret::alloc(n, rank); sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk (0..rank).for_each(|i| { @@ -363,13 +356,13 @@ pub(crate) fn test_gglwe_automorphism_inplace( ); }); - let sk_auto_dft: GLWESecretExec, B> = GLWESecretExec::from(&module, &sk_auto); + let sk_auto_dft: GLWESecretExec, B> = GLWESecretExec::from(module, &sk_auto); (0..auto_key.rank_in()).for_each(|col_i| { (0..auto_key.rows()).for_each(|row_i| { auto_key .at(row_i, col_i) - .decrypt(&module, &mut pt, &sk_auto_dft, scratch.borrow()); + .decrypt(module, &mut pt, &sk_auto_dft, scratch.borrow()); module.vec_znx_sub_scalar_inplace( &mut pt.data, 0, @@ -380,7 +373,7 @@ pub(crate) fn test_gglwe_automorphism_inplace( let noise_have: f64 = module.vec_znx_std(basek, &pt.data, 0).log2(); let noise_want: f64 = log2_std_noise_gglwe_product( - module.n() as f64, + n as f64, basek * digits, 0.5, 0.5, diff --git a/core/src/gglwe/tests/generics_gglwe.rs b/core/src/gglwe/tests/generics_gglwe.rs index 3feaa4e..ef909e7 100644 --- a/core/src/gglwe/tests/generics_gglwe.rs +++ b/core/src/gglwe/tests/generics_gglwe.rs @@ -1,8 +1,7 @@ use backend::hal::{ api::{ - MatZnxAlloc, ScalarZnxAlloc, ScalarZnxAllocBytes, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddScalarInplace, - VecZnxAlloc, VecZnxAllocBytes, VecZnxCopy, VecZnxRotateInplace, VecZnxStd, VecZnxSubScalarInplace, VecZnxSwithcDegree, - ZnxViewMut, + ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddScalarInplace, VecZnxCopy, VecZnxRotateInplace, VecZnxStd, + VecZnxSubScalarInplace, VecZnxSwithcDegree, ZnxViewMut, }, layouts::{Backend, Module, ScalarZnx, ScalarZnxToMut, ScratchOwned}, oep::{ @@ -21,14 +20,9 @@ use crate::{ pub(crate) trait TestModuleFamily = GGLWEEncryptSkFamily + GLWEDecryptFamily - + MatZnxAlloc - + ScalarZnxAlloc - + ScalarZnxAllocBytes - + VecZnxAllocBytes + VecZnxSwithcDegree + VecZnxAddScalarInplace + VecZnxStd - + VecZnxAlloc + VecZnxSubScalarInplace + VecZnxCopy; @@ -56,22 +50,23 @@ pub(crate) fn test_gglwe_encrypt_sk( Module: TestModuleFamily, B: TestScratchFamily, { + let n: usize = module.n(); let rows: usize = (k_ksk - digits * basek) / (digits * basek); - let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc(module, basek, k_ksk, rows, digits, rank_in, rank_out); + let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc(n, basek, k_ksk, rows, digits, rank_in, rank_out); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWESwitchingKey::encrypt_sk_scratch_space( - module, basek, k_ksk, rank_in, rank_out, + module, n, basek, k_ksk, rank_in, rank_out, )); - let mut sk_in: GLWESecret> = GLWESecret::alloc(module, rank_in); + let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank_in); sk_in.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_out: GLWESecret> = GLWESecret::alloc(module, rank_out); + let mut sk_out: GLWESecret> = GLWESecret::alloc(n, rank_out); sk_out.fill_ternary_prob(0.5, &mut source_xs); let sk_out_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk_out); @@ -101,22 +96,23 @@ pub(crate) fn test_gglwe_encrypt_sk_compressed( Module: TestModuleFamily, B: TestScratchFamily, { + let n: usize = module.n(); let rows: usize = (k_ksk - digits * basek) / (digits * basek); let mut ksk_compressed: GLWESwitchingKeyCompressed> = - GLWESwitchingKeyCompressed::alloc(module, basek, k_ksk, rows, digits, rank_in, rank_out); + GLWESwitchingKeyCompressed::alloc(n, basek, k_ksk, rows, digits, rank_in, rank_out); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWESwitchingKeyCompressed::encrypt_sk_scratch_space( - module, basek, k_ksk, rank_in, rank_out, + module, n, basek, k_ksk, rank_in, rank_out, )); - let mut sk_in: GLWESecret> = GLWESecret::alloc(module, rank_in); + let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank_in); sk_in.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_out: GLWESecret> = GLWESecret::alloc(module, rank_out); + let mut sk_out: GLWESecret> = GLWESecret::alloc(n, rank_out); sk_out.fill_ternary_prob(0.5, &mut source_xs); let sk_out_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk_out); @@ -132,7 +128,7 @@ pub(crate) fn test_gglwe_encrypt_sk_compressed( scratch.borrow(), ); - let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc(module, basek, k_ksk, rows, digits, rank_in, rank_out); + let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc(n, basek, k_ksk, rows, digits, rank_in, rank_out); ksk.decompress(module, &ksk_compressed); ksk.key @@ -155,29 +151,16 @@ pub(crate) fn test_gglwe_keyswitch( TestModuleFamily + GGLWEEncryptSkFamily + GLWEDecryptFamily + GLWEKeyswitchFamily + GGLWEExecLayoutFamily, B: TestScratchFamily, { + let n: usize = module.n(); let rows: usize = k_in.div_ceil(basek * digits); let digits_in: usize = 1; - let mut ct_gglwe_s0s1: GLWESwitchingKey> = GLWESwitchingKey::alloc( - module, - basek, - k_in, - rows, - digits_in, - rank_in_s0s1, - rank_out_s0s1, - ); - let mut ct_gglwe_s1s2: GLWESwitchingKey> = GLWESwitchingKey::alloc( - module, - basek, - k_ksk, - rows, - digits, - rank_out_s0s1, - rank_out_s1s2, - ); + let mut ct_gglwe_s0s1: GLWESwitchingKey> = + GLWESwitchingKey::alloc(n, basek, k_in, rows, digits_in, rank_in_s0s1, rank_out_s0s1); + let mut ct_gglwe_s1s2: GLWESwitchingKey> = + GLWESwitchingKey::alloc(n, basek, k_ksk, rows, digits, rank_out_s0s1, rank_out_s1s2); let mut ct_gglwe_s0s2: GLWESwitchingKey> = GLWESwitchingKey::alloc( - module, + n, basek, k_out, rows, @@ -192,6 +175,7 @@ pub(crate) fn test_gglwe_keyswitch( let mut scratch_enc: ScratchOwned = ScratchOwned::alloc(GLWESwitchingKey::encrypt_sk_scratch_space( module, + n, basek, k_ksk, rank_in_s0s1 | rank_out_s0s1, @@ -199,6 +183,7 @@ pub(crate) fn test_gglwe_keyswitch( )); let mut scratch_apply: ScratchOwned = ScratchOwned::alloc(GLWESwitchingKey::keyswitch_scratch_space( module, + n, basek, k_out, k_in, @@ -208,13 +193,13 @@ pub(crate) fn test_gglwe_keyswitch( ct_gglwe_s1s2.rank_out(), )); - let mut sk0: GLWESecret> = GLWESecret::alloc(module, rank_in_s0s1); + let mut sk0: GLWESecret> = GLWESecret::alloc(n, rank_in_s0s1); sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1: GLWESecret> = GLWESecret::alloc(module, rank_out_s0s1); + let mut sk1: GLWESecret> = GLWESecret::alloc(n, rank_out_s0s1); sk1.fill_ternary_prob(0.5, &mut source_xs); - let mut sk2: GLWESecret> = GLWESecret::alloc(module, rank_out_s1s2); + let mut sk2: GLWESecret> = GLWESecret::alloc(n, rank_out_s1s2); sk2.fill_ternary_prob(0.5, &mut source_xs); let sk2_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk2); @@ -252,7 +237,7 @@ pub(crate) fn test_gglwe_keyswitch( ); let max_noise: f64 = log2_std_noise_gglwe_product( - module.n() as f64, + n as f64, basek * digits, 0.5, 0.5, @@ -286,13 +271,13 @@ pub(crate) fn test_gglwe_keyswitch_inplace( + GLWEDecryptFamily, B: TestScratchFamily, { + let n: usize = module.n(); let rows: usize = k_ct.div_ceil(basek * digits); let digits_in: usize = 1; let mut ct_gglwe_s0s1: GLWESwitchingKey> = - GLWESwitchingKey::alloc(module, basek, k_ct, rows, digits_in, rank_in, rank_out); - let mut ct_gglwe_s1s2: GLWESwitchingKey> = - GLWESwitchingKey::alloc(module, basek, k_ksk, rows, digits, rank_out, rank_out); + GLWESwitchingKey::alloc(n, basek, k_ct, rows, digits_in, rank_in, rank_out); + let mut ct_gglwe_s1s2: GLWESwitchingKey> = GLWESwitchingKey::alloc(n, basek, k_ksk, rows, digits, rank_out, rank_out); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -300,24 +285,25 @@ pub(crate) fn test_gglwe_keyswitch_inplace( let mut scratch_enc: ScratchOwned = ScratchOwned::alloc(GLWESwitchingKey::encrypt_sk_scratch_space( module, + n, basek, k_ksk, rank_in | rank_out, rank_out, )); let mut scratch_apply: ScratchOwned = ScratchOwned::alloc(GLWESwitchingKey::keyswitch_inplace_scratch_space( - module, basek, k_ct, k_ksk, digits, rank_out, + module, n, basek, k_ct, k_ksk, digits, rank_out, )); let var_xs: f64 = 0.5; - let mut sk0: GLWESecret> = GLWESecret::alloc(module, rank_in); + let mut sk0: GLWESecret> = GLWESecret::alloc(n, rank_in); sk0.fill_ternary_prob(var_xs, &mut source_xs); - let mut sk1: GLWESecret> = GLWESecret::alloc(module, rank_out); + let mut sk1: GLWESecret> = GLWESecret::alloc(n, rank_out); sk1.fill_ternary_prob(var_xs, &mut source_xs); - let mut sk2: GLWESecret> = GLWESecret::alloc(module, rank_out); + let mut sk2: GLWESecret> = GLWESecret::alloc(n, rank_out); sk2.fill_ternary_prob(var_xs, &mut source_xs); let sk2_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk2); @@ -352,7 +338,7 @@ pub(crate) fn test_gglwe_keyswitch_inplace( let ct_gglwe_s0s2: GLWESwitchingKey> = ct_gglwe_s0s1; let max_noise: f64 = log2_std_noise_gglwe_product( - module.n() as f64, + n as f64, basek * digits, var_xs, var_xs, @@ -388,25 +374,25 @@ pub(crate) fn test_gglwe_external_product( + VecZnxRotateInplace, B: TestScratchFamily, { + let n: usize = module.n(); let rows: usize = k_in.div_ceil(basek * digits); let digits_in: usize = 1; - let mut ct_gglwe_in: GLWESwitchingKey> = - GLWESwitchingKey::alloc(module, basek, k_in, rows, digits_in, rank_in, rank_out); + let mut ct_gglwe_in: GLWESwitchingKey> = GLWESwitchingKey::alloc(n, basek, k_in, rows, digits_in, rank_in, rank_out); let mut ct_gglwe_out: GLWESwitchingKey> = - GLWESwitchingKey::alloc(module, basek, k_out, rows, digits_in, rank_in, rank_out); - let mut ct_rgsw: GGSWCiphertext> = GGSWCiphertext::alloc(module, basek, k_ggsw, rows, digits, rank_out); + GLWESwitchingKey::alloc(n, basek, k_out, rows, digits_in, rank_in, rank_out); + let mut ct_rgsw: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ggsw, rows, digits, rank_out); - let mut pt_rgsw: ScalarZnx> = module.scalar_znx_alloc(1); + let mut pt_rgsw: ScalarZnx> = ScalarZnx::alloc(n, 1); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k_in, rank_in, rank_out) - | GLWESwitchingKey::external_product_scratch_space(module, basek, k_out, k_in, k_ggsw, digits, rank_out) - | GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_ggsw, rank_out), + GLWESwitchingKey::encrypt_sk_scratch_space(module, n, basek, k_in, rank_in, rank_out) + | GLWESwitchingKey::external_product_scratch_space(module, n, basek, k_out, k_in, k_ggsw, digits, rank_out) + | GGSWCiphertext::encrypt_sk_scratch_space(module, n, basek, k_ggsw, rank_out), ); let r: usize = 1; @@ -415,10 +401,10 @@ pub(crate) fn test_gglwe_external_product( let var_xs: f64 = 0.5; - let mut sk_in: GLWESecret> = GLWESecret::alloc(module, rank_in); + let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank_in); sk_in.fill_ternary_prob(var_xs, &mut source_xs); - let mut sk_out: GLWESecret> = GLWESecret::alloc(module, rank_out); + let mut sk_out: GLWESecret> = GLWESecret::alloc(n, rank_out); sk_out.fill_ternary_prob(var_xs, &mut source_xs); let sk_out_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk_out); @@ -444,7 +430,7 @@ pub(crate) fn test_gglwe_external_product( ); let mut ct_rgsw_exec: GGSWCiphertextExec, B> = - GGSWCiphertextExec::alloc(module, basek, k_ggsw, rows, digits, rank_out); + GGSWCiphertextExec::alloc(module, n, basek, k_ggsw, rows, digits, rank_out); ct_rgsw_exec.prepare(module, &ct_rgsw, scratch.borrow()); @@ -458,12 +444,12 @@ pub(crate) fn test_gglwe_external_product( let var_gct_err_lhs: f64 = sigma * sigma; let var_gct_err_rhs: f64 = 0f64; - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_msg: f64 = 1f64 / n as f64; // X^{k} let var_a0_err: f64 = sigma * sigma; let var_a1_err: f64 = 1f64 / 12f64; let max_noise: f64 = noise_ggsw_product( - module.n() as f64, + n as f64, basek * digits, var_xs, var_msg, @@ -499,24 +485,24 @@ pub(crate) fn test_gglwe_external_product_inplace( + VecZnxRotateInplace, B: TestScratchFamily, { + let n: usize = module.n(); let rows: usize = k_ct.div_ceil(basek * digits); let digits_in: usize = 1; - let mut ct_gglwe: GLWESwitchingKey> = - GLWESwitchingKey::alloc(module, basek, k_ct, rows, digits_in, rank_in, rank_out); - let mut ct_rgsw: GGSWCiphertext> = GGSWCiphertext::alloc(module, basek, k_ggsw, rows, digits, rank_out); + let mut ct_gglwe: GLWESwitchingKey> = GLWESwitchingKey::alloc(n, basek, k_ct, rows, digits_in, rank_in, rank_out); + let mut ct_rgsw: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ggsw, rows, digits, rank_out); - let mut pt_rgsw: ScalarZnx> = module.scalar_znx_alloc(1); + let mut pt_rgsw: ScalarZnx> = ScalarZnx::alloc(n, 1); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k_ct, rank_in, rank_out) - | GLWESwitchingKey::external_product_inplace_scratch_space(module, basek, k_ct, k_ggsw, digits, rank_out) - | GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_ggsw, rank_out), + GLWESwitchingKey::encrypt_sk_scratch_space(module, n, basek, k_ct, rank_in, rank_out) + | GLWESwitchingKey::external_product_inplace_scratch_space(module, n, basek, k_ct, k_ggsw, digits, rank_out) + | GGSWCiphertext::encrypt_sk_scratch_space(module, n, basek, k_ggsw, rank_out), ); let r: usize = 1; @@ -525,10 +511,10 @@ pub(crate) fn test_gglwe_external_product_inplace( let var_xs: f64 = 0.5; - let mut sk_in: GLWESecret> = GLWESecret::alloc(module, rank_in); + let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank_in); sk_in.fill_ternary_prob(var_xs, &mut source_xs); - let mut sk_out: GLWESecret> = GLWESecret::alloc(module, rank_out); + let mut sk_out: GLWESecret> = GLWESecret::alloc(n, rank_out); sk_out.fill_ternary_prob(var_xs, &mut source_xs); let sk_out_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk_out); @@ -554,7 +540,7 @@ pub(crate) fn test_gglwe_external_product_inplace( ); let mut ct_rgsw_exec: GGSWCiphertextExec, B> = - GGSWCiphertextExec::alloc(module, basek, k_ggsw, rows, digits, rank_out); + GGSWCiphertextExec::alloc(module, n, basek, k_ggsw, rows, digits, rank_out); ct_rgsw_exec.prepare(module, &ct_rgsw, scratch.borrow()); @@ -568,12 +554,12 @@ pub(crate) fn test_gglwe_external_product_inplace( let var_gct_err_lhs: f64 = sigma * sigma; let var_gct_err_rhs: f64 = 0f64; - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_msg: f64 = 1f64 / n as f64; // X^{k} let var_a0_err: f64 = sigma * sigma; let var_a1_err: f64 = 1f64 / 12f64; let max_noise: f64 = noise_ggsw_product( - module.n() as f64, + n as f64, basek * digits, var_xs, var_msg, diff --git a/core/src/gglwe/tests/generics_tensor_key.rs b/core/src/gglwe/tests/generics_tensor_key.rs index 7078c2b..962abec 100644 --- a/core/src/gglwe/tests/generics_tensor_key.rs +++ b/core/src/gglwe/tests/generics_tensor_key.rs @@ -1,8 +1,7 @@ use backend::hal::{ api::{ - MatZnxAlloc, ScalarZnxAlloc, ScalarZnxAllocBytes, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddScalarInplace, - VecZnxAlloc, VecZnxAllocBytes, VecZnxBigAlloc, VecZnxCopy, VecZnxDftAlloc, VecZnxStd, VecZnxSubScalarInplace, - VecZnxSwithcDegree, + ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddScalarInplace, VecZnxBigAlloc, VecZnxCopy, VecZnxDftAlloc, VecZnxStd, + VecZnxSubScalarInplace, VecZnxSwithcDegree, }, layouts::{Backend, Module, ScratchOwned, VecZnxDft}, oep::{ @@ -19,14 +18,9 @@ use crate::{ pub(crate) trait TestModuleFamily = GGLWEEncryptSkFamily + GLWEDecryptFamily - + MatZnxAlloc - + ScalarZnxAlloc - + ScalarZnxAllocBytes - + VecZnxAllocBytes + VecZnxSwithcDegree + VecZnxAddScalarInplace + VecZnxStd - + VecZnxAlloc + VecZnxSubScalarInplace; pub(crate) trait TestScratchFamily = TakeVecZnxDftImpl @@ -51,9 +45,10 @@ where + VecZnxBigAlloc, B: TestScratchFamily, { + let n: usize = module.n(); let rows: usize = k / basek; - let mut tensor_key: GLWETensorKey> = GLWETensorKey::alloc(&module, basek, k, rows, 1, rank); + let mut tensor_key: GLWETensorKey> = GLWETensorKey::alloc(n, basek, k, rows, 1, rank); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -61,14 +56,15 @@ where let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWETensorKey::encrypt_sk_scratch_space( module, + n, basek, tensor_key.k(), rank, )); - let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_exec: GLWESecretExec, B> = GLWESecretExec::from(&module, &sk); + let mut sk_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk); sk_exec.prepare(module, &sk); tensor_key.encrypt_sk( @@ -80,12 +76,12 @@ where scratch.borrow(), ); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k); - let mut sk_ij_dft = module.vec_znx_dft_alloc(1, 1); - let mut sk_ij_big = module.vec_znx_big_alloc(1, 1); - let mut sk_ij: GLWESecret> = GLWESecret::alloc(&module, 1); - let mut sk_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(rank, 1); + let mut sk_ij_dft = module.vec_znx_dft_alloc(n, 1, 1); + let mut sk_ij_big = module.vec_znx_big_alloc(n, 1, 1); + let mut sk_ij: GLWESecret> = GLWESecret::alloc(n, 1); + let mut sk_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(n, rank, 1); (0..rank).for_each(|i| { module.vec_znx_dft_from_vec_znx(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); @@ -108,7 +104,7 @@ where tensor_key .at(i, j) .at(row_i, col_i) - .decrypt(&module, &mut pt, &sk_exec, scratch.borrow()); + .decrypt(module, &mut pt, &sk_exec, scratch.borrow()); module.vec_znx_sub_scalar_inplace(&mut pt.data, 0, row_i, &sk_ij.data, col_i); @@ -136,24 +132,25 @@ pub(crate) fn test_tensor_key_encrypt_sk_compressed( + VecZnxCopy, B: TestScratchFamily, { + let n: usize = module.n(); let rows: usize = k / basek; - let mut tensor_key_compressed: GLWETensorKeyCompressed> = - GLWETensorKeyCompressed::alloc(&module, basek, k, rows, 1, rank); + let mut tensor_key_compressed: GLWETensorKeyCompressed> = GLWETensorKeyCompressed::alloc(n, basek, k, rows, 1, rank); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWETensorKeyCompressed::encrypt_sk_scratch_space( module, + n, basek, tensor_key_compressed.k(), rank, )); - let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_exec: GLWESecretExec, B> = GLWESecretExec::from(&module, &sk); + let mut sk_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk); sk_exec.prepare(module, &sk); let seed_xa: [u8; 32] = [1u8; 32]; @@ -167,15 +164,15 @@ pub(crate) fn test_tensor_key_encrypt_sk_compressed( scratch.borrow(), ); - let mut tensor_key: GLWETensorKey> = GLWETensorKey::alloc(&module, basek, k, rows, 1, rank); + let mut tensor_key: GLWETensorKey> = GLWETensorKey::alloc(n, basek, k, rows, 1, rank); tensor_key.decompress(module, &tensor_key_compressed); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k); - let mut sk_ij_dft = module.vec_znx_dft_alloc(1, 1); - let mut sk_ij_big = module.vec_znx_big_alloc(1, 1); - let mut sk_ij: GLWESecret> = GLWESecret::alloc(&module, 1); - let mut sk_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(rank, 1); + let mut sk_ij_dft = module.vec_znx_dft_alloc(n, 1, 1); + let mut sk_ij_big = module.vec_znx_big_alloc(n, 1, 1); + let mut sk_ij: GLWESecret> = GLWESecret::alloc(n, 1); + let mut sk_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(n, rank, 1); (0..rank).for_each(|i| { module.vec_znx_dft_from_vec_znx(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); @@ -198,7 +195,7 @@ pub(crate) fn test_tensor_key_encrypt_sk_compressed( tensor_key .at(i, j) .at(row_i, col_i) - .decrypt(&module, &mut pt, &sk_exec, scratch.borrow()); + .decrypt(module, &mut pt, &sk_exec, scratch.borrow()); module.vec_znx_sub_scalar_inplace(&mut pt.data, 0, row_i, &sk_ij.data, col_i); diff --git a/core/src/gglwe/tests/mod.rs b/core/src/gglwe/tests/mod.rs index 50e4637..586f4c5 100644 --- a/core/src/gglwe/tests/mod.rs +++ b/core/src/gglwe/tests/mod.rs @@ -1,4 +1,5 @@ mod cpu_spqlios; +mod generic_serialization; mod generics_automorphism_key; mod generics_gglwe; mod generics_tensor_key; diff --git a/core/src/ggsw/automorphism.rs b/core/src/ggsw/automorphism.rs index c983976..9e25c37 100644 --- a/core/src/ggsw/automorphism.rs +++ b/core/src/ggsw/automorphism.rs @@ -10,6 +10,7 @@ use crate::{ impl GGSWCiphertext> { pub fn automorphism_scratch_space( module: &Module, + n: usize, basek: usize, k_out: usize, k_in: usize, @@ -23,15 +24,16 @@ impl GGSWCiphertext> { Module: GLWEKeyswitchFamily + GGSWKeySwitchFamily + VecZnxNormalizeTmpBytes, { let out_size: usize = k_out.div_ceil(basek); - let ci_dft: usize = module.vec_znx_dft_alloc_bytes(rank + 1, out_size); + let ci_dft: usize = module.vec_znx_dft_alloc_bytes(n, rank + 1, out_size); let ks_internal: usize = - GLWECiphertext::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits_ksk, rank, rank); - let expand: usize = GGSWCiphertext::expand_row_scratch_space(module, basek, k_out, k_tsk, digits_tsk, rank); + GLWECiphertext::keyswitch_scratch_space(module, n, basek, k_out, k_in, k_ksk, digits_ksk, rank, rank); + let expand: usize = GGSWCiphertext::expand_row_scratch_space(module, n, basek, k_out, k_tsk, digits_tsk, rank); ci_dft + (ks_internal | expand) } pub fn automorphism_inplace_scratch_space( module: &Module, + n: usize, basek: usize, k_out: usize, k_ksk: usize, @@ -44,7 +46,7 @@ impl GGSWCiphertext> { Module: GLWEKeyswitchFamily + GGSWKeySwitchFamily + VecZnxNormalizeTmpBytes, { GGSWCiphertext::automorphism_scratch_space( - module, basek, k_out, k_out, k_ksk, digits_ksk, k_tsk, digits_tsk, rank, + module, n, basek, k_out, k_out, k_ksk, digits_ksk, k_tsk, digits_tsk, rank, ) } } @@ -65,6 +67,9 @@ impl GGSWCiphertext { { use crate::Infos; + assert_eq!(self.n(), auto_key.n()); + assert_eq!(lhs.n(), auto_key.n()); + assert_eq!( self.rank(), lhs.rank(), @@ -90,6 +95,7 @@ impl GGSWCiphertext { scratch.available() >= GGSWCiphertext::automorphism_scratch_space( module, + self.n(), self.basek(), self.k(), lhs.k(), @@ -102,6 +108,7 @@ impl GGSWCiphertext { ) }; + let n: usize = auto_key.n(); let rank: usize = self.rank(); let cols: usize = rank + 1; @@ -113,7 +120,7 @@ impl GGSWCiphertext { .automorphism(module, &lhs.at(row_i, 0), auto_key, scratch); // Isolates DFT(AUTO(a[i])) - let (mut ci_dft, scratch1) = scratch.take_vec_znx_dft(module, cols, self.size()); + let (mut ci_dft, scratch1) = scratch.take_vec_znx_dft(n, cols, self.size()); (0..cols).for_each(|i| { module.vec_znx_dft_from_vec_znx(1, 0, &mut ci_dft, i, &self.at(row_i, 0).data, i); }); diff --git a/core/src/ggsw/encryption.rs b/core/src/ggsw/encryption.rs index 073f04c..c309c41 100644 --- a/core/src/ggsw/encryption.rs +++ b/core/src/ggsw/encryption.rs @@ -1,8 +1,6 @@ use backend::hal::{ - api::{ - ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxAddScalarInplace, VecZnxAllocBytes, VecZnxNormalizeInplace, ZnxZero, - }, - layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch}, + api::{ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxAddScalarInplace, VecZnxNormalizeInplace, ZnxZero}, + layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, VecZnx}, }; use sampling::source::Source; @@ -14,15 +12,15 @@ use crate::{ pub trait GGSWEncryptSkFamily = GLWEEncryptSkFamily; impl GGSWCiphertext> { - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, n: usize, basek: usize, k: usize, rank: usize) -> usize where - Module: GGSWEncryptSkFamily + VecZnxAllocBytes, + Module: GGSWEncryptSkFamily, { let size = k.div_ceil(basek); - GLWECiphertext::encrypt_sk_scratch_space(module, basek, k) - + module.vec_znx_alloc_bytes(rank + 1, size) - + module.vec_znx_alloc_bytes(1, size) - + module.vec_znx_dft_alloc_bytes(rank + 1, size) + GLWECiphertext::encrypt_sk_scratch_space(module, n, basek, k) + + VecZnx::alloc_bytes(n, rank + 1, size) + + VecZnx::alloc_bytes(n, 1, size) + + module.vec_znx_dft_alloc_bytes(n, rank + 1, size) } } @@ -38,16 +36,15 @@ impl GGSWCiphertext { scratch: &mut Scratch, ) where Module: GGSWEncryptSkFamily + VecZnxAddScalarInplace, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { #[cfg(debug_assertions)] { use backend::hal::api::ZnxInfos; assert_eq!(self.rank(), sk.rank()); - assert_eq!(self.n(), module.n()); - assert_eq!(pt.n(), module.n()); - assert_eq!(sk.n(), module.n()); + assert_eq!(self.n(), sk.n()); + assert_eq!(pt.n(), sk.n()); } let basek: usize = self.basek(); @@ -55,7 +52,7 @@ impl GGSWCiphertext { let rank: usize = self.rank(); let digits: usize = self.digits(); - let (mut tmp_pt, scratch1) = scratch.take_glwe_pt(module, basek, k); + let (mut tmp_pt, scratch1) = scratch.take_glwe_pt(self.n(), basek, k); (0..self.rows()).for_each(|row_i| { tmp_pt.data.zero(); @@ -82,11 +79,11 @@ impl GGSWCiphertext { } impl GGSWCiphertextCompressed> { - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, n: usize, basek: usize, k: usize, rank: usize) -> usize where - Module: GGSWEncryptSkFamily + VecZnxAllocBytes, + Module: GGSWEncryptSkFamily, { - GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k, rank) + GGSWCiphertext::encrypt_sk_scratch_space(module, n, basek, k, rank) } } @@ -102,16 +99,15 @@ impl GGSWCiphertextCompressed { scratch: &mut Scratch, ) where Module: GGSWEncryptSkFamily + VecZnxAddScalarInplace, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { #[cfg(debug_assertions)] { use backend::hal::api::ZnxInfos; assert_eq!(self.rank(), sk.rank()); - assert_eq!(self.n(), module.n()); - assert_eq!(pt.n(), module.n()); - assert_eq!(sk.n(), module.n()); + assert_eq!(self.n(), sk.n()); + assert_eq!(pt.n(), sk.n()); } let basek: usize = self.basek(); @@ -120,10 +116,12 @@ impl GGSWCiphertextCompressed { let cols: usize = rank + 1; let digits: usize = self.digits(); - let (mut tmp_pt, scratch_1) = scratch.take_glwe_pt(module, basek, k); + let (mut tmp_pt, scratch_1) = scratch.take_glwe_pt(self.n(), basek, k); let mut source = Source::new(seed_xa); + self.seed = vec![[0u8; 32]; self.rows() * cols]; + (0..self.rows()).for_each(|row_i| { tmp_pt.data.zero(); @@ -137,7 +135,7 @@ impl GGSWCiphertextCompressed { let (seed, mut source_xa_tmp) = source.branch(); self.seed[row_i * cols + col_j] = seed; - + encrypt_sk_internal( module, self.basek(), diff --git a/core/src/ggsw/external_product.rs b/core/src/ggsw/external_product.rs index c279ac3..26d2bab 100644 --- a/core/src/ggsw/external_product.rs +++ b/core/src/ggsw/external_product.rs @@ -8,6 +8,7 @@ use crate::{GGSWCiphertext, GGSWCiphertextExec, GLWECiphertext, GLWEExternalProd impl GGSWCiphertext> { pub fn external_product_scratch_space( module: &Module, + n: usize, basek: usize, k_out: usize, k_in: usize, @@ -18,11 +19,12 @@ impl GGSWCiphertext> { where Module: GLWEExternalProductFamily, { - GLWECiphertext::external_product_scratch_space(module, basek, k_out, k_in, k_ggsw, digits, rank) + GLWECiphertext::external_product_scratch_space(module, n, basek, k_out, k_in, k_ggsw, digits, rank) } pub fn external_product_inplace_scratch_space( module: &Module, + n: usize, basek: usize, k_out: usize, k_ggsw: usize, @@ -32,7 +34,7 @@ impl GGSWCiphertext> { where Module: GLWEExternalProductFamily, { - GLWECiphertext::external_product_inplace_scratch_space(module, basek, k_out, k_ggsw, digits, rank) + GLWECiphertext::external_product_inplace_scratch_space(module, n, basek, k_out, k_ggsw, digits, rank) } } @@ -51,6 +53,9 @@ impl GGSWCiphertext { { use crate::{GGSWCiphertext, Infos}; + assert_eq!(lhs.n(), self.n()); + assert_eq!(rhs.n(), self.n()); + assert_eq!( self.rank(), lhs.rank(), @@ -70,6 +75,7 @@ impl GGSWCiphertext { scratch.available() >= GGSWCiphertext::external_product_scratch_space( module, + self.n(), self.basek(), self.k(), lhs.k(), @@ -104,6 +110,7 @@ impl GGSWCiphertext { { #[cfg(debug_assertions)] { + assert_eq!(rhs.n(), self.n()); assert_eq!( self.rank(), rhs.rank(), diff --git a/core/src/ggsw/keyswitch.rs b/core/src/ggsw/keyswitch.rs index ac84e28..f549e19 100644 --- a/core/src/ggsw/keyswitch.rs +++ b/core/src/ggsw/keyswitch.rs @@ -1,9 +1,9 @@ use backend::hal::{ api::{ - ScratchAvailable, TakeVecZnxBig, TakeVecZnxDft, VecZnxAllocBytes, VecZnxBigAllocBytes, VecZnxDftAddInplace, - VecZnxDftCopy, VecZnxDftToVecZnxBigTmpA, VecZnxNormalizeTmpBytes, ZnxInfos, + ScratchAvailable, TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAllocBytes, VecZnxDftAddInplace, VecZnxDftCopy, + VecZnxDftToVecZnxBigTmpA, VecZnxNormalizeTmpBytes, ZnxInfos, }, - layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnxDft, VmpPMat}, + layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx, VecZnxDft, VmpPMat}, }; use crate::{GGSWCiphertext, GLWECiphertext, GLWEKeyswitchFamily, GLWESwitchingKeyExec, GLWETensorKeyExec, Infos}; @@ -14,6 +14,7 @@ pub trait GGSWKeySwitchFamily = impl GGSWCiphertext> { pub(crate) fn expand_row_scratch_space( module: &Module, + n: usize, basek: usize, self_k: usize, k_tsk: usize, @@ -27,9 +28,10 @@ impl GGSWCiphertext> { let self_size_out: usize = self_k.div_ceil(basek); let self_size_in: usize = self_size_out.div_ceil(digits); - let tmp_dft_i: usize = module.vec_znx_dft_alloc_bytes(rank + 1, tsk_size); - let tmp_a: usize = module.vec_znx_dft_alloc_bytes(1, self_size_in); + let tmp_dft_i: usize = module.vec_znx_dft_alloc_bytes(n, rank + 1, tsk_size); + let tmp_a: usize = module.vec_znx_dft_alloc_bytes(n, 1, self_size_in); let vmp: usize = module.vmp_apply_tmp_bytes( + n, self_size_out, self_size_in, self_size_in, @@ -37,13 +39,14 @@ impl GGSWCiphertext> { rank, tsk_size, ); - let tmp_idft: usize = module.vec_znx_big_alloc_bytes(1, tsk_size); + let tmp_idft: usize = module.vec_znx_big_alloc_bytes(n, 1, tsk_size); let norm: usize = module.vec_znx_normalize_tmp_bytes(module.n()); tmp_dft_i + ((tmp_a + vmp) | (tmp_idft + norm)) } pub fn keyswitch_scratch_space( module: &Module, + n: usize, basek: usize, k_out: usize, k_in: usize, @@ -54,19 +57,20 @@ impl GGSWCiphertext> { rank: usize, ) -> usize where - Module: GLWEKeyswitchFamily + GGSWKeySwitchFamily + VecZnxAllocBytes + VecZnxNormalizeTmpBytes, + Module: GLWEKeyswitchFamily + GGSWKeySwitchFamily + VecZnxNormalizeTmpBytes, { let out_size: usize = k_out.div_ceil(basek); - let res_znx: usize = module.vec_znx_alloc_bytes(rank + 1, out_size); - let ci_dft: usize = module.vec_znx_dft_alloc_bytes(rank + 1, out_size); - let ks: usize = GLWECiphertext::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits_ksk, rank, rank); - let expand_rows: usize = GGSWCiphertext::expand_row_scratch_space(module, basek, k_out, k_tsk, digits_tsk, rank); - let res_dft: usize = module.vec_znx_dft_alloc_bytes(rank + 1, out_size); + let res_znx: usize = VecZnx::alloc_bytes(n, rank + 1, out_size); + let ci_dft: usize = module.vec_znx_dft_alloc_bytes(n, rank + 1, out_size); + let ks: usize = GLWECiphertext::keyswitch_scratch_space(module, n, basek, k_out, k_in, k_ksk, digits_ksk, rank, rank); + let expand_rows: usize = GGSWCiphertext::expand_row_scratch_space(module, n, basek, k_out, k_tsk, digits_tsk, rank); + let res_dft: usize = module.vec_znx_dft_alloc_bytes(n, rank + 1, out_size); res_znx + ci_dft + (ks | expand_rows | res_dft) } pub fn keyswitch_inplace_scratch_space( module: &Module, + n: usize, basek: usize, k_out: usize, k_ksk: usize, @@ -76,10 +80,10 @@ impl GGSWCiphertext> { rank: usize, ) -> usize where - Module: GLWEKeyswitchFamily + GGSWKeySwitchFamily + VecZnxAllocBytes + VecZnxNormalizeTmpBytes, + Module: GLWEKeyswitchFamily + GGSWKeySwitchFamily + VecZnxNormalizeTmpBytes, { GGSWCiphertext::keyswitch_scratch_space( - module, basek, k_out, k_out, k_ksk, digits_ksk, k_tsk, digits_tsk, rank, + module, n, basek, k_out, k_out, k_ksk, digits_ksk, k_tsk, digits_tsk, rank, ) } } @@ -99,10 +103,16 @@ impl GGSWCiphertext { { let cols: usize = self.rank() + 1; + #[cfg(debug_assertions)] + { + assert_eq!(self.n(), tsk.n()); + } + assert!( scratch.available() >= GGSWCiphertext::expand_row_scratch_space( module, + self.n(), self.basek(), self.k(), tsk.k(), @@ -131,10 +141,11 @@ impl GGSWCiphertext { // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M[i], c2 ) // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M[i]) + let n: usize = self.n(); let digits: usize = tsk.digits(); - let (mut tmp_dft_i, scratch1) = scratch.take_vec_znx_dft(module, cols, tsk.size()); - let (mut tmp_a, scratch2) = scratch1.take_vec_znx_dft(module, 1, ci_dft.size().div_ceil(digits)); + let (mut tmp_dft_i, scratch1) = scratch.take_vec_znx_dft(n, cols, tsk.size()); + let (mut tmp_a, scratch2) = scratch1.take_vec_znx_dft(n, 1, ci_dft.size().div_ceil(digits)); { // Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2 @@ -184,7 +195,7 @@ impl GGSWCiphertext { // = // (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2) module.vec_znx_dft_add_inplace(&mut tmp_dft_i, col_j, ci_dft, 0); - let (mut tmp_idft, scratch2) = scratch1.take_vec_znx_big(module, 1, tsk.size()); + let (mut tmp_idft, scratch2) = scratch1.take_vec_znx_big(n, 1, tsk.size()); (0..cols).for_each(|i| { module.vec_znx_dft_to_vec_znx_big_tmp_a(&mut tmp_idft, 0, &mut tmp_dft_i, i); module.vec_znx_big_normalize( @@ -209,6 +220,7 @@ impl GGSWCiphertext { Module: GLWEKeyswitchFamily + GGSWKeySwitchFamily + VecZnxNormalizeTmpBytes, Scratch: TakeVecZnxDft + TakeVecZnxBig + ScratchAvailable, { + let n: usize = self.n(); let rank: usize = self.rank(); let cols: usize = rank + 1; @@ -220,7 +232,7 @@ impl GGSWCiphertext { .keyswitch(module, &lhs.at(row_i, 0), ksk, scratch); // Pre-compute DFT of (a0, a1, a2) - let (mut ci_dft, scratch1) = scratch.take_vec_znx_dft(module, cols, self.size()); + let (mut ci_dft, scratch1) = scratch.take_vec_znx_dft(n, cols, self.size()); (0..cols).for_each(|i| { module.vec_znx_dft_from_vec_znx(1, 0, &mut ci_dft, i, &self.at(row_i, 0).data, i); }); diff --git a/core/src/ggsw/layout.rs b/core/src/ggsw/layout.rs index 397f84f..9d209aa 100644 --- a/core/src/ggsw/layout.rs +++ b/core/src/ggsw/layout.rs @@ -1,13 +1,14 @@ use backend::hal::{ - api::{MatZnxAlloc, MatZnxAllocBytes, VmpPMatAlloc, VmpPMatAllocBytes, VmpPMatPrepare}, - layouts::{Backend, Data, DataMut, DataRef, MatZnx, Module, ReaderFrom, WriterTo}, + api::{FillUniform, Reset, VmpPMatAlloc, VmpPMatAllocBytes, VmpPMatPrepare}, + layouts::{Backend, Data, DataMut, DataRef, MatZnx, ReaderFrom, WriterTo}, }; +use std::fmt; use crate::{GLWECiphertext, Infos}; pub trait GGSWLayoutFamily = VmpPMatAlloc + VmpPMatAllocBytes + VmpPMatPrepare; -#[derive(PartialEq, Eq)] +#[derive(PartialEq, Eq, Clone)] pub struct GGSWCiphertext { pub(crate) data: MatZnx, pub(crate) basek: usize, @@ -15,6 +16,37 @@ pub struct GGSWCiphertext { pub(crate) digits: usize, } +impl fmt::Debug for GGSWCiphertext { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "(GGSWCiphertext: basek={} k={} digits={}) {}", + self.basek, self.k, self.digits, self.data + ) + } +} + +impl Reset for GGSWCiphertext +where + MatZnx: Reset, +{ + fn reset(&mut self) { + self.data.reset(); + self.basek = 0; + self.k = 0; + self.digits = 0; + } +} + +impl FillUniform for GGSWCiphertext +where + MatZnx: FillUniform, +{ + fn fill_uniform(&mut self, source: &mut sampling::source::Source) { + self.data.fill_uniform(source); + } +} + impl GGSWCiphertext { pub fn at(&self, row: usize, col: usize) -> GLWECiphertext<&[u8]> { GLWECiphertext { @@ -36,10 +68,7 @@ impl GGSWCiphertext { } impl GGSWCiphertext> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self - where - Module: MatZnxAlloc, - { + pub fn alloc(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self { let size: usize = k.div_ceil(basek); debug_assert!(digits > 0, "invalid ggsw: `digits` == 0"); @@ -59,17 +88,14 @@ impl GGSWCiphertext> { ); Self { - data: module.mat_znx_alloc(rows, rank + 1, rank + 1, k.div_ceil(basek)), + data: MatZnx::alloc(n, rows, rank + 1, rank + 1, k.div_ceil(basek)), basek, k: k, digits, } } - pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize - where - Module: MatZnxAllocBytes, - { + pub fn bytes_of(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize { let size: usize = k.div_ceil(basek); debug_assert!( size > digits, @@ -86,7 +112,7 @@ impl GGSWCiphertext> { size ); - module.mat_znx_alloc_bytes(rows, rank + 1, rank + 1, size) + MatZnx::alloc_bytes(n, rows, rank + 1, rank + 1, size) } } diff --git a/core/src/ggsw/layout_compressed.rs b/core/src/ggsw/layout_compressed.rs index 2d139fd..f2a88b2 100644 --- a/core/src/ggsw/layout_compressed.rs +++ b/core/src/ggsw/layout_compressed.rs @@ -1,11 +1,13 @@ use backend::hal::{ - api::{MatZnxAlloc, MatZnxAllocBytes, VecZnxCopy, VecZnxFillUniform}, + api::{FillUniform, Reset, VecZnxCopy, VecZnxFillUniform}, layouts::{Backend, Data, DataMut, DataRef, MatZnx, Module, ReaderFrom, WriterTo}, }; use crate::{Decompress, GGSWCiphertext, GLWECiphertextCompressed, Infos}; +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; +use std::fmt; -#[derive(PartialEq, Eq)] +#[derive(PartialEq, Eq, Clone)] pub struct GGSWCiphertextCompressed { pub(crate) data: MatZnx, pub(crate) basek: usize, @@ -15,11 +17,41 @@ pub struct GGSWCiphertextCompressed { pub(crate) seed: Vec<[u8; 32]>, } +impl fmt::Debug for GGSWCiphertextCompressed { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "(GGSWCiphertextCompressed: basek={} k={} digits={}) {}", + self.basek, self.k, self.digits, self.data + ) + } +} + +impl Reset for GGSWCiphertextCompressed +where + MatZnx: Reset, +{ + fn reset(&mut self) { + self.data.reset(); + self.basek = 0; + self.k = 0; + self.digits = 0; + self.rank = 0; + self.seed = Vec::new(); + } +} + +impl FillUniform for GGSWCiphertextCompressed +where + MatZnx: FillUniform, +{ + fn fill_uniform(&mut self, source: &mut sampling::source::Source) { + self.data.fill_uniform(source); + } +} + impl GGSWCiphertextCompressed> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self - where - Module: MatZnxAlloc, - { + pub fn alloc(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self { let size: usize = k.div_ceil(basek); debug_assert!(digits > 0, "invalid ggsw: `digits` == 0"); @@ -39,19 +71,16 @@ impl GGSWCiphertextCompressed> { ); Self { - data: module.mat_znx_alloc(rows, rank + 1, 1, k.div_ceil(basek)), + data: MatZnx::alloc(n, rows, rank + 1, 1, k.div_ceil(basek)), basek, k: k, digits, rank, - seed: vec![[0u8; 32]; rows * (rank + 1)], + seed: Vec::new(), } } - pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize - where - Module: MatZnxAllocBytes, - { + pub fn bytes_of(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize { let size: usize = k.div_ceil(basek); debug_assert!( size > digits, @@ -68,7 +97,7 @@ impl GGSWCiphertextCompressed> { size ); - module.mat_znx_alloc_bytes(rows, rank + 1, 1, size) + MatZnx::alloc_bytes(n, rows, rank + 1, 1, size) } } @@ -125,12 +154,29 @@ impl GGSWCiphertextCompressed { impl ReaderFrom for GGSWCiphertextCompressed { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { + self.k = reader.read_u64::()? as usize; + self.basek = reader.read_u64::()? as usize; + self.digits = reader.read_u64::()? as usize; + self.rank = reader.read_u64::()? as usize; + let seed_len = reader.read_u64::()? as usize; + self.seed = vec![[0u8; 32]; seed_len]; + for s in &mut self.seed { + reader.read_exact(s)?; + } self.data.read_from(reader) } } impl WriterTo for GGSWCiphertextCompressed { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + writer.write_u64::(self.k as u64)?; + writer.write_u64::(self.basek as u64)?; + writer.write_u64::(self.digits as u64)?; + writer.write_u64::(self.rank as u64)?; + writer.write_u64::(self.seed.len() as u64)?; + for s in &self.seed { + writer.write_all(s)?; + } self.data.write_to(writer) } } diff --git a/core/src/ggsw/layout_exec.rs b/core/src/ggsw/layout_exec.rs index dd11c86..d349f9c 100644 --- a/core/src/ggsw/layout_exec.rs +++ b/core/src/ggsw/layout_exec.rs @@ -14,7 +14,7 @@ pub struct GGSWCiphertextExec { } impl GGSWCiphertextExec, B> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self + pub fn alloc(module: &Module, n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self where Module: GGSWLayoutFamily, { @@ -37,14 +37,14 @@ impl GGSWCiphertextExec, B> { ); Self { - data: module.vmp_pmat_alloc(rows, rank + 1, rank + 1, k.div_ceil(basek)), + data: module.vmp_pmat_alloc(n, rows, rank + 1, rank + 1, k.div_ceil(basek)), basek, k: k, digits, } } - pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize + pub fn bytes_of(module: &Module, n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize where Module: GGSWLayoutFamily, { @@ -64,7 +64,7 @@ impl GGSWCiphertextExec, B> { size ); - module.vmp_pmat_alloc_bytes(rows, rank + 1, rank + 1, size) + module.vmp_pmat_alloc_bytes(n, rows, rank + 1, rank + 1, size) } pub fn from( @@ -77,6 +77,7 @@ impl GGSWCiphertextExec, B> { { let mut ggsw_exec: GGSWCiphertextExec, B> = Self::alloc( module, + other.n(), other.basek(), other.k(), other.rows(), diff --git a/core/src/ggsw/noise.rs b/core/src/ggsw/noise.rs index e98636c..83590e2 100644 --- a/core/src/ggsw/noise.rs +++ b/core/src/ggsw/noise.rs @@ -1,6 +1,6 @@ use backend::hal::{ api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddScalarInplace, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigNormalize, + ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddScalarInplace, VecZnxBigAlloc, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAlloc, VecZnxDftToVecZnxBigTmpA, VecZnxNormalizeTmpBytes, VecZnxStd, VecZnxSubABInplace, ZnxZero, }, @@ -27,7 +27,7 @@ impl GGSWCiphertext { ) where DataSk: DataRef, DataScalar: DataRef, - Module: GGSWAssertNoiseFamily + VecZnxAlloc + VecZnxAddScalarInplace + VecZnxSubABInplace + VecZnxStd, + Module: GGSWAssertNoiseFamily + VecZnxAddScalarInplace + VecZnxSubABInplace + VecZnxStd, B: TakeVecZnxDftImpl + TakeVecZnxBigImpl + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, F: Fn(usize) -> f64, { @@ -35,13 +35,13 @@ impl GGSWCiphertext { let k: usize = self.k(); let digits: usize = self.digits(); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k); - let mut pt_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(1, self.size()); - let mut pt_big: VecZnxBig, B> = module.vec_znx_big_alloc(1, self.size()); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(self.n(), basek, k); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(self.n(), basek, k); + let mut pt_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(self.n(), 1, self.size()); + let mut pt_big: VecZnxBig, B> = module.vec_znx_big_alloc(self.n(), 1, self.size()); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::decrypt_scratch_space(module, basek, k) | module.vec_znx_normalize_tmp_bytes(module.n()), + GLWECiphertext::decrypt_scratch_space(module, self.n(), basek, k) | module.vec_znx_normalize_tmp_bytes(self.n()), ); (0..self.rank() + 1).for_each(|col_j| { diff --git a/core/src/ggsw/test/generic_serialization.rs b/core/src/ggsw/test/generic_serialization.rs new file mode 100644 index 0000000..ee39727 --- /dev/null +++ b/core/src/ggsw/test/generic_serialization.rs @@ -0,0 +1,15 @@ +use backend::hal::tests::serialization::test_reader_writer_interface; + +use crate::{GGSWCiphertext, GGSWCiphertextCompressed}; + +#[test] +fn ggsw_test_serialization() { + let original: GGSWCiphertext> = GGSWCiphertext::alloc(1024, 12, 54, 3, 1, 2); + test_reader_writer_interface(original); +} + +#[test] +fn ggsw_test_serialization_compressed() { + let original: GGSWCiphertextCompressed> = GGSWCiphertextCompressed::alloc(1024, 12, 54, 3, 1, 2); + test_reader_writer_interface(original); +} diff --git a/core/src/ggsw/test/generic_tests.rs b/core/src/ggsw/test/generic_tests.rs index 182ab16..dfb1c6d 100644 --- a/core/src/ggsw/test/generic_tests.rs +++ b/core/src/ggsw/test/generic_tests.rs @@ -1,8 +1,7 @@ use backend::hal::{ api::{ - MatZnxAlloc, ScalarZnxAlloc, ScalarZnxAllocBytes, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddScalarInplace, - VecZnxAlloc, VecZnxAllocBytes, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxCopy, VecZnxRotateInplace, VecZnxStd, - VecZnxSubABInplace, VecZnxSwithcDegree, ZnxViewMut, + ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxCopy, + VecZnxRotateInplace, VecZnxStd, VecZnxSubABInplace, VecZnxSwithcDegree, ZnxViewMut, }, layouts::{Backend, Module, ScalarZnx, ScalarZnxToMut, ScratchOwned}, oep::{ @@ -23,14 +22,9 @@ use crate::{ pub(crate) trait TestModuleFamily = GLWESecretFamily + GGSWEncryptSkFamily + GGSWAssertNoiseFamily - + VecZnxAlloc - + ScalarZnxAlloc - + VecZnxAllocBytes - + MatZnxAlloc + VecZnxAddScalarInplace + VecZnxSubABInplace + VecZnxStd - + ScalarZnxAllocBytes + VecZnxCopy; pub(crate) trait TestScratchFamily = TakeVecZnxDftImpl + TakeVecZnxBigImpl @@ -49,23 +43,24 @@ where Module: TestModuleFamily, B: TestScratchFamily, { + let n: usize = module.n(); let rows: usize = (k - digits * basek) / (digits * basek); - let mut ct: GGSWCiphertext> = GGSWCiphertext::alloc(module, basek, k, rows, digits, rank); + let mut ct: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k, rows, digits, rank); - let mut pt_scalar: ScalarZnx> = module.scalar_znx_alloc(1); + let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + pt_scalar.fill_ternary_hw(0, n, &mut source_xs); let mut scratch: ScratchOwned = ScratchOwned::alloc(GGSWCiphertext::encrypt_sk_scratch_space( - module, basek, k, rank, + module, n, basek, k, rank, )); - let mut sk: GLWESecret> = GLWESecret::alloc(module, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); sk.fill_ternary_prob(0.5, &mut source_xs); let mut sk_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk); sk_exec.prepare(module, &sk); @@ -96,23 +91,23 @@ pub(crate) fn test_encrypt_sk_compressed( Module: TestModuleFamily, B: TestScratchFamily, { + let n: usize = module.n(); let rows: usize = (k - digits * basek) / (digits * basek); - let mut ct_compressed: GGSWCiphertextCompressed> = - GGSWCiphertextCompressed::alloc(module, basek, k, rows, digits, rank); + let mut ct_compressed: GGSWCiphertextCompressed> = GGSWCiphertextCompressed::alloc(n, basek, k, rows, digits, rank); - let mut pt_scalar: ScalarZnx> = module.scalar_znx_alloc(1); + let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); - pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + pt_scalar.fill_ternary_hw(0, n, &mut source_xs); let mut scratch: ScratchOwned = ScratchOwned::alloc(GGSWCiphertextCompressed::encrypt_sk_scratch_space( - module, basek, k, rank, + module, n, basek, k, rank, )); - let mut sk: GLWESecret> = GLWESecret::alloc(module, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); sk.fill_ternary_prob(0.5, &mut source_xs); let mut sk_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk); sk_exec.prepare(module, &sk); @@ -131,7 +126,7 @@ pub(crate) fn test_encrypt_sk_compressed( let noise_f = |_col_i: usize| -(k as f64) + sigma.log2() + 0.5; - let mut ct: GGSWCiphertext> = GGSWCiphertext::alloc(module, basek, k, rows, digits, rank); + let mut ct: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k, rows, digits, rank); ct.decompress(module, &ct_compressed); ct.assert_noise(module, &sk_exec, &pt_scalar, &noise_f); @@ -157,36 +152,37 @@ pub(crate) fn test_keyswitch( + VecZnxSwithcDegree, B: TestScratchFamily + VecZnxDftAllocBytesImpl + VecZnxBigAllocBytesImpl + TakeSvpPPolImpl, { + let n: usize = module.n(); let rows: usize = k_in.div_ceil(digits * basek); let digits_in: usize = 1; - let mut ct_in: GGSWCiphertext> = GGSWCiphertext::alloc(module, basek, k_in, rows, digits_in, rank); - let mut ct_out: GGSWCiphertext> = GGSWCiphertext::alloc(module, basek, k_out, rows, digits_in, rank); - let mut tsk: GLWETensorKey> = GLWETensorKey::alloc(module, basek, k_ksk, rows, digits, rank); - let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc(module, basek, k_ksk, rows, digits, rank, rank); - let mut pt_scalar: ScalarZnx> = module.scalar_znx_alloc(1); + let mut ct_in: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_in, rows, digits_in, rank); + let mut ct_out: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_out, rows, digits_in, rank); + let mut tsk: GLWETensorKey> = GLWETensorKey::alloc(n, basek, k_ksk, rows, digits, rank); + let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc(n, basek, k_ksk, rows, digits, rank, rank); + let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_in, rank) - | GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k_ksk, rank, rank) - | GLWETensorKey::encrypt_sk_scratch_space(module, basek, k_tsk, rank) + GGSWCiphertext::encrypt_sk_scratch_space(module, n, basek, k_in, rank) + | GLWESwitchingKey::encrypt_sk_scratch_space(module, n, basek, k_ksk, rank, rank) + | GLWETensorKey::encrypt_sk_scratch_space(module, n, basek, k_tsk, rank) | GGSWCiphertext::keyswitch_scratch_space( - module, basek, k_out, k_in, k_ksk, digits, k_tsk, digits, rank, + module, n, basek, k_out, k_in, k_ksk, digits, k_tsk, digits, rank, ), ); let var_xs: f64 = 0.5; - let mut sk_in: GLWESecret> = GLWESecret::alloc(module, rank); + let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank); sk_in.fill_ternary_prob(var_xs, &mut source_xs); let sk_in_dft: GLWESecretExec, B> = GLWESecretExec::from(module, &sk_in); - let mut sk_out: GLWESecret> = GLWESecret::alloc(module, rank); + let mut sk_out: GLWESecret> = GLWESecret::alloc(n, rank); sk_out.fill_ternary_prob(var_xs, &mut source_xs); let sk_out_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk_out); @@ -208,7 +204,7 @@ pub(crate) fn test_keyswitch( scratch.borrow(), ); - pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + pt_scalar.fill_ternary_hw(0, n, &mut source_xs); ct_in.encrypt_sk( module, @@ -221,8 +217,8 @@ pub(crate) fn test_keyswitch( ); let mut ksk_exec: GLWESwitchingKeyExec, B> = - GLWESwitchingKeyExec::alloc(module, basek, k_ksk, rows, digits, rank, rank); - let mut tsk_exec: GLWETensorKeyExec, B> = GLWETensorKeyExec::alloc(module, basek, k_ksk, rows, digits, rank); + GLWESwitchingKeyExec::alloc(module, n, basek, k_ksk, rows, digits, rank, rank); + let mut tsk_exec: GLWETensorKeyExec, B> = GLWETensorKeyExec::alloc(module, n, basek, k_ksk, rows, digits, rank); ksk_exec.prepare(module, &ksk, scratch.borrow()); tsk_exec.prepare(module, &tsk, scratch.borrow()); @@ -231,7 +227,7 @@ pub(crate) fn test_keyswitch( let max_noise = |col_j: usize| -> f64 { noise_ggsw_keyswitch( - module.n() as f64, + n as f64, basek * digits, col_j, var_xs, @@ -267,33 +263,34 @@ pub(crate) fn test_keyswitch_inplace( + VecZnxSwithcDegree, B: TestScratchFamily, { + let n: usize = module.n(); let rows: usize = k_ct.div_ceil(digits * basek); let digits_in: usize = 1; - let mut ct: GGSWCiphertext> = GGSWCiphertext::alloc(module, basek, k_ct, rows, digits_in, rank); - let mut tsk: GLWETensorKey> = GLWETensorKey::alloc(module, basek, k_tsk, rows, digits, rank); - let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc(module, basek, k_ksk, rows, digits, rank, rank); - let mut pt_scalar: ScalarZnx> = module.scalar_znx_alloc(1); + let mut ct: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ct, rows, digits_in, rank); + let mut tsk: GLWETensorKey> = GLWETensorKey::alloc(n, basek, k_tsk, rows, digits, rank); + let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc(n, basek, k_ksk, rows, digits, rank, rank); + let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_ct, rank) - | GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k_ksk, rank, rank) - | GLWETensorKey::encrypt_sk_scratch_space(module, basek, k_tsk, rank) - | GGSWCiphertext::keyswitch_inplace_scratch_space(module, basek, k_ct, k_ksk, digits, k_tsk, digits, rank), + GGSWCiphertext::encrypt_sk_scratch_space(module, n, basek, k_ct, rank) + | GLWESwitchingKey::encrypt_sk_scratch_space(module, n, basek, k_ksk, rank, rank) + | GLWETensorKey::encrypt_sk_scratch_space(module, n, basek, k_tsk, rank) + | GGSWCiphertext::keyswitch_inplace_scratch_space(module, n, basek, k_ct, k_ksk, digits, k_tsk, digits, rank), ); let var_xs: f64 = 0.5; - let mut sk_in: GLWESecret> = GLWESecret::alloc(module, rank); + let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank); sk_in.fill_ternary_prob(var_xs, &mut source_xs); let sk_in_dft: GLWESecretExec, B> = GLWESecretExec::from(module, &sk_in); - let mut sk_out: GLWESecret> = GLWESecret::alloc(module, rank); + let mut sk_out: GLWESecret> = GLWESecret::alloc(n, rank); sk_out.fill_ternary_prob(var_xs, &mut source_xs); let sk_out_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk_out); @@ -315,7 +312,7 @@ pub(crate) fn test_keyswitch_inplace( scratch.borrow(), ); - pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + pt_scalar.fill_ternary_hw(0, n, &mut source_xs); ct.encrypt_sk( module, @@ -328,8 +325,8 @@ pub(crate) fn test_keyswitch_inplace( ); let mut ksk_exec: GLWESwitchingKeyExec, B> = - GLWESwitchingKeyExec::alloc(module, basek, k_ksk, rows, digits, rank, rank); - let mut tsk_exec: GLWETensorKeyExec, B> = GLWETensorKeyExec::alloc(module, basek, k_ksk, rows, digits, rank); + GLWESwitchingKeyExec::alloc(module, n, basek, k_ksk, rows, digits, rank, rank); + let mut tsk_exec: GLWETensorKeyExec, B> = GLWETensorKeyExec::alloc(module, n, basek, k_ksk, rows, digits, rank); ksk_exec.prepare(module, &ksk, scratch.borrow()); tsk_exec.prepare(module, &tsk, scratch.borrow()); @@ -338,7 +335,7 @@ pub(crate) fn test_keyswitch_inplace( let max_noise = |col_j: usize| -> f64 { noise_ggsw_keyswitch( - module.n() as f64, + n as f64, basek * digits, col_j, var_xs, @@ -379,33 +376,34 @@ pub(crate) fn test_automorphism( + VecZnxAutomorphism, B: TestScratchFamily, { + let n: usize = module.n(); let rows: usize = k_in.div_ceil(basek * digits); let rows_in: usize = k_in.div_euclid(basek * digits); let digits_in: usize = 1; - let mut ct_in: GGSWCiphertext> = GGSWCiphertext::alloc(module, basek, k_in, rows_in, digits_in, rank); - let mut ct_out: GGSWCiphertext> = GGSWCiphertext::alloc(module, basek, k_out, rows_in, digits_in, rank); - let mut tensor_key: GLWETensorKey> = GLWETensorKey::alloc(module, basek, k_tsk, rows, digits, rank); - let mut auto_key: AutomorphismKey> = AutomorphismKey::alloc(module, basek, k_ksk, rows, digits, rank); - let mut pt_scalar: ScalarZnx> = module.scalar_znx_alloc(1); + let mut ct_in: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_in, rows_in, digits_in, rank); + let mut ct_out: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_out, rows_in, digits_in, rank); + let mut tensor_key: GLWETensorKey> = GLWETensorKey::alloc(n, basek, k_tsk, rows, digits, rank); + let mut auto_key: AutomorphismKey> = AutomorphismKey::alloc(n, basek, k_ksk, rows, digits, rank); + let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_in, rank) - | AutomorphismKey::encrypt_sk_scratch_space(module, basek, k_ksk, rank) - | GLWETensorKey::encrypt_sk_scratch_space(module, basek, k_tsk, rank) + GGSWCiphertext::encrypt_sk_scratch_space(module, n, basek, k_in, rank) + | AutomorphismKey::encrypt_sk_scratch_space(module, n, basek, k_ksk, rank) + | GLWETensorKey::encrypt_sk_scratch_space(module, n, basek, k_tsk, rank) | GGSWCiphertext::automorphism_scratch_space( - module, basek, k_out, k_in, k_ksk, digits, k_tsk, digits, rank, + module, n, basek, k_out, k_in, k_ksk, digits, k_tsk, digits, rank, ), ); let var_xs: f64 = 0.5; - let mut sk: GLWESecret> = GLWESecret::alloc(module, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); sk.fill_ternary_prob(var_xs, &mut source_xs); let sk_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk); @@ -427,7 +425,7 @@ pub(crate) fn test_automorphism( scratch.borrow(), ); - pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + pt_scalar.fill_ternary_hw(0, n, &mut source_xs); ct_in.encrypt_sk( module, @@ -439,10 +437,11 @@ pub(crate) fn test_automorphism( scratch.borrow(), ); - let mut auto_key_exec: AutomorphismKeyExec, B> = AutomorphismKeyExec::alloc(module, basek, k_ksk, rows, digits, rank); + let mut auto_key_exec: AutomorphismKeyExec, B> = + AutomorphismKeyExec::alloc(module, n, basek, k_ksk, rows, digits, rank); auto_key_exec.prepare(module, &auto_key, scratch.borrow()); - let mut tsk_exec: GLWETensorKeyExec, B> = GLWETensorKeyExec::alloc(module, basek, k_tsk, rows, digits, rank); + let mut tsk_exec: GLWETensorKeyExec, B> = GLWETensorKeyExec::alloc(module, n, basek, k_tsk, rows, digits, rank); tsk_exec.prepare(module, &tensor_key, scratch.borrow()); ct_out.automorphism(module, &ct_in, &auto_key_exec, &tsk_exec, scratch.borrow()); @@ -451,7 +450,7 @@ pub(crate) fn test_automorphism( let max_noise = |col_j: usize| -> f64 { noise_ggsw_keyswitch( - module.n() as f64, + n as f64, basek * digits, col_j, var_xs, @@ -491,29 +490,30 @@ pub(crate) fn test_automorphism_inplace( + VecZnxAutomorphismInplace, B: TestScratchFamily, { + let n: usize = module.n(); let rows: usize = k_ct.div_ceil(digits * basek); let rows_in: usize = k_ct.div_euclid(basek * digits); let digits_in: usize = 1; - let mut ct: GGSWCiphertext> = GGSWCiphertext::alloc(module, basek, k_ct, rows_in, digits_in, rank); - let mut tensor_key: GLWETensorKey> = GLWETensorKey::alloc(module, basek, k_tsk, rows, digits, rank); - let mut auto_key: AutomorphismKey> = AutomorphismKey::alloc(module, basek, k_ksk, rows, digits, rank); - let mut pt_scalar: ScalarZnx> = module.scalar_znx_alloc(1); + let mut ct: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ct, rows_in, digits_in, rank); + let mut tensor_key: GLWETensorKey> = GLWETensorKey::alloc(n, basek, k_tsk, rows, digits, rank); + let mut auto_key: AutomorphismKey> = AutomorphismKey::alloc(n, basek, k_ksk, rows, digits, rank); + let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_ct, rank) - | AutomorphismKey::encrypt_sk_scratch_space(module, basek, k_ksk, rank) - | GLWETensorKey::encrypt_sk_scratch_space(module, basek, k_tsk, rank) - | GGSWCiphertext::automorphism_inplace_scratch_space(module, basek, k_ct, k_ksk, digits, k_tsk, digits, rank), + GGSWCiphertext::encrypt_sk_scratch_space(module, n, basek, k_ct, rank) + | AutomorphismKey::encrypt_sk_scratch_space(module, n, basek, k_ksk, rank) + | GLWETensorKey::encrypt_sk_scratch_space(module, n, basek, k_tsk, rank) + | GGSWCiphertext::automorphism_inplace_scratch_space(module, n, basek, k_ct, k_ksk, digits, k_tsk, digits, rank), ); let var_xs: f64 = 0.5; - let mut sk: GLWESecret> = GLWESecret::alloc(module, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); sk.fill_ternary_prob(var_xs, &mut source_xs); let sk_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk); @@ -535,7 +535,7 @@ pub(crate) fn test_automorphism_inplace( scratch.borrow(), ); - pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + pt_scalar.fill_ternary_hw(0, n, &mut source_xs); ct.encrypt_sk( module, @@ -547,10 +547,11 @@ pub(crate) fn test_automorphism_inplace( scratch.borrow(), ); - let mut auto_key_exec: AutomorphismKeyExec, B> = AutomorphismKeyExec::alloc(module, basek, k_ksk, rows, digits, rank); + let mut auto_key_exec: AutomorphismKeyExec, B> = + AutomorphismKeyExec::alloc(module, n, basek, k_ksk, rows, digits, rank); auto_key_exec.prepare(module, &auto_key, scratch.borrow()); - let mut tsk_exec: GLWETensorKeyExec, B> = GLWETensorKeyExec::alloc(module, basek, k_tsk, rows, digits, rank); + let mut tsk_exec: GLWETensorKeyExec, B> = GLWETensorKeyExec::alloc(module, n, basek, k_tsk, rows, digits, rank); tsk_exec.prepare(module, &tensor_key, scratch.borrow()); ct.automorphism_inplace(module, &auto_key_exec, &tsk_exec, scratch.borrow()); @@ -559,7 +560,7 @@ pub(crate) fn test_automorphism_inplace( let max_noise = |col_j: usize| -> f64 { noise_ggsw_keyswitch( - module.n() as f64, + n as f64, basek * digits, col_j, var_xs, @@ -595,15 +596,16 @@ pub(crate) fn test_external_product( + VecZnxRotateInplace, B: TestScratchFamily, { + let n: usize = module.n(); let rows: usize = k_in.div_ceil(basek * digits); let rows_in: usize = k_in.div_euclid(basek * digits); let digits_in: usize = 1; - let mut ct_ggsw_lhs_in: GGSWCiphertext> = GGSWCiphertext::alloc(module, basek, k_in, rows_in, digits_in, rank); - let mut ct_ggsw_lhs_out: GGSWCiphertext> = GGSWCiphertext::alloc(module, basek, k_out, rows_in, digits_in, rank); - let mut ct_ggsw_rhs: GGSWCiphertext> = GGSWCiphertext::alloc(module, basek, k_ggsw, rows, digits, rank); - let mut pt_ggsw_lhs: ScalarZnx> = module.scalar_znx_alloc(1); - let mut pt_ggsw_rhs: ScalarZnx> = module.scalar_znx_alloc(1); + let mut ct_ggsw_lhs_in: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_in, rows_in, digits_in, rank); + let mut ct_ggsw_lhs_out: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_out, rows_in, digits_in, rank); + let mut ct_ggsw_rhs: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ggsw, rows, digits, rank); + let mut pt_ggsw_lhs: ScalarZnx> = ScalarZnx::alloc(n, 1); + let mut pt_ggsw_rhs: ScalarZnx> = ScalarZnx::alloc(n, 1); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -616,11 +618,11 @@ pub(crate) fn test_external_product( pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_ggsw, rank) - | GGSWCiphertext::external_product_scratch_space(module, basek, k_out, k_in, k_ggsw, digits, rank), + GGSWCiphertext::encrypt_sk_scratch_space(module, n, basek, k_ggsw, rank) + | GGSWCiphertext::external_product_scratch_space(module, n, basek, k_out, k_in, k_ggsw, digits, rank), ); - let mut sk: GLWESecret> = GLWESecret::alloc(module, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk); @@ -644,7 +646,7 @@ pub(crate) fn test_external_product( scratch.borrow(), ); - let mut ct_rhs_exec: GGSWCiphertextExec, B> = GGSWCiphertextExec::alloc(module, basek, k_ggsw, rows, digits, rank); + let mut ct_rhs_exec: GGSWCiphertextExec, B> = GGSWCiphertextExec::alloc(module, n, basek, k_ggsw, rows, digits, rank); ct_rhs_exec.prepare(module, &ct_ggsw_rhs, scratch.borrow()); ct_ggsw_lhs_out.external_product(module, &ct_ggsw_lhs_in, &ct_rhs_exec, scratch.borrow()); @@ -654,13 +656,13 @@ pub(crate) fn test_external_product( let var_gct_err_lhs: f64 = sigma * sigma; let var_gct_err_rhs: f64 = 0f64; - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_msg: f64 = 1f64 / n as f64; // X^{k} let var_a0_err: f64 = sigma * sigma; let var_a1_err: f64 = 1f64 / 12f64; let max_noise = |_col_j: usize| -> f64 { noise_ggsw_product( - module.n() as f64, + n as f64, basek * digits, 0.5, var_msg, @@ -695,15 +697,16 @@ pub(crate) fn test_external_product_inplace( + VecZnxRotateInplace, B: TestScratchFamily, { + let n: usize = module.n(); let rows: usize = k_ct.div_ceil(digits * basek); let rows_in: usize = k_ct.div_euclid(basek * digits); let digits_in: usize = 1; - let mut ct_ggsw_lhs: GGSWCiphertext> = GGSWCiphertext::alloc(module, basek, k_ct, rows_in, digits_in, rank); - let mut ct_ggsw_rhs: GGSWCiphertext> = GGSWCiphertext::alloc(module, basek, k_ggsw, rows, digits, rank); + let mut ct_ggsw_lhs: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ct, rows_in, digits_in, rank); + let mut ct_ggsw_rhs: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ggsw, rows, digits, rank); - let mut pt_ggsw_lhs: ScalarZnx> = module.scalar_znx_alloc(1); - let mut pt_ggsw_rhs: ScalarZnx> = module.scalar_znx_alloc(1); + let mut pt_ggsw_lhs: ScalarZnx> = ScalarZnx::alloc(n, 1); + let mut pt_ggsw_rhs: ScalarZnx> = ScalarZnx::alloc(n, 1); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -716,11 +719,11 @@ pub(crate) fn test_external_product_inplace( pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_ggsw, rank) - | GGSWCiphertext::external_product_inplace_scratch_space(module, basek, k_ct, k_ggsw, digits, rank), + GGSWCiphertext::encrypt_sk_scratch_space(module, n, basek, k_ggsw, rank) + | GGSWCiphertext::external_product_inplace_scratch_space(module, n, basek, k_ct, k_ggsw, digits, rank), ); - let mut sk: GLWESecret> = GLWESecret::alloc(module, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk); @@ -744,7 +747,7 @@ pub(crate) fn test_external_product_inplace( scratch.borrow(), ); - let mut ct_rhs_exec: GGSWCiphertextExec, B> = GGSWCiphertextExec::alloc(module, basek, k_ggsw, rows, digits, rank); + let mut ct_rhs_exec: GGSWCiphertextExec, B> = GGSWCiphertextExec::alloc(module, n, basek, k_ggsw, rows, digits, rank); ct_rhs_exec.prepare(module, &ct_ggsw_rhs, scratch.borrow()); ct_ggsw_lhs.external_product_inplace(module, &ct_rhs_exec, scratch.borrow()); @@ -754,13 +757,13 @@ pub(crate) fn test_external_product_inplace( let var_gct_err_lhs: f64 = sigma * sigma; let var_gct_err_rhs: f64 = 0f64; - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_msg: f64 = 1f64 / n as f64; // X^{k} let var_a0_err: f64 = sigma * sigma; let var_a1_err: f64 = 1f64 / 12f64; let max_noise = |_col_j: usize| -> f64 { noise_ggsw_product( - module.n() as f64, + n as f64, basek * digits, 0.5, var_msg, diff --git a/core/src/ggsw/test/mod.rs b/core/src/ggsw/test/mod.rs index ac22b00..c3e241f 100644 --- a/core/src/ggsw/test/mod.rs +++ b/core/src/ggsw/test/mod.rs @@ -1,2 +1,3 @@ mod cpu_spqlios; +mod generic_serialization; mod generic_tests; diff --git a/core/src/glwe/automorphism.rs b/core/src/glwe/automorphism.rs index ac731e2..ec4ae93 100644 --- a/core/src/glwe/automorphism.rs +++ b/core/src/glwe/automorphism.rs @@ -11,6 +11,7 @@ use crate::{AutomorphismKeyExec, GLWECiphertext, GLWEKeyswitchFamily, Infos, glw impl GLWECiphertext> { pub fn automorphism_scratch_space( module: &Module, + n: usize, basek: usize, k_out: usize, k_in: usize, @@ -21,11 +22,12 @@ impl GLWECiphertext> { where Module: GLWEKeyswitchFamily, { - Self::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank, rank) + Self::keyswitch_scratch_space(module, n, basek, k_out, k_in, k_ksk, digits, rank, rank) } pub fn automorphism_inplace_scratch_space( module: &Module, + n: usize, basek: usize, k_out: usize, k_ksk: usize, @@ -35,7 +37,7 @@ impl GLWECiphertext> { where Module: GLWEKeyswitchFamily, { - Self::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank) + Self::keyswitch_inplace_scratch_space(module, n, basek, k_out, k_ksk, digits, rank) } } @@ -85,7 +87,7 @@ impl GLWECiphertext { { self.assert_keyswitch(module, lhs, &rhs.key, scratch); } - let (res_dft, scratch1) = scratch.take_vec_znx_dft(module, self.cols(), rhs.size()); // TODO: optimise size + let (res_dft, scratch1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size let mut res_big: VecZnxBig<_, B> = keyswitch(module, res_dft, lhs, &rhs.key, scratch1); (0..self.cols()).for_each(|i| { module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i); @@ -123,7 +125,7 @@ impl GLWECiphertext { { self.assert_keyswitch(module, lhs, &rhs.key, scratch); } - let (res_dft, scratch1) = scratch.take_vec_znx_dft(module, self.cols(), rhs.size()); // TODO: optimise size + let (res_dft, scratch1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size let mut res_big: VecZnxBig<_, B> = keyswitch(module, res_dft, lhs, &rhs.key, scratch1); (0..self.cols()).for_each(|i| { module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i); @@ -161,7 +163,7 @@ impl GLWECiphertext { { self.assert_keyswitch(module, lhs, &rhs.key, scratch); } - let (res_dft, scratch1) = scratch.take_vec_znx_dft(module, self.cols(), rhs.size()); // TODO: optimise size + let (res_dft, scratch1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size let mut res_big: VecZnxBig<_, B> = keyswitch(module, res_dft, lhs, &rhs.key, scratch1); (0..self.cols()).for_each(|i| { module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i); diff --git a/core/src/glwe/decryption.rs b/core/src/glwe/decryption.rs index fcd2e59..570a486 100644 --- a/core/src/glwe/decryption.rs +++ b/core/src/glwe/decryption.rs @@ -20,13 +20,13 @@ pub trait GLWEDecryptFamily = VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes; impl GLWECiphertext> { - pub fn decrypt_scratch_space(module: &Module, basek: usize, k: usize) -> usize + pub fn decrypt_scratch_space(module: &Module, n: usize, basek: usize, k: usize) -> usize where Module: GLWEDecryptFamily, { let size: usize = k.div_ceil(basek); - (module.vec_znx_normalize_tmp_bytes(module.n()) | module.vec_znx_dft_alloc_bytes(1, size)) - + module.vec_znx_dft_alloc_bytes(1, size) + (module.vec_znx_normalize_tmp_bytes(n) | module.vec_znx_dft_alloc_bytes(n, 1, size)) + + module.vec_znx_dft_alloc_bytes(n, 1, size) } } @@ -44,20 +44,19 @@ impl GLWECiphertext { #[cfg(debug_assertions)] { assert_eq!(self.rank(), sk.rank()); - assert_eq!(self.n(), module.n()); - assert_eq!(pt.n(), module.n()); - assert_eq!(sk.n(), module.n()); + assert_eq!(self.n(), sk.n()); + assert_eq!(pt.n(), sk.n()); } let cols: usize = self.rank() + 1; - let (mut c0_big, scratch_1) = scratch.take_vec_znx_big(module, 1, self.size()); // TODO optimize size when pt << ct + let (mut c0_big, scratch_1) = scratch.take_vec_znx_big(self.n(), 1, self.size()); // TODO optimize size when pt << ct c0_big.data_mut().fill(0); { (1..cols).for_each(|i| { // ci_dft = DFT(a[i]) * DFT(s[i]) - let (mut ci_dft, _) = scratch_1.take_vec_znx_dft(module, 1, self.size()); // TODO optimize size when pt << ct + let (mut ci_dft, _) = scratch_1.take_vec_znx_dft(self.n(), 1, self.size()); // TODO optimize size when pt << ct module.vec_znx_dft_from_vec_znx(1, 0, &mut ci_dft, 0, &self.data, i); module.svp_apply_inplace(&mut ci_dft, 0, &sk.data, i - 1); let ci_big = module.vec_znx_dft_to_vec_znx_big_consume(ci_dft); diff --git a/core/src/glwe/encryption.rs b/core/src/glwe/encryption.rs index bc9a95e..0b5edbc 100644 --- a/core/src/glwe/encryption.rs +++ b/core/src/glwe/encryption.rs @@ -1,12 +1,11 @@ use backend::hal::{ api::{ - ScalarZnxAllocBytes, ScratchAvailable, SvpApply, SvpApplyInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, - TakeSvpPPol, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAllocBytes, VecZnxBigAddNormal, - VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, - VecZnxDftToVecZnxBigConsume, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, - VecZnxSub, VecZnxSubABInplace, ZnxInfos, ZnxZero, + ScratchAvailable, SvpApply, SvpApplyInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeSvpPPol, TakeVecZnx, + TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxBigAddNormal, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, + VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VecZnxFillUniform, + VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, ZnxInfos, ZnxZero, }, - layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx, VecZnxBig}, + layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, VecZnx, VecZnxBig}, }; use sampling::source::Source; @@ -27,8 +26,7 @@ pub trait GLWEEncryptSkFamily = VecZnxDftAllocBytes + VecZnxNormalizeInplace + VecZnxAddNormal + VecZnxNormalize - + VecZnxSub - + VecZnxAllocBytes; + + VecZnxSub; pub trait GLWEEncryptPkFamily = VecZnxDftAllocBytes + VecZnxBigAllocBytes @@ -39,27 +37,24 @@ pub trait GLWEEncryptPkFamily = VecZnxDftAllocBytes + VecZnxBigAddNormal + VecZnxBigAddSmallInplace + VecZnxBigNormalize - + ScalarZnxAllocBytes + VecZnxNormalizeTmpBytes; impl GLWECiphertext> { - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, n: usize, basek: usize, k: usize) -> usize where Module: GLWEEncryptSkFamily, { let size: usize = k.div_ceil(basek); - module.vec_znx_normalize_tmp_bytes(module.n()) - + 2 * module.vec_znx_alloc_bytes(1, size) - + module.vec_znx_dft_alloc_bytes(1, size) + module.vec_znx_normalize_tmp_bytes(n) + 2 * VecZnx::alloc_bytes(n, 1, size) + module.vec_znx_dft_alloc_bytes(n, 1, size) } - pub fn encrypt_pk_scratch_space(module: &Module, basek: usize, k: usize) -> usize + pub fn encrypt_pk_scratch_space(module: &Module, n: usize, basek: usize, k: usize) -> usize where Module: GLWEEncryptPkFamily, { let size: usize = k.div_ceil(basek); - ((module.vec_znx_dft_alloc_bytes(1, size) + module.vec_znx_big_alloc_bytes(1, size)) | module.scalar_znx_alloc_bytes(1)) - + module.svp_ppol_alloc_bytes(1) - + module.vec_znx_normalize_tmp_bytes(module.n()) + ((module.vec_znx_dft_alloc_bytes(n, 1, size) + module.vec_znx_big_alloc_bytes(n, 1, size)) | ScalarZnx::alloc_bytes(n, 1)) + + module.svp_ppol_alloc_bytes(n, 1) + + module.vec_znx_normalize_tmp_bytes(n) } } @@ -75,7 +70,7 @@ impl GLWECiphertext { scratch: &mut Scratch, ) where Module: GLWEEncryptSkFamily, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { #[cfg(debug_assertions)] { @@ -83,10 +78,10 @@ impl GLWECiphertext { assert_eq!(sk.n(), self.n()); assert_eq!(pt.n(), self.n()); assert!( - scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k()), + scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self.n(), self.basek(), self.k()), "scratch.available(): {} < GLWECiphertext::encrypt_sk_scratch_space: {}", scratch.available(), - GLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k()) + GLWECiphertext::encrypt_sk_scratch_space(module, self.n(), self.basek(), self.k()) ) } @@ -111,17 +106,17 @@ impl GLWECiphertext { scratch: &mut Scratch, ) where Module: GLWEEncryptSkFamily, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { #[cfg(debug_assertions)] { assert_eq!(self.rank(), sk.rank()); assert_eq!(sk.n(), self.n()); assert!( - scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k()), + scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self.n(), self.basek(), self.k()), "scratch.available(): {} < GLWECiphertext::encrypt_sk_scratch_space: {}", scratch.available(), - GLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k()) + GLWECiphertext::encrypt_sk_scratch_space(module, self.n(), self.basek(), self.k()) ) } self.encrypt_sk_internal( @@ -146,7 +141,7 @@ impl GLWECiphertext { scratch: &mut Scratch, ) where Module: GLWEEncryptSkFamily, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { let cols: usize = self.rank() + 1; encrypt_sk_internal( @@ -176,7 +171,7 @@ impl GLWECiphertext { scratch: &mut Scratch, ) where Module: GLWEEncryptPkFamily, - Scratch: TakeVecZnxDft + TakeSvpPPol + TakeScalarZnx, + Scratch: TakeVecZnxDft + TakeSvpPPol + TakeScalarZnx, { self.encrypt_pk_internal::( module, @@ -199,7 +194,7 @@ impl GLWECiphertext { scratch: &mut Scratch, ) where Module: GLWEEncryptPkFamily, - Scratch: TakeVecZnxDft + TakeSvpPPol + TakeScalarZnx, + Scratch: TakeVecZnxDft + TakeSvpPPol + TakeScalarZnx, { self.encrypt_pk_internal::, DataPk, B>( module, @@ -230,17 +225,16 @@ impl GLWECiphertext { + VecZnxBigAddNormal + VecZnxBigAddSmallInplace + VecZnxBigNormalize, - Scratch: TakeVecZnxDft + TakeSvpPPol + TakeScalarZnx, + Scratch: TakeVecZnxDft + TakeSvpPPol + TakeScalarZnx, { #[cfg(debug_assertions)] { assert_eq!(self.basek(), pk.basek()); - assert_eq!(self.n(), module.n()); - assert_eq!(pk.n(), module.n()); + assert_eq!(self.n(), pk.n()); assert_eq!(self.rank(), pk.rank()); if let Some((pt, _)) = pt { assert_eq!(pt.basek(), pk.basek()); - assert_eq!(pt.n(), module.n()); + assert_eq!(pt.n(), pk.n()); } } @@ -249,10 +243,10 @@ impl GLWECiphertext { let cols: usize = self.rank() + 1; // Generates u according to the underlying secret distribution. - let (mut u_dft, scratch_1) = scratch.take_svp_ppol(module, 1); + let (mut u_dft, scratch_1) = scratch.take_svp_ppol(self.n(), 1); { - let (mut u, _) = scratch_1.take_scalar_znx(module, 1); + let (mut u, _) = scratch_1.take_scalar_znx(self.n(), 1); match pk.dist { Distribution::NONE => panic!( "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through \ @@ -271,7 +265,7 @@ impl GLWECiphertext { // ct[i] = pk[i] * u + ei (+ m if col = i) (0..cols).for_each(|i| { - let (mut ci_dft, scratch_2) = scratch_1.take_vec_znx_dft(module, 1, size_pk); + let (mut ci_dft, scratch_2) = scratch_1.take_vec_znx_dft(self.n(), 1, size_pk); // ci_dft = DFT(u) * DFT(pk[i]) module.svp_apply(&mut ci_dft, 0, &u_dft, 0, &pk.data, i); @@ -303,11 +297,11 @@ impl GLWECiphertext { } impl GLWECiphertextCompressed> { - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, n: usize, basek: usize, k: usize) -> usize where Module: GLWEEncryptSkFamily, { - GLWECiphertext::encrypt_sk_scratch_space(module, basek, k) + GLWECiphertext::encrypt_sk_scratch_space(module, n, basek, k) } } impl GLWECiphertextCompressed { @@ -322,7 +316,7 @@ impl GLWECiphertextCompressed { scratch: &mut Scratch, ) where Module: GLWEEncryptSkFamily, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { self.encrypt_sk_internal( module, @@ -346,7 +340,7 @@ impl GLWECiphertextCompressed { scratch: &mut Scratch, ) where Module: GLWEEncryptSkFamily, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { let mut source_xa = Source::new(seed_xa); let cols: usize = self.rank() + 1; @@ -383,7 +377,7 @@ pub(crate) fn encrypt_sk_internal, ) where Module: GLWEEncryptSkFamily, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { #[cfg(debug_assertions)] { @@ -399,11 +393,11 @@ pub(crate) fn encrypt_sk_internal = VecZnxDftAllocBytes impl GLWECiphertext> { pub fn external_product_scratch_space( module: &Module, + n: usize, basek: usize, k_out: usize, k_in: usize, @@ -33,9 +34,10 @@ impl GLWECiphertext> { let in_size: usize = k_in.div_ceil(basek).div_ceil(digits); let out_size: usize = k_out.div_ceil(basek); let ggsw_size: usize = k_ggsw.div_ceil(basek); - let res_dft: usize = module.vec_znx_dft_alloc_bytes(rank + 1, ggsw_size); - let a_dft: usize = module.vec_znx_dft_alloc_bytes(rank + 1, in_size); + let res_dft: usize = module.vec_znx_dft_alloc_bytes(n, rank + 1, ggsw_size); + let a_dft: usize = module.vec_znx_dft_alloc_bytes(n, rank + 1, in_size); let vmp: usize = module.vmp_apply_tmp_bytes( + n, out_size, in_size, in_size, // rows @@ -49,6 +51,7 @@ impl GLWECiphertext> { pub fn external_product_inplace_scratch_space( module: &Module, + n: usize, basek: usize, k_out: usize, k_ggsw: usize, @@ -58,7 +61,7 @@ impl GLWECiphertext> { where Module: GLWEExternalProductFamily, { - Self::external_product_scratch_space(module, basek, k_out, k_out, k_ggsw, digits, rank) + Self::external_product_scratch_space(module, n, basek, k_out, k_out, k_ggsw, digits, rank) } } @@ -83,13 +86,13 @@ impl GLWECiphertext { assert_eq!(rhs.rank(), self.rank()); assert_eq!(self.basek(), basek); assert_eq!(lhs.basek(), basek); - assert_eq!(rhs.n(), module.n()); - assert_eq!(self.n(), module.n()); - assert_eq!(lhs.n(), module.n()); + assert_eq!(rhs.n(), self.n()); + assert_eq!(lhs.n(), self.n()); assert!( scratch.available() >= GLWECiphertext::external_product_scratch_space( module, + self.n(), self.basek(), self.k(), lhs.k(), @@ -103,8 +106,8 @@ impl GLWECiphertext { let cols: usize = rhs.rank() + 1; let digits: usize = rhs.digits(); - let (mut res_dft, scratch1) = scratch.take_vec_znx_dft(module, cols, rhs.size()); // Todo optimise - let (mut a_dft, scratch2) = scratch1.take_vec_znx_dft(module, cols, lhs.size().div_ceil(digits)); + let (mut res_dft, scratch1) = scratch.take_vec_znx_dft(self.n(), cols, rhs.size()); // Todo optimise + let (mut a_dft, scratch2) = scratch1.take_vec_znx_dft(self.n(), cols, lhs.size().div_ceil(digits)); a_dft.data_mut().fill(0); diff --git a/core/src/glwe/keyswitch.rs b/core/src/glwe/keyswitch.rs index a3b40aa..909910e 100644 --- a/core/src/glwe/keyswitch.rs +++ b/core/src/glwe/keyswitch.rs @@ -22,6 +22,7 @@ pub trait GLWEKeyswitchFamily = VecZnxDftAllocBytes impl GLWECiphertext> { pub fn keyswitch_scratch_space( module: &Module, + n: usize, basek: usize, k_out: usize, k_in: usize, @@ -36,16 +37,24 @@ impl GLWECiphertext> { let in_size: usize = k_in.div_ceil(basek).div_ceil(digits); let out_size: usize = k_out.div_ceil(basek); let ksk_size: usize = k_ksk.div_ceil(basek); - let res_dft: usize = module.vec_znx_dft_alloc_bytes(rank_out + 1, ksk_size); // TODO OPTIMIZE - let ai_dft: usize = module.vec_znx_dft_alloc_bytes(rank_in, in_size); - let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, rank_in, rank_out + 1, ksk_size) - + module.vec_znx_dft_alloc_bytes(rank_in, in_size); - let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(module.n()); + let res_dft: usize = module.vec_znx_dft_alloc_bytes(n, rank_out + 1, ksk_size); // TODO OPTIMIZE + let ai_dft: usize = module.vec_znx_dft_alloc_bytes(n, rank_in, in_size); + let vmp: usize = module.vmp_apply_tmp_bytes( + n, + out_size, + in_size, + in_size, + rank_in, + rank_out + 1, + ksk_size, + ) + module.vec_znx_dft_alloc_bytes(n, rank_in, in_size); + let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(n); return res_dft + ((ai_dft + vmp) | normalize); } pub fn keyswitch_from_fourier_scratch_space( module: &Module, + n: usize, basek: usize, k_out: usize, k_in: usize, @@ -57,11 +66,14 @@ impl GLWECiphertext> { where Module: GLWEKeyswitchFamily, { - Self::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank_in, rank_out) + Self::keyswitch_scratch_space( + module, n, basek, k_out, k_in, k_ksk, digits, rank_in, rank_out, + ) } pub fn keyswitch_inplace_scratch_space( module: &Module, + n: usize, basek: usize, k_out: usize, k_ksk: usize, @@ -71,7 +83,7 @@ impl GLWECiphertext> { where Module: GLWEKeyswitchFamily, { - Self::keyswitch_scratch_space(module, basek, k_out, k_out, k_ksk, digits, rank, rank) + Self::keyswitch_scratch_space(module, n, basek, k_out, k_out, k_ksk, digits, rank, rank) } } @@ -105,13 +117,13 @@ impl GLWECiphertext { ); assert_eq!(self.basek(), basek); assert_eq!(lhs.basek(), basek); - assert_eq!(rhs.n(), module.n()); - assert_eq!(self.n(), module.n()); - assert_eq!(lhs.n(), module.n()); + assert_eq!(rhs.n(), self.n()); + assert_eq!(lhs.n(), self.n()); assert!( scratch.available() >= GLWECiphertext::keyswitch_scratch_space( module, + self.n(), self.basek(), self.k(), lhs.k(), @@ -133,6 +145,7 @@ impl GLWECiphertext { scratch.available(), GLWECiphertext::keyswitch_scratch_space( module, + self.n(), self.basek(), self.k(), lhs.k(), @@ -160,7 +173,7 @@ impl GLWECiphertext { { self.assert_keyswitch(module, lhs, rhs, scratch); } - let (res_dft, scratch1) = scratch.take_vec_znx_dft(module, self.cols(), rhs.size()); // Todo optimise + let (res_dft, scratch1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // Todo optimise let res_big: VecZnxBig<_, B> = keyswitch(module, res_dft, lhs, rhs, scratch1); (0..self.cols()).for_each(|i| { module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch1); @@ -227,7 +240,7 @@ where Scratch: TakeVecZnxDft, { let cols: usize = a.cols(); - let (mut ai_dft, scratch1) = scratch.take_vec_znx_dft(module, cols - 1, a.size()); + let (mut ai_dft, scratch1) = scratch.take_vec_znx_dft(a.n(), cols - 1, a.size()); (0..cols - 1).for_each(|col_i| { module.vec_znx_dft_from_vec_znx(1, 0, &mut ai_dft, col_i, a, col_i + 1); }); @@ -259,7 +272,7 @@ where { let cols: usize = a.cols(); let size: usize = a.size(); - let (mut ai_dft, scratch1) = scratch.take_vec_znx_dft(module, cols - 1, size.div_ceil(digits)); + let (mut ai_dft, scratch1) = scratch.take_vec_znx_dft(a.n(), cols - 1, size.div_ceil(digits)); ai_dft.data_mut().fill(0); diff --git a/core/src/glwe/layout.rs b/core/src/glwe/layout.rs index ba50a95..ec8fbb3 100644 --- a/core/src/glwe/layout.rs +++ b/core/src/glwe/layout.rs @@ -1,12 +1,11 @@ -use std::fmt::Debug; - use backend::hal::{ - api::{FillUniform, VecZnxAlloc, VecZnxAllocBytes, VecZnxCopy, VecZnxFillUniform, ZnxInfos, ZnxZero}, + api::{FillUniform, Reset, VecZnxCopy, VecZnxFillUniform, ZnxInfos}, layouts::{Backend, Data, DataMut, DataRef, Module, ReaderFrom, VecZnx, VecZnxToMut, VecZnxToRef, WriterTo}, }; use sampling::source::Source; use crate::{Decompress, GLWEOps, Infos, SetMetaData}; +use std::fmt; #[derive(PartialEq, Eq, Clone)] pub struct GLWECiphertext { @@ -15,8 +14,14 @@ pub struct GLWECiphertext { pub k: usize, } -impl Debug for GLWECiphertext { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl fmt::Debug for GLWECiphertext { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self) + } +} + +impl fmt::Display for GLWECiphertext { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, "GLWECiphertext: basek={} k={}: {}", @@ -27,16 +32,14 @@ impl Debug for GLWECiphertext { } } -impl ZnxZero for GLWECiphertext +impl Reset for GLWECiphertext where - VecZnx: ZnxZero, + VecZnx: Reset, { - fn zero(&mut self) { - self.data.zero() - } - - fn zero_at(&mut self, i: usize, j: usize) { - self.data.zero_at(i, j); + fn reset(&mut self) { + self.data.reset(); + self.basek = 0; + self.k = 0; } } @@ -50,22 +53,16 @@ where } impl GLWECiphertext> { - pub fn alloc(module: &Module, basek: usize, k: usize, rank: usize) -> Self - where - Module: VecZnxAlloc, - { + pub fn alloc(n: usize, basek: usize, k: usize, rank: usize) -> Self { Self { - data: module.vec_znx_alloc(rank + 1, k.div_ceil(basek)), + data: VecZnx::alloc(n, rank + 1, k.div_ceil(basek)), basek, k, } } - pub fn bytes_of(module: &Module, basek: usize, k: usize, rank: usize) -> usize - where - Module: VecZnxAllocBytes, - { - module.vec_znx_alloc_bytes(rank + 1, k.div_ceil(basek)) + pub fn bytes_of(n: usize, basek: usize, k: usize, rank: usize) -> usize { + VecZnx::alloc_bytes(n, rank + 1, k.div_ceil(basek)) } } @@ -168,28 +165,36 @@ pub struct GLWECiphertextCompressed { pub(crate) seed: [u8; 32], } -impl Debug for GLWECiphertextCompressed { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl fmt::Debug for GLWECiphertextCompressed { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self) + } +} + +impl fmt::Display for GLWECiphertextCompressed { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "GLWECiphertext: basek={} k={}: {}", + "GLWECiphertextCompressed: basek={} k={} rank={} seed={:?}: {}", self.basek(), self.k(), + self.rank, + self.seed, self.data ) } } -impl ZnxZero for GLWECiphertextCompressed +impl Reset for GLWECiphertextCompressed where - VecZnx: ZnxZero, + VecZnx: Reset, { - fn zero(&mut self) { - self.data.zero() - } - - fn zero_at(&mut self, i: usize, j: usize) { - self.data.zero_at(i, j); + fn reset(&mut self) { + self.data.reset(); + self.basek = 0; + self.k = 0; + self.rank = 0; + self.seed = [0u8; 32]; } } @@ -225,12 +230,9 @@ impl GLWECiphertextCompressed { } impl GLWECiphertextCompressed> { - pub fn alloc(module: &Module, basek: usize, k: usize, rank: usize) -> Self - where - Module: VecZnxAlloc, - { + pub fn alloc(n: usize, basek: usize, k: usize, rank: usize) -> Self { Self { - data: module.vec_znx_alloc(1, k.div_ceil(basek)), + data: VecZnx::alloc(n, 1, k.div_ceil(basek)), basek, k, rank, @@ -238,11 +240,8 @@ impl GLWECiphertextCompressed> { } } - pub fn bytes_of(module: &Module, basek: usize, k: usize) -> usize - where - Module: VecZnxAllocBytes, - { - GLWECiphertext::bytes_of(module, basek, k, 1) + pub fn bytes_of(n: usize, basek: usize, k: usize) -> usize { + GLWECiphertext::bytes_of(n, basek, k, 1) } } diff --git a/core/src/glwe/noise.rs b/core/src/glwe/noise.rs index a8bb86f..d02c63a 100644 --- a/core/src/glwe/noise.rs +++ b/core/src/glwe/noise.rs @@ -1,5 +1,5 @@ use backend::hal::{ - api::{ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAlloc, VecZnxNormalizeInplace, VecZnxStd, VecZnxSubABInplace}, + api::{ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxNormalizeInplace, VecZnxStd, VecZnxSubABInplace}, layouts::{Backend, DataRef, Module, ScratchOwned}, oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl}, }; @@ -16,13 +16,14 @@ impl GLWECiphertext { ) where DataSk: DataRef, DataPt: DataRef, - Module: GLWEDecryptFamily + VecZnxSubABInplace + VecZnxNormalizeInplace + VecZnxStd + VecZnxAlloc, + Module: GLWEDecryptFamily + VecZnxSubABInplace + VecZnxNormalizeInplace + VecZnxStd, B: TakeVecZnxDftImpl + TakeVecZnxBigImpl + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, { - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(module, self.basek(), self.k()); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(self.n(), self.basek(), self.k()); let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space( module, + self.n(), self.basek(), self.k(), )); diff --git a/core/src/glwe/ops.rs b/core/src/glwe/ops.rs index d6ede6d..89c833c 100644 --- a/core/src/glwe/ops.rs +++ b/core/src/glwe/ops.rs @@ -18,9 +18,8 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData + Sized { { #[cfg(debug_assertions)] { - assert_eq!(a.n(), module.n()); - assert_eq!(b.n(), module.n()); - assert_eq!(self.n(), module.n()); + assert_eq!(a.n(), self.n()); + assert_eq!(b.n(), self.n()); assert_eq!(a.basek(), b.basek()); assert!(self.rank() >= a.rank().max(b.rank())); } @@ -65,8 +64,7 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData + Sized { { #[cfg(debug_assertions)] { - assert_eq!(a.n(), module.n()); - assert_eq!(self.n(), module.n()); + assert_eq!(a.n(), self.n()); assert_eq!(self.basek(), a.basek()); assert!(self.rank() >= a.rank()) } @@ -89,9 +87,8 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData + Sized { { #[cfg(debug_assertions)] { - assert_eq!(a.n(), module.n()); - assert_eq!(b.n(), module.n()); - assert_eq!(self.n(), module.n()); + assert_eq!(a.n(), self.n()); + assert_eq!(b.n(), self.n()); assert_eq!(a.basek(), b.basek()); assert!(self.rank() >= a.rank().max(b.rank())); } @@ -137,8 +134,7 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData + Sized { { #[cfg(debug_assertions)] { - assert_eq!(a.n(), module.n()); - assert_eq!(self.n(), module.n()); + assert_eq!(a.n(), self.n()); assert_eq!(self.basek(), a.basek()); assert!(self.rank() >= a.rank()) } @@ -160,8 +156,7 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData + Sized { { #[cfg(debug_assertions)] { - assert_eq!(a.n(), module.n()); - assert_eq!(self.n(), module.n()); + assert_eq!(a.n(), self.n()); assert_eq!(self.basek(), a.basek()); assert!(self.rank() >= a.rank()) } @@ -183,8 +178,7 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData + Sized { { #[cfg(debug_assertions)] { - assert_eq!(a.n(), module.n()); - assert_eq!(self.n(), module.n()); + assert_eq!(a.n(), self.n()); assert_eq!(self.rank(), a.rank()) } @@ -203,11 +197,6 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData + Sized { where Module: VecZnxRotateInplace, { - #[cfg(debug_assertions)] - { - assert_eq!(self.n(), module.n()); - } - let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); (0..self_mut.rank() + 1).for_each(|i| { @@ -222,8 +211,7 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData + Sized { { #[cfg(debug_assertions)] { - assert_eq!(a.n(), module.n()); - assert_eq!(self.n(), module.n()); + assert_eq!(a.n(), self.n()); assert_eq!(self.rank(), a.rank()) } @@ -242,11 +230,6 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData + Sized { where Module: VecZnxMulXpMinusOneInplace, { - #[cfg(debug_assertions)] - { - assert_eq!(self.n(), module.n()); - } - let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); (0..self_mut.rank() + 1).for_each(|i| { @@ -261,8 +244,7 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData + Sized { { #[cfg(debug_assertions)] { - assert_eq!(self.n(), module.n()); - assert_eq!(a.n(), module.n()); + assert_eq!(self.n(), a.n()); assert_eq!(self.rank(), a.rank()); } @@ -292,8 +274,7 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData + Sized { { #[cfg(debug_assertions)] { - assert_eq!(self.n(), module.n()); - assert_eq!(a.n(), module.n()); + assert_eq!(self.n(), a.n()); assert_eq!(self.rank(), a.rank()); } @@ -311,10 +292,6 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData + Sized { where Module: VecZnxNormalizeInplace, { - #[cfg(debug_assertions)] - { - assert_eq!(self.n(), module.n()); - } let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); (0..self_mut.rank() + 1).for_each(|i| { module.vec_znx_normalize_inplace(self_mut.basek(), &mut self_mut.data, i, scratch); @@ -323,8 +300,8 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData + Sized { } impl GLWECiphertext> { - pub fn rsh_scratch_space(module: &Module) -> usize { - VecZnx::rsh_scratch_space(module.n()) + pub fn rsh_scratch_space(n: usize) -> usize { + VecZnx::rsh_scratch_space(n) } } diff --git a/core/src/glwe/packing.rs b/core/src/glwe/packing.rs index e4ba401..602a6c4 100644 --- a/core/src/glwe/packing.rs +++ b/core/src/glwe/packing.rs @@ -2,9 +2,9 @@ use std::collections::HashMap; use backend::hal::{ api::{ - ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAlloc, VecZnxAllocBytes, VecZnxAutomorphismInplace, - VecZnxBigAutomorphismInplace, VecZnxBigSubSmallBInplace, VecZnxCopy, VecZnxNegateInplace, VecZnxNormalizeInplace, - VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubABInplace, + ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAutomorphismInplace, VecZnxBigAutomorphismInplace, + VecZnxBigSubSmallBInplace, VecZnxCopy, VecZnxNegateInplace, VecZnxNormalizeInplace, VecZnxRotate, VecZnxRotateInplace, + VecZnxRshInplace, VecZnxSub, VecZnxSubABInplace, }, layouts::{Backend, DataMut, DataRef, Module, Scratch}, }; @@ -52,12 +52,9 @@ impl Accumulator { /// * `basek`: base 2 logarithm of the GLWE ciphertext in memory digit representation. /// * `k`: base 2 precision of the GLWE ciphertext precision over the Torus. /// * `rank`: rank of the GLWE ciphertext. - pub fn alloc(module: &Module, basek: usize, k: usize, rank: usize) -> Self - where - Module: VecZnxAlloc, - { + pub fn alloc(n: usize, basek: usize, k: usize, rank: usize) -> Self { Self { - data: GLWECiphertext::alloc(module, basek, k, rank), + data: GLWECiphertext::alloc(n, basek, k, rank), value: false, control: false, } @@ -78,13 +75,10 @@ impl GLWEPacker { /// * `basek`: base 2 logarithm of the GLWE ciphertext in memory digit representation. /// * `k`: base 2 precision of the GLWE ciphertext precision over the Torus. /// * `rank`: rank of the GLWE ciphertext. - pub fn new(module: &Module, log_batch: usize, basek: usize, k: usize, rank: usize) -> Self - where - Module: VecZnxAlloc, - { + pub fn new(n: usize, log_batch: usize, basek: usize, k: usize, rank: usize) -> Self { let mut accumulators: Vec = Vec::::new(); - let log_n: usize = module.log_n(); - (0..log_n - log_batch).for_each(|_| accumulators.push(Accumulator::alloc(module, basek, k, rank))); + let log_n: usize = (usize::BITS - (n - 1).leading_zeros()) as _; + (0..log_n - log_batch).for_each(|_| accumulators.push(Accumulator::alloc(n, basek, k, rank))); Self { accumulators: accumulators, log_batch, @@ -104,6 +98,7 @@ impl GLWEPacker { /// Number of scratch space bytes required to call [Self::add]. pub fn scratch_space( module: &Module, + n: usize, basek: usize, ct_k: usize, k_ksk: usize, @@ -111,9 +106,9 @@ impl GLWEPacker { rank: usize, ) -> usize where - Module: GLWEKeyswitchFamily + VecZnxAllocBytes, + Module: GLWEKeyswitchFamily, { - pack_core_scratch_space(module, basek, ct_k, k_ksk, digits, rank) + pack_core_scratch_space(module, n, basek, ct_k, k_ksk, digits, rank) } pub fn galois_elements(module: &Module) -> Vec { @@ -137,12 +132,12 @@ impl GLWEPacker { scratch: &mut Scratch, ) where Module: GLWEPackingFamily, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { assert!( - self.counter < module.n(), + self.counter < self.accumulators[0].data.n(), "Packing limit of {} reached", - module.n() >> self.log_batch + self.accumulators[0].data.n() >> self.log_batch ); pack_core( @@ -161,7 +156,7 @@ impl GLWEPacker { where Module: VecZnxCopy, { - assert!(self.counter == module.n()); + assert!(self.counter == self.accumulators[0].data.n()); // Copy result GLWE into res GLWE res.copy( module, @@ -174,6 +169,7 @@ impl GLWEPacker { fn pack_core_scratch_space( module: &Module, + n: usize, basek: usize, ct_k: usize, k_ksk: usize, @@ -181,9 +177,9 @@ fn pack_core_scratch_space( rank: usize, ) -> usize where - Module: GLWEKeyswitchFamily + VecZnxAllocBytes, + Module: GLWEKeyswitchFamily, { - combine_scratch_space(module, basek, ct_k, k_ksk, digits, rank) + combine_scratch_space(module, n, basek, ct_k, k_ksk, digits, rank) } fn pack_core( @@ -195,7 +191,7 @@ fn pack_core( scratch: &mut Scratch, ) where Module: GLWEPackingFamily, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { let log_n: usize = module.log_n(); @@ -248,6 +244,7 @@ fn pack_core( fn combine_scratch_space( module: &Module, + n: usize, basek: usize, ct_k: usize, k_ksk: usize, @@ -255,11 +252,11 @@ fn combine_scratch_space( rank: usize, ) -> usize where - Module: GLWEKeyswitchFamily + VecZnxAllocBytes, + Module: GLWEKeyswitchFamily, { - GLWECiphertext::bytes_of(module, basek, ct_k, rank) - + (GLWECiphertext::rsh_scratch_space(module) - | GLWECiphertext::automorphism_scratch_space(module, basek, ct_k, ct_k, k_ksk, digits, rank)) + GLWECiphertext::bytes_of(n, basek, ct_k, rank) + + (GLWECiphertext::rsh_scratch_space(n) + | GLWECiphertext::automorphism_scratch_space(module, n, basek, ct_k, ct_k, k_ksk, digits, rank)) } /// [combine] merges two ciphertexts together. @@ -272,9 +269,10 @@ fn combine( scratch: &mut Scratch, ) where Module: GLWEPackingFamily, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { - let log_n: usize = module.log_n(); + let n: usize = acc.data.n(); + let log_n: usize = (u64::BITS - (n - 1).leading_zeros()) as _; let a: &mut GLWECiphertext> = &mut acc.data; let basek: usize = a.basek(); let k: usize = a.k(); @@ -302,7 +300,7 @@ fn combine( // since 2*(I(X) * Q/2) = I(X) * Q = 0 mod Q. if acc.value { if let Some(b) = b { - let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(module, basek, k, rank); + let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(n, basek, k, rank); // a = a * X^-t a.rotate_inplace(module, -t); @@ -343,7 +341,7 @@ fn combine( } } else { if let Some(b) = b { - let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(module, basek, k, rank); + let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(n, basek, k, rank); tmp_b.rotate(module, 1 << (log_n - i - 1), b); tmp_b.rsh(module, 1); diff --git a/core/src/glwe/plaintext.rs b/core/src/glwe/plaintext.rs index 114c488..60754eb 100644 --- a/core/src/glwe/plaintext.rs +++ b/core/src/glwe/plaintext.rs @@ -1,7 +1,4 @@ -use backend::hal::{ - api::{VecZnxAlloc, VecZnxAllocBytes}, - layouts::{Backend, Data, DataMut, DataRef, Module, VecZnx, VecZnxToMut, VecZnxToRef}, -}; +use backend::hal::layouts::{Data, DataMut, DataRef, VecZnx, VecZnxToMut, VecZnxToRef}; use crate::{GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEOps, Infos, SetMetaData}; @@ -38,22 +35,16 @@ impl SetMetaData for GLWEPlaintext { } impl GLWEPlaintext> { - pub fn alloc(module: &Module, basek: usize, k: usize) -> Self - where - Module: VecZnxAlloc, - { + pub fn alloc(n: usize, basek: usize, k: usize) -> Self { Self { - data: module.vec_znx_alloc(1, k.div_ceil(basek)), + data: VecZnx::alloc(n, 1, k.div_ceil(basek)), basek: basek, k, } } - pub fn byte_of(module: &Module, basek: usize, k: usize) -> usize - where - Module: VecZnxAllocBytes, - { - module.vec_znx_alloc_bytes(1, k.div_ceil(basek)) + pub fn byte_of(n: usize, basek: usize, k: usize) -> usize { + VecZnx::alloc_bytes(n, 1, k.div_ceil(basek)) } } diff --git a/core/src/glwe/public_key.rs b/core/src/glwe/public_key.rs index cd9bb34..d45f663 100644 --- a/core/src/glwe/public_key.rs +++ b/core/src/glwe/public_key.rs @@ -1,8 +1,5 @@ use backend::hal::{ - api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAlloc, VecZnxAllocBytes, VecZnxDftAlloc, VecZnxDftAllocBytes, - VecZnxDftFromVecZnx, - }, + api::{ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftFromVecZnx}, layouts::{Backend, Data, DataMut, DataRef, Module, ReaderFrom, Scratch, ScratchOwned, VecZnx, VecZnxDft, WriterTo}, oep::{ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxDftImpl, TakeVecZnxImpl}, }; @@ -21,23 +18,17 @@ pub struct GLWEPublicKey { } impl GLWEPublicKey> { - pub fn alloc(module: &Module, basek: usize, k: usize, rank: usize) -> Self - where - Module: VecZnxAlloc, - { + pub fn alloc(n: usize, basek: usize, k: usize, rank: usize) -> Self { Self { - data: module.vec_znx_alloc(rank + 1, k.div_ceil(basek)), + data: VecZnx::alloc(n, rank + 1, k.div_ceil(basek)), basek: basek, k: k, dist: Distribution::NONE, } } - pub fn bytes_of(module: &Module, basek: usize, k: usize, rank: usize) -> usize - where - Module: VecZnxAllocBytes, - { - module.vec_znx_alloc_bytes(rank + 1, k.div_ceil(basek)) + pub fn bytes_of(n: usize, basek: usize, k: usize, rank: usize) -> usize { + VecZnx::alloc_bytes(n, rank + 1, k.div_ceil(basek)) } } @@ -72,7 +63,7 @@ impl GLWEPublicKey { source_xe: &mut Source, sigma: f64, ) where - Module: GLWEPublicKeyFamily + VecZnxAlloc, + Module: GLWEPublicKeyFamily, B: ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl + TakeVecZnxDftImpl @@ -81,6 +72,8 @@ impl GLWEPublicKey { { #[cfg(debug_assertions)] { + assert_eq!(self.n(), sk.n()); + match sk.dist { Distribution::NONE => panic!("invalid sk: SecretDistribution::NONE"), _ => {} @@ -90,11 +83,12 @@ impl GLWEPublicKey { // Its ok to allocate scratch space here since pk is usually generated only once. let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWECiphertext::encrypt_sk_scratch_space( module, + self.n(), self.basek(), self.k(), )); - let mut tmp: GLWECiphertext> = GLWECiphertext::alloc(module, self.basek(), self.k(), self.rank()); + let mut tmp: GLWECiphertext> = GLWECiphertext::alloc(self.n(), self.basek(), self.k(), self.rank()); tmp.encrypt_zero_sk(module, sk, source_xa, source_xe, sigma, scratch.borrow()); self.dist = sk.dist; } @@ -157,23 +151,23 @@ impl GLWEPublicKeyExec { } impl GLWEPublicKeyExec, B> { - pub fn alloc(module: &Module, basek: usize, k: usize, rank: usize) -> Self + pub fn alloc(module: &Module, n: usize, basek: usize, k: usize, rank: usize) -> Self where Module: VecZnxDftAlloc, { Self { - data: module.vec_znx_dft_alloc(rank + 1, k.div_ceil(basek)), + data: module.vec_znx_dft_alloc(n, rank + 1, k.div_ceil(basek)), basek: basek, k: k, dist: Distribution::NONE, } } - pub fn bytes_of(module: &Module, basek: usize, k: usize, rank: usize) -> usize + pub fn bytes_of(module: &Module, n: usize, basek: usize, k: usize, rank: usize) -> usize where Module: VecZnxDftAllocBytes, { - module.vec_znx_dft_alloc_bytes(rank + 1, k.div_ceil(basek)) + module.vec_znx_dft_alloc_bytes(n, rank + 1, k.div_ceil(basek)) } pub fn from(module: &Module, other: &GLWEPublicKey, scratch: &mut Scratch) -> Self @@ -181,7 +175,8 @@ impl GLWEPublicKeyExec, B> { DataOther: DataRef, Module: VecZnxDftAlloc + VecZnxDftFromVecZnx, { - let mut pk_exec: GLWEPublicKeyExec, B> = GLWEPublicKeyExec::alloc(module, other.basek(), other.k(), other.rank()); + let mut pk_exec: GLWEPublicKeyExec, B> = + GLWEPublicKeyExec::alloc(module, other.n(), other.basek(), other.k(), other.rank()); pk_exec.prepare(module, other, scratch); pk_exec } @@ -195,8 +190,7 @@ impl GLWEPublicKeyExec { { #[cfg(debug_assertions)] { - assert_eq!(self.n(), module.n()); - assert_eq!(other.n(), module.n()); + assert_eq!(self.n(), other.n()); assert_eq!(self.size(), other.size()); } diff --git a/core/src/glwe/secret.rs b/core/src/glwe/secret.rs index 10232b2..90b4c72 100644 --- a/core/src/glwe/secret.rs +++ b/core/src/glwe/secret.rs @@ -1,5 +1,5 @@ use backend::hal::{ - api::{ScalarZnxAlloc, ScalarZnxAllocBytes, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, ZnxInfos, ZnxZero}, + api::{SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, ZnxInfos, ZnxZero}, layouts::{Backend, Data, DataMut, DataRef, Module, ReaderFrom, ScalarZnx, SvpPPol, WriterTo}, }; use sampling::source::Source; @@ -15,21 +15,15 @@ pub struct GLWESecret { } impl GLWESecret> { - pub fn alloc(module: &Module, rank: usize) -> Self - where - Module: ScalarZnxAlloc, - { + pub fn alloc(n: usize, rank: usize) -> Self { Self { - data: module.scalar_znx_alloc(rank), + data: ScalarZnx::alloc(n, rank), dist: Distribution::NONE, } } - pub fn bytes_of(module: &Module, rank: usize) -> usize - where - Module: ScalarZnxAllocBytes, - { - module.scalar_znx_alloc_bytes(rank) + pub fn bytes_of(n: usize, rank: usize) -> usize { + ScalarZnx::alloc_bytes(n, rank) } } @@ -115,21 +109,21 @@ pub struct GLWESecretExec { } impl GLWESecretExec, B> { - pub fn alloc(module: &Module, rank: usize) -> Self + pub fn alloc(module: &Module, n: usize, rank: usize) -> Self where - Module: GLWESecretFamily, + Module: SvpPPolAlloc, { Self { - data: module.svp_ppol_alloc(rank), + data: module.svp_ppol_alloc(n, rank), dist: Distribution::NONE, } } - pub fn bytes_of(module: &Module, rank: usize) -> usize + pub fn bytes_of(module: &Module, n: usize, rank: usize) -> usize where - Module: GLWESecretFamily, + Module: SvpPPolAllocBytes, { - module.svp_ppol_alloc_bytes(rank) + module.svp_ppol_alloc_bytes(n, rank) } } @@ -137,9 +131,9 @@ impl GLWESecretExec, B> { pub fn from(module: &Module, sk: &GLWESecret) -> Self where D: DataRef, - Module: GLWESecretFamily, + Module: SvpPrepare + SvpPPolAlloc, { - let mut sk_dft: GLWESecretExec, B> = Self::alloc(module, sk.rank()); + let mut sk_dft: GLWESecretExec, B> = Self::alloc(module, sk.n(), sk.rank()); sk_dft.prepare(module, sk); sk_dft } @@ -163,7 +157,7 @@ impl GLWESecretExec { pub(crate) fn prepare(&mut self, module: &Module, sk: &GLWESecret) where O: DataRef, - Module: GLWESecretFamily, + Module: SvpPrepare, { (0..self.rank()).for_each(|i| { module.svp_prepare(&mut self.data, i, &sk.data, i); diff --git a/core/src/glwe/tests/cpu_spqlios/fft64.rs b/core/src/glwe/tests/cpu_spqlios/fft64.rs index c3a7ffc..5f7135f 100644 --- a/core/src/glwe/tests/cpu_spqlios/fft64.rs +++ b/core/src/glwe/tests/cpu_spqlios/fft64.rs @@ -8,7 +8,6 @@ use crate::glwe::tests::{ generic_encryption::{test_encrypt_pk, test_encrypt_sk, test_encrypt_sk_compressed, test_encrypt_zero_sk}, generic_external_product::{test_external_product, test_external_product_inplace}, generic_keyswitch::{test_keyswitch, test_keyswitch_inplace}, - generic_serialization::{test_serialization, test_serialization_compressed}, packing::test_packing, trace::test_trace_inplace, }; @@ -175,17 +174,3 @@ fn packing() { let module: Module = Module::::new(1 << log_n); test_packing(&module); } - -#[test] -fn serialization() { - let log_n: usize = 5; - let module: Module = Module::::new(1 << log_n); - test_serialization(&module); -} - -#[test] -fn serialization_compressed() { - let log_n: usize = 5; - let module: Module = Module::::new(1 << log_n); - test_serialization_compressed(&module); -} diff --git a/core/src/glwe/tests/generic_automorphism.rs b/core/src/glwe/tests/generic_automorphism.rs index a3875a5..0ffb9bb 100644 --- a/core/src/glwe/tests/generic_automorphism.rs +++ b/core/src/glwe/tests/generic_automorphism.rs @@ -1,8 +1,7 @@ use backend::hal::{ api::{ - MatZnxAlloc, ScalarZnxAlloc, ScalarZnxAllocBytes, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddScalarInplace, - VecZnxAlloc, VecZnxAllocBytes, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxFillUniform, VecZnxStd, - VecZnxSwithcDegree, + ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, + VecZnxFillUniform, VecZnxStd, VecZnxSwithcDegree, }, layouts::{Backend, Module, ScratchOwned}, oep::{ @@ -22,13 +21,8 @@ pub(crate) trait AutomorphismTestModuleFamily = AutomorphismKeyEncry + GLWEDecryptFamily + GGLWEExecLayoutFamily + GLWEKeyswitchFamily - + MatZnxAlloc - + VecZnxAlloc - + ScalarZnxAllocBytes - + VecZnxAllocBytes + VecZnxAutomorphism + VecZnxSwithcDegree - + ScalarZnxAlloc + VecZnxAddScalarInplace + VecZnxAutomorphismInplace + VecZnxStd; @@ -55,12 +49,13 @@ pub(crate) fn test_automorphism( Module: AutomorphismTestModuleFamily, B: AutomorphismTestScratchFamily, { + let n: usize = module.n(); let rows: usize = k_in.div_ceil(basek * digits); - let mut autokey: AutomorphismKey> = AutomorphismKey::alloc(module, basek, k_ksk, rows, digits, rank); - let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(module, basek, k_in, rank); - let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(module, basek, k_out, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k_in); + let mut autokey: AutomorphismKey> = AutomorphismKey::alloc(n, basek, k_ksk, rows, digits, rank); + let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_in, rank); + let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_out, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_in); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -69,11 +64,12 @@ pub(crate) fn test_automorphism( module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, k_in, &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::alloc( - AutomorphismKey::encrypt_sk_scratch_space(module, basek, autokey.k(), rank) - | GLWECiphertext::decrypt_scratch_space(module, basek, ct_out.k()) - | GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct_in.k()) + AutomorphismKey::encrypt_sk_scratch_space(module, n, basek, autokey.k(), rank) + | GLWECiphertext::decrypt_scratch_space(module, n, basek, ct_out.k()) + | GLWECiphertext::encrypt_sk_scratch_space(module, n, basek, ct_in.k()) | GLWECiphertext::automorphism_scratch_space( module, + n, basek, ct_out.k(), ct_in.k(), @@ -83,7 +79,7 @@ pub(crate) fn test_automorphism( ), ); - let mut sk: GLWESecret> = GLWESecret::alloc(module, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk); @@ -107,7 +103,8 @@ pub(crate) fn test_automorphism( scratch.borrow(), ); - let mut autokey_exec: AutomorphismKeyExec, B> = AutomorphismKeyExec::alloc(module, basek, k_ksk, rows, digits, rank); + let mut autokey_exec: AutomorphismKeyExec, B> = + AutomorphismKeyExec::alloc(module, n, basek, k_ksk, rows, digits, rank); autokey_exec.prepare(module, &autokey, scratch.borrow()); ct_out.automorphism(module, &ct_in, &autokey_exec, scratch.borrow()); @@ -143,11 +140,12 @@ pub(crate) fn test_automorphism_inplace( Module: AutomorphismTestModuleFamily, B: AutomorphismTestScratchFamily, { + let n: usize = module.n(); let rows: usize = k_ct.div_ceil(basek * digits); - let mut autokey: AutomorphismKey> = AutomorphismKey::alloc(module, basek, k_ksk, rows, digits, rank); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(module, basek, k_ct, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k_ct); + let mut autokey: AutomorphismKey> = AutomorphismKey::alloc(n, basek, k_ksk, rows, digits, rank); + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_ct); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -156,13 +154,13 @@ pub(crate) fn test_automorphism_inplace( module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, k_ct, &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::alloc( - AutomorphismKey::encrypt_sk_scratch_space(module, basek, autokey.k(), rank) - | GLWECiphertext::decrypt_scratch_space(module, basek, ct.k()) - | GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct.k()) - | GLWECiphertext::automorphism_inplace_scratch_space(module, basek, ct.k(), autokey.k(), digits, rank), + AutomorphismKey::encrypt_sk_scratch_space(module, n, basek, autokey.k(), rank) + | GLWECiphertext::decrypt_scratch_space(module, n, basek, ct.k()) + | GLWECiphertext::encrypt_sk_scratch_space(module, n, basek, ct.k()) + | GLWECiphertext::automorphism_inplace_scratch_space(module, n, basek, ct.k(), autokey.k(), digits, rank), ); - let mut sk: GLWESecret> = GLWESecret::alloc(module, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk); @@ -186,7 +184,8 @@ pub(crate) fn test_automorphism_inplace( scratch.borrow(), ); - let mut autokey_exec: AutomorphismKeyExec, B> = AutomorphismKeyExec::alloc(module, basek, k_ksk, rows, digits, rank); + let mut autokey_exec: AutomorphismKeyExec, B> = + AutomorphismKeyExec::alloc(module, n, basek, k_ksk, rows, digits, rank); autokey_exec.prepare(module, &autokey, scratch.borrow()); ct.automorphism_inplace(module, &autokey_exec, scratch.borrow()); diff --git a/core/src/glwe/tests/generic_encryption.rs b/core/src/glwe/tests/generic_encryption.rs index f4eb063..32a1fd6 100644 --- a/core/src/glwe/tests/generic_encryption.rs +++ b/core/src/glwe/tests/generic_encryption.rs @@ -1,8 +1,5 @@ use backend::hal::{ - api::{ - ScalarZnxAlloc, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAlloc, VecZnxCopy, VecZnxDftAlloc, VecZnxFillUniform, - VecZnxStd, VecZnxSubABInplace, - }, + api::{ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxCopy, VecZnxDftAlloc, VecZnxFillUniform, VecZnxStd, VecZnxSubABInplace}, layouts::{Backend, Module, ScratchOwned}, oep::{ ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, @@ -16,8 +13,7 @@ use crate::{ GLWEPlaintext, GLWEPublicKey, GLWEPublicKeyExec, GLWESecret, GLWESecretExec, GLWESecretFamily, Infos, }; -pub(crate) trait EncryptionTestModuleFamily = - GLWEDecryptFamily + GLWESecretFamily + VecZnxAlloc + ScalarZnxAlloc + VecZnxStd; +pub(crate) trait EncryptionTestModuleFamily = GLWEDecryptFamily + GLWESecretFamily + VecZnxStd; pub(crate) trait EncryptionTestScratchFamily = TakeVecZnxDftImpl + TakeVecZnxBigImpl @@ -33,20 +29,21 @@ where Module: EncryptionTestModuleFamily + GLWEEncryptSkFamily, B: EncryptionTestScratchFamily, { - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(module, basek, k_ct, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k_pt); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k_pt); + let n = module.n(); + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_pt); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_pt); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct.k()) - | GLWECiphertext::decrypt_scratch_space(module, basek, ct.k()), + GLWECiphertext::encrypt_sk_scratch_space(module, n, basek, ct.k()) + | GLWECiphertext::decrypt_scratch_space(module, n, basek, ct.k()), ); - let mut sk: GLWESecret> = GLWESecret::alloc(module, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk); @@ -83,21 +80,22 @@ pub(crate) fn test_encrypt_sk_compressed( Module: EncryptionTestModuleFamily + GLWEEncryptSkFamily + VecZnxCopy, B: EncryptionTestScratchFamily, { - let mut ct_compressed: GLWECiphertextCompressed> = GLWECiphertextCompressed::alloc(module, basek, k_ct, rank); + let n = module.n(); + let mut ct_compressed: GLWECiphertextCompressed> = GLWECiphertextCompressed::alloc(n, basek, k_ct, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k_pt); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k_pt); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_pt); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_pt); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertextCompressed::encrypt_sk_scratch_space(module, basek, k_ct) - | GLWECiphertext::decrypt_scratch_space(module, basek, k_ct), + GLWECiphertextCompressed::encrypt_sk_scratch_space(module, n, basek, k_ct) + | GLWECiphertext::decrypt_scratch_space(module, n, basek, k_ct), ); - let mut sk: GLWESecret> = GLWESecret::alloc(module, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk); @@ -115,7 +113,7 @@ pub(crate) fn test_encrypt_sk_compressed( scratch.borrow(), ); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(module, basek, k_ct, rank); + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct, rank); ct.decompress(module, &ct_compressed); ct.decrypt(module, &mut pt_have, &sk_exec, scratch.borrow()); @@ -138,21 +136,22 @@ where Module: EncryptionTestModuleFamily + GLWEEncryptSkFamily, B: EncryptionTestScratchFamily, { - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k_ct); + let n = module.n(); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_ct); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([1u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut sk: GLWESecret> = GLWESecret::alloc(module, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(module, basek, k_ct, rank); + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct, rank); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::decrypt_scratch_space(module, basek, k_ct) - | GLWECiphertext::encrypt_sk_scratch_space(module, basek, k_ct), + GLWECiphertext::decrypt_scratch_space(module, n, basek, k_ct) + | GLWECiphertext::encrypt_sk_scratch_space(module, n, basek, k_ct), ); ct.encrypt_zero_sk( @@ -178,26 +177,27 @@ where + VecZnxSubABInplace, B: EncryptionTestScratchFamily, { - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(module, basek, k_ct, rank); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k_ct); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k_ct); + let n: usize = module.n(); + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct, rank); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_ct); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_ct); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut source_xu: Source = Source::new([0u8; 32]); - let mut sk: GLWESecret> = GLWESecret::alloc(module, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk); - let mut pk: GLWEPublicKey> = GLWEPublicKey::alloc(module, basek, k_pk, rank); + let mut pk: GLWEPublicKey> = GLWEPublicKey::alloc(n, basek, k_pk, rank); pk.generate_from_sk(module, &sk_exec, &mut source_xa, &mut source_xe, sigma); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct.k()) - | GLWECiphertext::decrypt_scratch_space(module, basek, ct.k()) - | GLWECiphertext::encrypt_pk_scratch_space(module, basek, pk.k()), + GLWECiphertext::encrypt_sk_scratch_space(module, n, basek, ct.k()) + | GLWECiphertext::decrypt_scratch_space(module, n, basek, ct.k()) + | GLWECiphertext::encrypt_pk_scratch_space(module, n, basek, pk.k()), ); module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, k_ct, &mut source_xa); @@ -219,7 +219,7 @@ where pt_want.sub_inplace_ab(module, &pt_have); let noise_have: f64 = module.vec_znx_std(basek, &pt_want.data, 0).log2(); - let noise_want: f64 = ((((rank as f64) + 1.0) * module.n() as f64 * 0.5 * sigma * sigma).sqrt()).log2() - (k_ct as f64); + let noise_want: f64 = ((((rank as f64) + 1.0) * n as f64 * 0.5 * sigma * sigma).sqrt()).log2() - (k_ct as f64); assert!( noise_have <= noise_want + 0.2, diff --git a/core/src/glwe/tests/generic_external_product.rs b/core/src/glwe/tests/generic_external_product.rs index d881fbe..4150261 100644 --- a/core/src/glwe/tests/generic_external_product.rs +++ b/core/src/glwe/tests/generic_external_product.rs @@ -1,7 +1,7 @@ use backend::hal::{ api::{ - MatZnxAlloc, ScalarZnxAlloc, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddScalarInplace, VecZnxAlloc, - VecZnxAllocBytes, VecZnxFillUniform, VecZnxRotateInplace, VecZnxStd, ZnxViewMut, + ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddScalarInplace, VecZnxFillUniform, VecZnxRotateInplace, VecZnxStd, + ZnxViewMut, }, layouts::{Backend, Module, ScalarZnx, ScratchOwned}, oep::{ @@ -21,10 +21,6 @@ pub(crate) trait ExternalProductTestModuleFamily = GLWEEncryptSkFami + GLWESecretFamily + GLWEExternalProductFamily + GGSWLayoutFamily - + MatZnxAlloc - + VecZnxAlloc - + ScalarZnxAlloc - + VecZnxAllocBytes + VecZnxAddScalarInplace + VecZnxRotateInplace + VecZnxStd; @@ -51,13 +47,14 @@ pub(crate) fn test_external_product( Module: ExternalProductTestModuleFamily, B: ExternalProductTestScratchFamily, { + let n: usize = module.n(); let rows: usize = k_in.div_ceil(basek * digits); - let mut ct_ggsw: GGSWCiphertext> = GGSWCiphertext::alloc(module, basek, k_ggsw, rows, digits, rank); - let mut ct_glwe_in: GLWECiphertext> = GLWECiphertext::alloc(module, basek, k_in, rank); - let mut ct_glwe_out: GLWECiphertext> = GLWECiphertext::alloc(module, basek, k_out, rank); - let mut pt_rgsw: ScalarZnx> = module.scalar_znx_alloc(1); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k_in); + let mut ct_ggsw: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ggsw, rows, digits, rank); + let mut ct_glwe_in: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_in, rank); + let mut ct_glwe_out: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_out, rank); + let mut pt_rgsw: ScalarZnx> = ScalarZnx::alloc(n, 1); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_in); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -73,10 +70,11 @@ pub(crate) fn test_external_product( pt_rgsw.raw_mut()[k] = 1; // X^{k} let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, basek, ct_ggsw.k(), rank) - | GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct_glwe_in.k()) + GGSWCiphertext::encrypt_sk_scratch_space(module, n, basek, ct_ggsw.k(), rank) + | GLWECiphertext::encrypt_sk_scratch_space(module, n, basek, ct_glwe_in.k()) | GLWECiphertext::external_product_scratch_space( module, + n, basek, ct_glwe_out.k(), ct_glwe_in.k(), @@ -86,7 +84,7 @@ pub(crate) fn test_external_product( ), ); - let mut sk: GLWESecret> = GLWESecret::alloc(module, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk); @@ -119,12 +117,12 @@ pub(crate) fn test_external_product( let var_gct_err_lhs: f64 = sigma * sigma; let var_gct_err_rhs: f64 = 0f64; - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_msg: f64 = 1f64 / n as f64; // X^{k} let var_a0_err: f64 = sigma * sigma; let var_a1_err: f64 = 1f64 / 12f64; let max_noise: f64 = noise_ggsw_product( - module.n() as f64, + n as f64, basek * digits, 0.5, var_msg, @@ -152,12 +150,13 @@ pub(crate) fn test_external_product_inplace( Module: ExternalProductTestModuleFamily, B: ExternalProductTestScratchFamily, { + let n: usize = module.n(); let rows: usize = k_ct.div_ceil(basek * digits); - let mut ct_ggsw: GGSWCiphertext> = GGSWCiphertext::alloc(module, basek, k_ggsw, rows, digits, rank); - let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(module, basek, k_ct, rank); - let mut pt_rgsw: ScalarZnx> = module.scalar_znx_alloc(1); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k_ct); + let mut ct_ggsw: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ggsw, rows, digits, rank); + let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct, rank); + let mut pt_rgsw: ScalarZnx> = ScalarZnx::alloc(n, 1); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_ct); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -173,12 +172,12 @@ pub(crate) fn test_external_product_inplace( pt_rgsw.raw_mut()[k] = 1; // X^{k} let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, basek, ct_ggsw.k(), rank) - | GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct_glwe.k()) - | GLWECiphertext::external_product_inplace_scratch_space(module, basek, ct_glwe.k(), ct_ggsw.k(), digits, rank), + GGSWCiphertext::encrypt_sk_scratch_space(module, n, basek, ct_ggsw.k(), rank) + | GLWECiphertext::encrypt_sk_scratch_space(module, n, basek, ct_glwe.k()) + | GLWECiphertext::external_product_inplace_scratch_space(module, n, basek, ct_glwe.k(), ct_ggsw.k(), digits, rank), ); - let mut sk: GLWESecret> = GLWESecret::alloc(module, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk); @@ -211,12 +210,12 @@ pub(crate) fn test_external_product_inplace( let var_gct_err_lhs: f64 = sigma * sigma; let var_gct_err_rhs: f64 = 0f64; - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_msg: f64 = 1f64 / n as f64; // X^{k} let var_a0_err: f64 = sigma * sigma; let var_a1_err: f64 = 1f64 / 12f64; let max_noise: f64 = noise_ggsw_product( - module.n() as f64, + n as f64, basek * digits, 0.5, var_msg, diff --git a/core/src/glwe/tests/generic_keyswitch.rs b/core/src/glwe/tests/generic_keyswitch.rs index 587470a..733b940 100644 --- a/core/src/glwe/tests/generic_keyswitch.rs +++ b/core/src/glwe/tests/generic_keyswitch.rs @@ -1,8 +1,5 @@ use backend::hal::{ - api::{ - MatZnxAlloc, ScalarZnxAlloc, ScalarZnxAllocBytes, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddScalarInplace, - VecZnxAlloc, VecZnxAllocBytes, VecZnxFillUniform, VecZnxStd, VecZnxSwithcDegree, - }, + api::{ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddScalarInplace, VecZnxFillUniform, VecZnxStd, VecZnxSwithcDegree}, layouts::{Backend, Module, ScratchOwned}, oep::{ ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, @@ -22,11 +19,6 @@ pub(crate) trait KeySwitchTestModuleFamily = GLWESecretFamily + GLWEKeyswitchFamily + GLWEDecryptFamily + GGLWEExecLayoutFamily - + MatZnxAlloc - + VecZnxAlloc - + ScalarZnxAlloc - + ScalarZnxAllocBytes - + VecZnxAllocBytes + VecZnxStd + VecZnxSwithcDegree + VecZnxAddScalarInplace; @@ -54,12 +46,13 @@ pub(crate) fn test_keyswitch( Module: KeySwitchTestModuleFamily, B: KeySwitchTestScratchFamily, { + let n = module.n(); let rows: usize = k_in.div_ceil(basek * digits); - let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc(module, basek, k_ksk, rows, digits, rank_in, rank_out); - let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(module, basek, k_in, rank_in); - let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(module, basek, k_out, rank_out); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k_in); + let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc(n, basek, k_ksk, rows, digits, rank_in, rank_out); + let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_in, rank_in); + let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_out, rank_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_in); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -68,10 +61,11 @@ pub(crate) fn test_keyswitch( module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, k_in, &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, ksk.k(), rank_in, rank_out) - | GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct_in.k()) + GLWESwitchingKey::encrypt_sk_scratch_space(module, n, basek, ksk.k(), rank_in, rank_out) + | GLWECiphertext::encrypt_sk_scratch_space(module, n, basek, ct_in.k()) | GLWECiphertext::keyswitch_scratch_space( module, + n, basek, ct_out.k(), ct_in.k(), @@ -82,11 +76,11 @@ pub(crate) fn test_keyswitch( ), ); - let mut sk_in: GLWESecret> = GLWESecret::alloc(module, rank_in); + let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank_in); sk_in.fill_ternary_prob(0.5, &mut source_xs); let sk_in_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk_in); - let mut sk_out: GLWESecret> = GLWESecret::alloc(module, rank_out); + let mut sk_out: GLWESecret> = GLWESecret::alloc(n, rank_out); sk_out.fill_ternary_prob(0.5, &mut source_xs); let sk_out_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk_out); @@ -142,11 +136,12 @@ pub(crate) fn test_keyswitch_inplace( Module: KeySwitchTestModuleFamily, B: KeySwitchTestScratchFamily, { + let n: usize = module.n(); let rows: usize = k_ct.div_ceil(basek * digits); - let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc(module, basek, k_ksk, rows, digits, rank, rank); - let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(module, basek, k_ct, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k_ct); + let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc(n, basek, k_ksk, rows, digits, rank, rank); + let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_ct); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -155,16 +150,16 @@ pub(crate) fn test_keyswitch_inplace( module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, k_ct, &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, ksk.k(), rank, rank) - | GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct_glwe.k()) - | GLWECiphertext::keyswitch_inplace_scratch_space(module, basek, ct_glwe.k(), ksk.k(), digits, rank), + GLWESwitchingKey::encrypt_sk_scratch_space(module, n, basek, ksk.k(), rank, rank) + | GLWECiphertext::encrypt_sk_scratch_space(module, n, basek, ct_glwe.k()) + | GLWECiphertext::keyswitch_inplace_scratch_space(module, n, basek, ct_glwe.k(), ksk.k(), digits, rank), ); - let mut sk_in: GLWESecret> = GLWESecret::alloc(module, rank); + let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank); sk_in.fill_ternary_prob(0.5, &mut source_xs); let sk_in_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk_in); - let mut sk_out: GLWESecret> = GLWESecret::alloc(module, rank); + let mut sk_out: GLWESecret> = GLWESecret::alloc(n, rank); sk_out.fill_ternary_prob(0.5, &mut source_xs); let sk_out_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk_out); diff --git a/core/src/glwe/tests/generic_serialization.rs b/core/src/glwe/tests/generic_serialization.rs index 0a55716..dd2084f 100644 --- a/core/src/glwe/tests/generic_serialization.rs +++ b/core/src/glwe/tests/generic_serialization.rs @@ -1,23 +1,15 @@ -use backend::hal::{ - api::VecZnxAlloc, - layouts::{Backend, Module}, - tests::serialization::test_reader_writer_interface, -}; +use backend::hal::tests::serialization::test_reader_writer_interface; use crate::{GLWECiphertext, GLWECiphertextCompressed}; -pub(crate) fn test_serialization(module: &Module) -where - Module: VecZnxAlloc, -{ - let original: GLWECiphertext> = GLWECiphertext::alloc(module, 12, 54, 3); +#[test] +fn test_serialization() { + let original: GLWECiphertext> = GLWECiphertext::alloc(1024, 12, 54, 3); test_reader_writer_interface(original); } -pub(crate) fn test_serialization_compressed(module: &Module) -where - Module: VecZnxAlloc, -{ - let original: GLWECiphertextCompressed> = GLWECiphertextCompressed::alloc(module, 12, 54, 3); +#[test] +fn test_serialization_compressed() { + let original: GLWECiphertextCompressed> = GLWECiphertextCompressed::alloc(1024, 12, 54, 3); test_reader_writer_interface(original); } diff --git a/core/src/glwe/tests/packing.rs b/core/src/glwe/tests/packing.rs index 2066c6a..806fb4f 100644 --- a/core/src/glwe/tests/packing.rs +++ b/core/src/glwe/tests/packing.rs @@ -2,9 +2,8 @@ use std::collections::HashMap; use backend::hal::{ api::{ - MatZnxAlloc, ScalarZnxAlloc, ScalarZnxAllocBytes, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddScalarInplace, - VecZnxAlloc, VecZnxAllocBytes, VecZnxAutomorphism, VecZnxBigSubSmallBInplace, VecZnxEncodeVeci64, VecZnxRotateInplace, - VecZnxStd, VecZnxSwithcDegree, + ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxBigSubSmallBInplace, + VecZnxEncodeVeci64, VecZnxRotateInplace, VecZnxStd, VecZnxSwithcDegree, }, layouts::{Backend, Module, ScratchOwned}, oep::{ @@ -25,11 +24,6 @@ pub(crate) trait PackingTestModuleFamily = GLWEPackingFamily + GLWEKeyswitchFamily + GLWEDecryptFamily + GGLWEExecLayoutFamily - + MatZnxAlloc - + VecZnxAlloc - + ScalarZnxAlloc - + ScalarZnxAllocBytes - + VecZnxAllocBytes + VecZnxStd + VecZnxSwithcDegree + VecZnxAddScalarInplace @@ -56,6 +50,7 @@ where let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); + let n: usize = module.n(); let basek: usize = 18; let k_ct: usize = 36; let pt_k: usize = 18; @@ -67,17 +62,17 @@ where let rows: usize = k_ct.div_ceil(basek * digits); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::encrypt_sk_scratch_space(module, basek, k_ct) - | AutomorphismKey::encrypt_sk_scratch_space(module, basek, k_ksk, rank) - | GLWEPacker::scratch_space(module, basek, k_ct, k_ksk, digits, rank), + GLWECiphertext::encrypt_sk_scratch_space(module, n, basek, k_ct) + | AutomorphismKey::encrypt_sk_scratch_space(module, n, basek, k_ksk, rank) + | GLWEPacker::scratch_space(module, n, basek, k_ct, k_ksk, digits, rank), ); - let mut sk: GLWESecret> = GLWESecret::alloc(module, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_dft: GLWESecretExec, B> = GLWESecretExec::from(module, &sk); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k_ct); - let mut data: Vec = vec![0i64; module.n()]; + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_ct); + let mut data: Vec = vec![0i64; n]; data.iter_mut().enumerate().for_each(|(i, x)| { *x = i as i64; }); @@ -87,7 +82,7 @@ where let gal_els: Vec = GLWEPacker::galois_elements(module); let mut auto_keys: HashMap, B>> = HashMap::new(); - let mut tmp: AutomorphismKey> = AutomorphismKey::alloc(module, basek, k_ksk, rows, digits, rank); + let mut tmp: AutomorphismKey> = AutomorphismKey::alloc(n, basek, k_ksk, rows, digits, rank); gal_els.iter().for_each(|gal_el| { tmp.encrypt_sk( module, @@ -104,9 +99,9 @@ where let log_batch: usize = 0; - let mut packer: GLWEPacker = GLWEPacker::new(module, log_batch, basek, k_ct, rank); + let mut packer: GLWEPacker = GLWEPacker::new(n, log_batch, basek, k_ct, rank); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(module, basek, k_ct, rank); + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct, rank); ct.encrypt_sk( module, @@ -120,7 +115,7 @@ where let log_n: usize = module.log_n(); - (0..module.n() >> log_batch).for_each(|i| { + (0..n >> log_batch).for_each(|i| { ct.encrypt_sk( module, &pt, @@ -145,11 +140,11 @@ where } }); - let mut res = GLWECiphertext::alloc(module, basek, k_ct, rank); + let mut res = GLWECiphertext::alloc(n, basek, k_ct, rank); packer.flush(module, &mut res); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k_ct); - let mut data: Vec = vec![0i64; module.n()]; + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_ct); + let mut data: Vec = vec![0i64; n]; data.iter_mut().enumerate().for_each(|(i, x)| { if i % 5 == 0 { *x = reverse_bits_msb(i, log_n as u32) as i64; diff --git a/core/src/glwe/tests/trace.rs b/core/src/glwe/tests/trace.rs index f6313a4..7465a07 100644 --- a/core/src/glwe/tests/trace.rs +++ b/core/src/glwe/tests/trace.rs @@ -2,10 +2,9 @@ use std::collections::HashMap; use backend::hal::{ api::{ - MatZnxAlloc, ScalarZnxAlloc, ScalarZnxAllocBytes, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddScalarInplace, - VecZnxAlloc, VecZnxAllocBytes, VecZnxAutomorphism, VecZnxBigAutomorphismInplace, VecZnxBigSubSmallBInplace, VecZnxCopy, - VecZnxEncodeVeci64, VecZnxFillUniform, VecZnxNormalizeInplace, VecZnxRotateInplace, VecZnxRshInplace, VecZnxStd, - VecZnxSubABInplace, VecZnxSwithcDegree, ZnxView, ZnxViewMut, + ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxBigAutomorphismInplace, + VecZnxBigSubSmallBInplace, VecZnxCopy, VecZnxEncodeVeci64, VecZnxFillUniform, VecZnxNormalizeInplace, + VecZnxRotateInplace, VecZnxRshInplace, VecZnxStd, VecZnxSubABInplace, VecZnxSwithcDegree, ZnxView, ZnxViewMut, }, layouts::{Backend, Module, ScratchOwned}, oep::{ @@ -26,11 +25,6 @@ pub(crate) trait TraceTestModuleFamily = GLWESecretFamily + GLWEKeyswitchFamily + GLWEDecryptFamily + GGLWEExecLayoutFamily - + MatZnxAlloc - + VecZnxAlloc - + ScalarZnxAlloc - + ScalarZnxAllocBytes - + VecZnxAllocBytes + VecZnxStd + VecZnxSwithcDegree + VecZnxAddScalarInplace @@ -56,31 +50,32 @@ where Module: TraceTestModuleFamily, B: TraceTestScratchFamily, { + let n: usize = module.n(); let k_autokey: usize = k + basek; let digits: usize = 1; let rows: usize = k.div_ceil(basek * digits); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(module, basek, k, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k); + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct.k()) - | GLWECiphertext::decrypt_scratch_space(module, basek, ct.k()) - | AutomorphismKey::encrypt_sk_scratch_space(module, basek, k_autokey, rank) - | GLWECiphertext::trace_inplace_scratch_space(module, basek, ct.k(), k_autokey, digits, rank), + GLWECiphertext::encrypt_sk_scratch_space(module, n, basek, ct.k()) + | GLWECiphertext::decrypt_scratch_space(module, n, basek, ct.k()) + | AutomorphismKey::encrypt_sk_scratch_space(module, n, basek, k_autokey, rank) + | GLWECiphertext::trace_inplace_scratch_space(module, n, basek, ct.k(), k_autokey, digits, rank), ); - let mut sk: GLWESecret> = GLWESecret::alloc(module, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_dft: GLWESecretExec, B> = GLWESecretExec::from(module, &sk); - let mut data_want: Vec = vec![0i64; module.n()]; + let mut data_want: Vec = vec![0i64; n]; data_want .iter_mut() @@ -100,7 +95,7 @@ where let mut auto_keys: HashMap, B>> = HashMap::new(); let gal_els: Vec = GLWECiphertext::trace_galois_elements(module); - let mut tmp: AutomorphismKey> = AutomorphismKey::alloc(module, basek, k_autokey, rows, digits, rank); + let mut tmp: AutomorphismKey> = AutomorphismKey::alloc(n, basek, k_autokey, rows, digits, rank); gal_els.iter().for_each(|gal_el| { tmp.encrypt_sk( module, @@ -128,7 +123,7 @@ where let noise_have: f64 = module.vec_znx_std(basek, &pt_want.data, 0).log2(); let mut noise_want: f64 = var_noise_gglwe_product( - module.n() as f64, + n as f64, basek, 0.5, 0.5, @@ -140,7 +135,7 @@ where k_autokey, ); noise_want += sigma * sigma * (-2.0 * (k) as f64).exp2(); - noise_want += module.n() as f64 * 1.0 / 12.0 * 0.5 * rank as f64 * (-2.0 * (k) as f64).exp2(); + noise_want += n as f64 * 1.0 / 12.0 * 0.5 * rank as f64 * (-2.0 * (k) as f64).exp2(); noise_want = noise_want.sqrt().log2(); assert!( diff --git a/core/src/glwe/trace.rs b/core/src/glwe/trace.rs index 407c960..3ae9869 100644 --- a/core/src/glwe/trace.rs +++ b/core/src/glwe/trace.rs @@ -27,6 +27,7 @@ impl GLWECiphertext> { pub fn trace_scratch_space( module: &Module, + n: usize, basek: usize, out_k: usize, in_k: usize, @@ -37,11 +38,12 @@ impl GLWECiphertext> { where Module: GLWEKeyswitchFamily, { - Self::automorphism_inplace_scratch_space(module, basek, out_k.min(in_k), ksk_k, digits, rank) + Self::automorphism_inplace_scratch_space(module, n, basek, out_k.min(in_k), ksk_k, digits, rank) } pub fn trace_inplace_scratch_space( module: &Module, + n: usize, basek: usize, out_k: usize, ksk_k: usize, @@ -51,7 +53,7 @@ impl GLWECiphertext> { where Module: GLWEKeyswitchFamily, { - Self::automorphism_inplace_scratch_space(module, basek, out_k, ksk_k, digits, rank) + Self::automorphism_inplace_scratch_space(module, n, basek, out_k, ksk_k, digits, rank) } } diff --git a/core/src/lwe/conversion.rs b/core/src/lwe/conversion.rs new file mode 100644 index 0000000..7dfb741 --- /dev/null +++ b/core/src/lwe/conversion.rs @@ -0,0 +1,140 @@ +use backend::hal::{ + api::{ScratchAvailable, TakeVecZnx, TakeVecZnxDft, ZnxView, ZnxViewMut, ZnxZero}, + layouts::{Backend, DataMut, DataRef, Module, Scratch}, +}; + +use crate::{ + GLWECiphertext, GLWEKeyswitchFamily, GLWEToLWESwitchingKeyExec, Infos, LWECiphertext, LWESwitchingKeyExec, + LWEToGLWESwitchingKeyExec, TakeGLWECt, +}; + +impl LWECiphertext { + pub fn sample_extract(&mut self, a: &GLWECiphertext) { + #[cfg(debug_assertions)] + { + assert!(self.n() <= a.n()); + } + + let min_size: usize = self.size().min(a.size()); + let n: usize = self.n(); + + self.data.zero(); + (0..min_size).for_each(|i| { + let data_lwe: &mut [i64] = self.data.at_mut(0, i); + data_lwe[0] = a.data.at(0, i)[0]; + data_lwe[1..].copy_from_slice(&a.data.at(1, i)[..n]); + }); + } + + pub fn from_glwe( + &mut self, + module: &Module, + a: &GLWECiphertext, + ks: &GLWEToLWESwitchingKeyExec, + scratch: &mut Scratch, + ) where + DGlwe: DataRef, + DKs: DataRef, + Module: GLWEKeyswitchFamily, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.basek(), a.basek()); + assert_eq!(a.n(), ks.n()); + } + let (mut tmp_glwe, scratch1) = scratch.take_glwe_ct(a.n(), a.basek(), self.k(), 1); + tmp_glwe.keyswitch(module, a, &ks.0, scratch1); + self.sample_extract(&tmp_glwe); + } + + pub fn keyswitch( + &mut self, + module: &Module, + a: &LWECiphertext, + ksk: &LWESwitchingKeyExec, + scratch: &mut Scratch, + ) where + A: DataRef, + DKs: DataRef, + Module: GLWEKeyswitchFamily, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + { + #[cfg(debug_assertions)] + { + assert!(self.n() <= module.n()); + assert!(a.n() <= module.n()); + assert_eq!(self.basek(), a.basek()); + } + + let max_k: usize = self.k().max(a.k()); + let basek: usize = self.basek(); + + let (mut glwe, scratch1) = scratch.take_glwe_ct(ksk.n(), basek, max_k, 1); + glwe.data.zero(); + + let n_lwe: usize = a.n(); + + (0..a.size()).for_each(|i| { + let data_lwe: &[i64] = a.data.at(0, i); + glwe.data.at_mut(0, i)[0] = data_lwe[0]; + glwe.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]); + }); + + glwe.keyswitch_inplace(module, &ksk.0, scratch1); + + self.sample_extract(&glwe); + } +} + +impl GLWECiphertext> { + pub fn from_lwe_scratch_space( + module: &Module, + n: usize, + basek: usize, + k_lwe: usize, + k_glwe: usize, + k_ksk: usize, + rank: usize, + ) -> usize + where + Module: GLWEKeyswitchFamily, + { + GLWECiphertext::keyswitch_scratch_space(module, n, basek, k_glwe, k_lwe, k_ksk, 1, 1, rank) + + GLWECiphertext::bytes_of(n, basek, k_lwe, 1) + } +} + +impl GLWECiphertext { + pub fn from_lwe( + &mut self, + module: &Module, + lwe: &LWECiphertext, + ksk: &LWEToGLWESwitchingKeyExec, + scratch: &mut Scratch, + ) where + DLwe: DataRef, + DKsk: DataRef, + Module: GLWEKeyswitchFamily, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + { + #[cfg(debug_assertions)] + { + assert!(lwe.n() <= self.n()); + assert_eq!(self.basek(), self.basek()); + } + + let (mut glwe, scratch1) = scratch.take_glwe_ct(ksk.n(), lwe.basek(), lwe.k(), 1); + glwe.data.zero(); + + let n_lwe: usize = lwe.n(); + + (0..lwe.size()).for_each(|i| { + let data_lwe: &[i64] = lwe.data.at(0, i); + glwe.data.at_mut(0, i)[0] = data_lwe[0]; + glwe.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]); + }); + + self.keyswitch(module, &glwe, &ksk.0, scratch1); + } +} diff --git a/core/src/lwe/encryption.rs b/core/src/lwe/encryption.rs index 23d6826..21db6d8 100644 --- a/core/src/lwe/encryption.rs +++ b/core/src/lwe/encryption.rs @@ -34,7 +34,7 @@ impl LWECiphertext { module.vec_znx_fill_uniform(basek, &mut self.data, 0, k, source_xa); - let mut tmp_znx: VecZnx> = VecZnx::>::alloc::(1, 1, self.size()); + let mut tmp_znx: VecZnx> = VecZnx::alloc(1, 1, self.size()); let min_size = self.size().min(pt.size()); diff --git a/core/src/lwe/keyswitch_layouts_exec.rs b/core/src/lwe/keyswitch_layouts_exec.rs new file mode 100644 index 0000000..eb4faba --- /dev/null +++ b/core/src/lwe/keyswitch_layouts_exec.rs @@ -0,0 +1,257 @@ +use backend::hal::layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VmpPMat}; + +use crate::{ + GGLWEExecLayoutFamily, GLWESwitchingKeyExec, Infos, + lwe::keyswtich_layouts::{GLWEToLWESwitchingKey, LWESwitchingKey, LWEToGLWESwitchingKey}, +}; + +#[derive(PartialEq, Eq)] +pub struct GLWEToLWESwitchingKeyExec(pub(crate) GLWESwitchingKeyExec); + +impl Infos for GLWEToLWESwitchingKeyExec { + type Inner = VmpPMat; + + fn inner(&self) -> &Self::Inner { + &self.0.inner() + } + + fn basek(&self) -> usize { + self.0.basek() + } + + fn k(&self) -> usize { + self.0.k() + } +} + +impl GLWEToLWESwitchingKeyExec { + pub fn digits(&self) -> usize { + self.0.digits() + } + + pub fn rank(&self) -> usize { + self.0.rank() + } + + pub fn rank_in(&self) -> usize { + self.0.rank_in() + } + + pub fn rank_out(&self) -> usize { + self.0.rank_out() + } +} + +impl GLWEToLWESwitchingKeyExec, B> { + pub fn alloc(module: &Module, n: usize, basek: usize, k: usize, rows: usize, rank_in: usize) -> Self + where + Module: GGLWEExecLayoutFamily, + { + Self(GLWESwitchingKeyExec::alloc( + module, n, basek, k, rows, 1, rank_in, 1, + )) + } + + pub fn bytes_of(module: &Module, n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize) -> usize + where + Module: GGLWEExecLayoutFamily, + { + GLWESwitchingKeyExec::, B>::bytes_of(module, n, basek, k, rows, digits, rank_in, 1) + } + + pub fn from( + module: &Module, + other: &GLWEToLWESwitchingKey, + scratch: &mut Scratch, + ) -> Self + where + Module: GGLWEExecLayoutFamily, + { + let mut ksk_exec: GLWEToLWESwitchingKeyExec, B> = Self::alloc( + module, + other.0.n(), + other.0.basek(), + other.0.k(), + other.0.rows(), + other.0.rank_in(), + ); + ksk_exec.prepare(module, other, scratch); + ksk_exec + } +} + +impl GLWEToLWESwitchingKeyExec { + pub fn prepare(&mut self, module: &Module, other: &GLWEToLWESwitchingKey, scratch: &mut Scratch) + where + DataOther: DataRef, + Module: GGLWEExecLayoutFamily, + { + self.0.prepare(module, &other.0, scratch); + } +} + +/// A special [GLWESwitchingKey] required to for the conversion from [LWECiphertext] to [GLWECiphertext]. +#[derive(PartialEq, Eq)] +pub struct LWEToGLWESwitchingKeyExec(pub(crate) GLWESwitchingKeyExec); + +impl Infos for LWEToGLWESwitchingKeyExec { + type Inner = VmpPMat; + + fn inner(&self) -> &Self::Inner { + &self.0.inner() + } + + fn basek(&self) -> usize { + self.0.basek() + } + + fn k(&self) -> usize { + self.0.k() + } +} + +impl LWEToGLWESwitchingKeyExec { + pub fn digits(&self) -> usize { + self.0.digits() + } + + pub fn rank(&self) -> usize { + self.0.rank() + } + + pub fn rank_in(&self) -> usize { + self.0.rank_in() + } + + pub fn rank_out(&self) -> usize { + self.0.rank_out() + } +} + +impl LWEToGLWESwitchingKeyExec, B> { + pub fn alloc(module: &Module, n: usize, basek: usize, k: usize, rows: usize, rank_out: usize) -> Self + where + Module: GGLWEExecLayoutFamily, + { + Self(GLWESwitchingKeyExec::alloc( + module, n, basek, k, rows, 1, 1, rank_out, + )) + } + + pub fn bytes_of(module: &Module, n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank_out: usize) -> usize + where + Module: GGLWEExecLayoutFamily, + { + GLWESwitchingKeyExec::, B>::bytes_of(module, n, basek, k, rows, digits, 1, rank_out) + } + + pub fn from( + module: &Module, + other: &LWEToGLWESwitchingKey, + scratch: &mut Scratch, + ) -> Self + where + Module: GGLWEExecLayoutFamily, + { + let mut ksk_exec: LWEToGLWESwitchingKeyExec, B> = Self::alloc( + module, + other.0.n(), + other.0.basek(), + other.0.k(), + other.0.rows(), + other.0.rank(), + ); + ksk_exec.prepare(module, other, scratch); + ksk_exec + } +} + +impl LWEToGLWESwitchingKeyExec { + pub fn prepare(&mut self, module: &Module, other: &LWEToGLWESwitchingKey, scratch: &mut Scratch) + where + DataOther: DataRef, + Module: GGLWEExecLayoutFamily, + { + self.0.prepare(module, &other.0, scratch); + } +} + +#[derive(PartialEq, Eq)] +pub struct LWESwitchingKeyExec(pub(crate) GLWESwitchingKeyExec); + +impl Infos for LWESwitchingKeyExec { + type Inner = VmpPMat; + + fn inner(&self) -> &Self::Inner { + &self.0.inner() + } + + fn basek(&self) -> usize { + self.0.basek() + } + + fn k(&self) -> usize { + self.0.k() + } +} + +impl LWESwitchingKeyExec { + pub fn digits(&self) -> usize { + self.0.digits() + } + + pub fn rank(&self) -> usize { + self.0.rank() + } + + pub fn rank_in(&self) -> usize { + self.0.rank_in() + } + + pub fn rank_out(&self) -> usize { + self.0.rank_out() + } +} + +impl LWESwitchingKeyExec, B> { + pub fn alloc(module: &Module, n: usize, basek: usize, k: usize, rows: usize) -> Self + where + Module: GGLWEExecLayoutFamily, + { + Self(GLWESwitchingKeyExec::alloc( + module, n, basek, k, rows, 1, 1, 1, + )) + } + + pub fn bytes_of(module: &Module, n: usize, basek: usize, k: usize, rows: usize, digits: usize) -> usize + where + Module: GGLWEExecLayoutFamily, + { + GLWESwitchingKeyExec::, B>::bytes_of(module, n, basek, k, rows, digits, 1, 1) + } + + pub fn from(module: &Module, other: &LWESwitchingKey, scratch: &mut Scratch) -> Self + where + Module: GGLWEExecLayoutFamily, + { + let mut ksk_exec: LWESwitchingKeyExec, B> = Self::alloc( + module, + other.0.n(), + other.0.basek(), + other.0.k(), + other.0.rows(), + ); + ksk_exec.prepare(module, other, scratch); + ksk_exec + } +} + +impl LWESwitchingKeyExec { + pub fn prepare(&mut self, module: &Module, other: &LWESwitchingKey, scratch: &mut Scratch) + where + DataOther: DataRef, + Module: GGLWEExecLayoutFamily, + { + self.0.prepare(module, &other.0, scratch); + } +} diff --git a/core/src/lwe/keyswtich.rs b/core/src/lwe/keyswtich.rs deleted file mode 100644 index 2650f8f..0000000 --- a/core/src/lwe/keyswtich.rs +++ /dev/null @@ -1,547 +0,0 @@ -use backend::hal::{ - api::{ - MatZnxAlloc, ScalarZnxAllocBytes, ScratchAvailable, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, VecZnxAddScalarInplace, - VecZnxAllocBytes, VecZnxAutomorphismInplace, VecZnxSwithcDegree, ZnxView, ZnxViewMut, ZnxZero, - }, - layouts::{Backend, Data, DataMut, DataRef, Module, ReaderFrom, Scratch, WriterTo}, -}; -use sampling::source::Source; - -use crate::{ - GGLWEEncryptSkFamily, GGLWEExecLayoutFamily, GLWECiphertext, GLWEKeyswitchFamily, GLWESecret, GLWESecretExec, - GLWESwitchingKey, GLWESwitchingKeyExec, Infos, LWECiphertext, LWESecret, TakeGLWECt, TakeGLWESecret, TakeGLWESecretExec, -}; - -/// A special [GLWESwitchingKey] required to for the conversion from [GLWECiphertext] to [LWECiphertext]. -#[derive(PartialEq, Eq)] -pub struct GLWEToLWESwitchingKey(GLWESwitchingKey); - -impl ReaderFrom for GLWEToLWESwitchingKey { - fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { - self.0.read_from(reader) - } -} - -impl WriterTo for GLWEToLWESwitchingKey { - fn write_to(&self, writer: &mut W) -> std::io::Result<()> { - self.0.write_to(writer) - } -} - -#[derive(PartialEq, Eq)] -pub struct GLWEToLWESwitchingKeyExec(GLWESwitchingKeyExec); - -impl GLWEToLWESwitchingKeyExec, B> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, rank_in: usize) -> Self - where - Module: GGLWEExecLayoutFamily, - { - Self(GLWESwitchingKeyExec::alloc( - module, basek, k, rows, 1, rank_in, 1, - )) - } - - pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize) -> usize - where - Module: GGLWEExecLayoutFamily, - { - GLWESwitchingKeyExec::, B>::bytes_of(module, basek, k, rows, digits, rank_in, 1) - } - - pub fn from( - module: &Module, - other: &GLWEToLWESwitchingKey, - scratch: &mut Scratch, - ) -> Self - where - Module: GGLWEExecLayoutFamily, - { - let mut ksk_exec: GLWEToLWESwitchingKeyExec, B> = Self::alloc( - module, - other.0.basek(), - other.0.k(), - other.0.rows(), - other.0.rank_in(), - ); - ksk_exec.prepare(module, other, scratch); - ksk_exec - } -} - -impl GLWEToLWESwitchingKeyExec { - pub fn prepare(&mut self, module: &Module, other: &GLWEToLWESwitchingKey, scratch: &mut Scratch) - where - DataOther: DataRef, - Module: GGLWEExecLayoutFamily, - { - self.0.prepare(module, &other.0, scratch); - } -} - -impl GLWEToLWESwitchingKey> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, rank_in: usize) -> Self - where - Module: MatZnxAlloc, - { - Self(GLWESwitchingKey::alloc( - module, basek, k, rows, 1, rank_in, 1, - )) - } - - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank_in: usize) -> usize - where - Module: GGLWEEncryptSkFamily + ScalarZnxAllocBytes + VecZnxAllocBytes, - { - GLWESecretExec::bytes_of(module, rank_in) - + (GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, rank_in, 1) | GLWESecret::bytes_of(module, rank_in)) - } -} - -impl GLWEToLWESwitchingKey { - pub fn encrypt_sk( - &mut self, - module: &Module, - sk_lwe: &LWESecret, - sk_glwe: &GLWESecret, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - scratch: &mut Scratch, - ) where - DLwe: DataRef, - DGlwe: DataRef, - Module: GGLWEEncryptSkFamily - + VecZnxAutomorphismInplace - + ScalarZnxAllocBytes - + VecZnxSwithcDegree - + VecZnxAllocBytes - + VecZnxAddScalarInplace, - Scratch: ScratchAvailable + TakeScalarZnx + TakeVecZnxDft + TakeGLWESecretExec + TakeVecZnx, - { - #[cfg(debug_assertions)] - { - assert!(sk_lwe.n() <= module.n()); - } - - let (mut sk_lwe_as_glwe, scratch1) = scratch.take_glwe_secret(module, 1); - sk_lwe_as_glwe.data.zero(); - sk_lwe_as_glwe.data.at_mut(0, 0)[..sk_lwe.n()].copy_from_slice(sk_lwe.data.at(0, 0)); - module.vec_znx_automorphism_inplace(-1, &mut sk_lwe_as_glwe.data.as_vec_znx_mut(), 0); - - self.0.encrypt_sk( - module, - sk_glwe, - &sk_lwe_as_glwe, - source_xa, - source_xe, - sigma, - scratch1, - ); - } -} - -/// A special [GLWESwitchingKey] required to for the conversion from [LWECiphertext] to [GLWECiphertext]. -#[derive(PartialEq, Eq)] -pub struct LWEToGLWESwitchingKeyExec(GLWESwitchingKeyExec); - -impl LWEToGLWESwitchingKeyExec, B> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, rank_out: usize) -> Self - where - Module: GGLWEExecLayoutFamily, - { - Self(GLWESwitchingKeyExec::alloc( - module, basek, k, rows, 1, 1, rank_out, - )) - } - - pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank_out: usize) -> usize - where - Module: GGLWEExecLayoutFamily, - { - GLWESwitchingKeyExec::, B>::bytes_of(module, basek, k, rows, digits, 1, rank_out) - } - - pub fn from( - module: &Module, - other: &LWEToGLWESwitchingKey, - scratch: &mut Scratch, - ) -> Self - where - Module: GGLWEExecLayoutFamily, - { - let mut ksk_exec: LWEToGLWESwitchingKeyExec, B> = Self::alloc( - module, - other.0.basek(), - other.0.k(), - other.0.rows(), - other.0.rank(), - ); - ksk_exec.prepare(module, other, scratch); - ksk_exec - } -} - -impl LWEToGLWESwitchingKeyExec { - pub fn prepare(&mut self, module: &Module, other: &LWEToGLWESwitchingKey, scratch: &mut Scratch) - where - DataOther: DataRef, - Module: GGLWEExecLayoutFamily, - { - self.0.prepare(module, &other.0, scratch); - } -} -#[derive(PartialEq, Eq)] -pub struct LWEToGLWESwitchingKey(GLWESwitchingKey); - -impl ReaderFrom for LWEToGLWESwitchingKey { - fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { - self.0.read_from(reader) - } -} - -impl WriterTo for LWEToGLWESwitchingKey { - fn write_to(&self, writer: &mut W) -> std::io::Result<()> { - self.0.write_to(writer) - } -} - -impl LWEToGLWESwitchingKey> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, rank_out: usize) -> Self - where - Module: MatZnxAlloc, - { - Self(GLWESwitchingKey::alloc( - module, basek, k, rows, 1, 1, rank_out, - )) - } - - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank_out: usize) -> usize - where - Module: GGLWEEncryptSkFamily + ScalarZnxAllocBytes + VecZnxAllocBytes, - { - GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, 1, rank_out) + GLWESecret::bytes_of(module, 1) - } -} - -impl LWEToGLWESwitchingKey { - pub fn encrypt_sk( - &mut self, - module: &Module, - sk_lwe: &LWESecret, - sk_glwe: &GLWESecret, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - scratch: &mut Scratch, - ) where - DLwe: DataRef, - DGlwe: DataRef, - Module: GGLWEEncryptSkFamily - + VecZnxAutomorphismInplace - + ScalarZnxAllocBytes - + VecZnxSwithcDegree - + VecZnxAllocBytes - + VecZnxAddScalarInplace, - Scratch: ScratchAvailable + TakeScalarZnx + TakeVecZnxDft + TakeGLWESecretExec + TakeVecZnx, - { - #[cfg(debug_assertions)] - { - assert!(sk_lwe.n() <= module.n()); - } - - let (mut sk_lwe_as_glwe, scratch1) = scratch.take_glwe_secret(module, 1); - sk_lwe_as_glwe.data.at_mut(0, 0)[..sk_lwe.n()].copy_from_slice(sk_lwe.data.at(0, 0)); - sk_lwe_as_glwe.data.at_mut(0, 0)[sk_lwe.n()..].fill(0); - module.vec_znx_automorphism_inplace(-1, &mut sk_lwe_as_glwe.data.as_vec_znx_mut(), 0); - - self.0.encrypt_sk( - module, - &sk_lwe_as_glwe, - &sk_glwe, - source_xa, - source_xe, - sigma, - scratch1, - ); - } -} - -#[derive(PartialEq, Eq)] -pub struct LWESwitchingKeyExec(GLWESwitchingKeyExec); - -impl LWESwitchingKeyExec, B> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize) -> Self - where - Module: GGLWEExecLayoutFamily, - { - Self(GLWESwitchingKeyExec::alloc(module, basek, k, rows, 1, 1, 1)) - } - - pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize) -> usize - where - Module: GGLWEExecLayoutFamily, - { - GLWESwitchingKeyExec::, B>::bytes_of(module, basek, k, rows, digits, 1, 1) - } - - pub fn from(module: &Module, other: &LWESwitchingKey, scratch: &mut Scratch) -> Self - where - Module: GGLWEExecLayoutFamily, - { - let mut ksk_exec: LWESwitchingKeyExec, B> = Self::alloc(module, other.0.basek(), other.0.k(), other.0.rows()); - ksk_exec.prepare(module, other, scratch); - ksk_exec - } -} - -impl LWESwitchingKeyExec { - pub fn prepare(&mut self, module: &Module, other: &LWESwitchingKey, scratch: &mut Scratch) - where - DataOther: DataRef, - Module: GGLWEExecLayoutFamily, - { - self.0.prepare(module, &other.0, scratch); - } -} -#[derive(PartialEq, Eq)] -pub struct LWESwitchingKey(GLWESwitchingKey); - -impl ReaderFrom for LWESwitchingKey { - fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { - self.0.read_from(reader) - } -} - -impl WriterTo for LWESwitchingKey { - fn write_to(&self, writer: &mut W) -> std::io::Result<()> { - self.0.write_to(writer) - } -} - -impl LWESwitchingKey> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize) -> Self - where - Module: MatZnxAlloc, - { - Self(GLWESwitchingKey::alloc(module, basek, k, rows, 1, 1, 1)) - } - - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize) -> usize - where - Module: GGLWEEncryptSkFamily + ScalarZnxAllocBytes + VecZnxAllocBytes, - { - GLWESecret::bytes_of(module, 1) - + GLWESecretExec::bytes_of(module, 1) - + GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, 1, 1) - } -} - -impl LWESwitchingKey { - pub fn encrypt_sk( - &mut self, - module: &Module, - sk_lwe_in: &LWESecret, - sk_lwe_out: &LWESecret, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - scratch: &mut Scratch, - ) where - DIn: DataRef, - DOut: DataRef, - Module: GGLWEEncryptSkFamily - + VecZnxAutomorphismInplace - + ScalarZnxAllocBytes - + VecZnxSwithcDegree - + VecZnxAllocBytes - + VecZnxAddScalarInplace, - Scratch: ScratchAvailable + TakeScalarZnx + TakeVecZnxDft + TakeGLWESecretExec + TakeVecZnx, - { - #[cfg(debug_assertions)] - { - assert!(sk_lwe_in.n() <= module.n()); - assert!(sk_lwe_out.n() <= module.n()); - } - - let (mut sk_in_glwe, scratch1) = scratch.take_glwe_secret(module, 1); - let (mut sk_out_glwe, scratch2) = scratch1.take_glwe_secret(module, 1); - - sk_out_glwe.data.at_mut(0, 0)[..sk_lwe_out.n()].copy_from_slice(sk_lwe_out.data.at(0, 0)); - sk_out_glwe.data.at_mut(0, 0)[sk_lwe_out.n()..].fill(0); - module.vec_znx_automorphism_inplace(-1, &mut sk_out_glwe.data.as_vec_znx_mut(), 0); - - sk_in_glwe.data.at_mut(0, 0)[..sk_lwe_in.n()].copy_from_slice(sk_lwe_in.data.at(0, 0)); - sk_in_glwe.data.at_mut(0, 0)[sk_lwe_in.n()..].fill(0); - module.vec_znx_automorphism_inplace(-1, &mut sk_in_glwe.data.as_vec_znx_mut(), 0); - - self.0.encrypt_sk( - module, - &sk_in_glwe, - &sk_out_glwe, - source_xa, - source_xe, - sigma, - scratch2, - ); - } -} - -impl LWECiphertext> { - pub fn from_glwe_scratch_space( - module: &Module, - basek: usize, - k_lwe: usize, - k_glwe: usize, - k_ksk: usize, - rank: usize, - ) -> usize - where - Module: GLWEKeyswitchFamily + VecZnxAllocBytes, - { - GLWECiphertext::bytes_of(module, basek, k_lwe, 1) - + GLWECiphertext::keyswitch_scratch_space(module, basek, k_lwe, k_glwe, k_ksk, 1, rank, 1) - } - - pub fn keyswitch_scratch_space( - module: &Module, - basek: usize, - k_lwe_out: usize, - k_lwe_in: usize, - k_ksk: usize, - ) -> usize - where - Module: GLWEKeyswitchFamily + ScalarZnxAllocBytes + VecZnxAllocBytes, - { - GLWECiphertext::bytes_of(module, basek, k_lwe_out.max(k_lwe_in), 1) - + GLWECiphertext::keyswitch_inplace_scratch_space(module, basek, k_lwe_out, k_ksk, 1, 1) - } -} - -impl LWECiphertext { - pub fn sample_extract(&mut self, a: &GLWECiphertext) { - #[cfg(debug_assertions)] - { - assert!(self.n() <= a.n()); - } - - let min_size: usize = self.size().min(a.size()); - let n: usize = self.n(); - - self.data.zero(); - (0..min_size).for_each(|i| { - let data_lwe: &mut [i64] = self.data.at_mut(0, i); - data_lwe[0] = a.data.at(0, i)[0]; - data_lwe[1..].copy_from_slice(&a.data.at(1, i)[..n]); - }); - } - - pub fn from_glwe( - &mut self, - module: &Module, - a: &GLWECiphertext, - ks: &GLWEToLWESwitchingKeyExec, - scratch: &mut Scratch, - ) where - DGlwe: DataRef, - DKs: DataRef, - Module: GLWEKeyswitchFamily, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, - { - #[cfg(debug_assertions)] - { - assert_eq!(self.basek(), a.basek()); - } - let (mut tmp_glwe, scratch1) = scratch.take_glwe_ct(module, a.basek(), self.k(), 1); - tmp_glwe.keyswitch(module, a, &ks.0, scratch1); - self.sample_extract(&tmp_glwe); - } - - pub fn keyswitch( - &mut self, - module: &Module, - a: &LWECiphertext, - ksk: &LWESwitchingKeyExec, - scratch: &mut Scratch, - ) where - A: DataRef, - DKs: DataRef, - Module: GLWEKeyswitchFamily, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, - { - #[cfg(debug_assertions)] - { - assert!(self.n() <= module.n()); - assert!(a.n() <= module.n()); - assert_eq!(self.basek(), a.basek()); - } - - let max_k: usize = self.k().max(a.k()); - let basek: usize = self.basek(); - - let (mut glwe, scratch1) = scratch.take_glwe_ct(&module, basek, max_k, 1); - glwe.data.zero(); - - let n_lwe: usize = a.n(); - - (0..a.size()).for_each(|i| { - let data_lwe: &[i64] = a.data.at(0, i); - glwe.data.at_mut(0, i)[0] = data_lwe[0]; - glwe.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]); - }); - - glwe.keyswitch_inplace(module, &ksk.0, scratch1); - - self.sample_extract(&glwe); - } -} - -impl GLWECiphertext> { - pub fn from_lwe_scratch_space( - module: &Module, - basek: usize, - k_lwe: usize, - k_glwe: usize, - k_ksk: usize, - rank: usize, - ) -> usize - where - Module: GLWEKeyswitchFamily + VecZnxAllocBytes, - { - GLWECiphertext::keyswitch_scratch_space(module, basek, k_glwe, k_lwe, k_ksk, 1, 1, rank) - + GLWECiphertext::bytes_of(module, basek, k_lwe, 1) - } -} - -impl GLWECiphertext { - pub fn from_lwe( - &mut self, - module: &Module, - lwe: &LWECiphertext, - ksk: &LWEToGLWESwitchingKeyExec, - scratch: &mut Scratch, - ) where - DLwe: DataRef, - DKsk: DataRef, - Module: GLWEKeyswitchFamily, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, - { - #[cfg(debug_assertions)] - { - assert!(lwe.n() <= self.n()); - assert_eq!(self.basek(), self.basek()); - } - - let (mut glwe, scratch1) = scratch.take_glwe_ct(module, lwe.basek(), lwe.k(), 1); - glwe.data.zero(); - - let n_lwe: usize = lwe.n(); - - (0..lwe.size()).for_each(|i| { - let data_lwe: &[i64] = lwe.data.at(0, i); - glwe.data.at_mut(0, i)[0] = data_lwe[0]; - glwe.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]); - }); - - self.keyswitch(module, &glwe, &ksk.0, scratch1); - } -} diff --git a/core/src/lwe/keyswtich_layouts.rs b/core/src/lwe/keyswtich_layouts.rs new file mode 100644 index 0000000..1348b87 --- /dev/null +++ b/core/src/lwe/keyswtich_layouts.rs @@ -0,0 +1,358 @@ +use backend::hal::{ + api::{ + ScratchAvailable, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, VecZnxAddScalarInplace, VecZnxAutomorphismInplace, + VecZnxSwithcDegree, ZnxView, ZnxViewMut, ZnxZero, + }, + layouts::{Backend, Data, DataMut, DataRef, MatZnx, Module, ReaderFrom, Scratch, WriterTo}, +}; +use sampling::source::Source; + +use crate::{ + GGLWEEncryptSkFamily, GLWECiphertext, GLWEKeyswitchFamily, GLWESecret, GLWESecretExec, GLWESwitchingKey, Infos, + LWECiphertext, LWESecret, TakeGLWESecret, TakeGLWESecretExec, +}; + +/// A special [GLWESwitchingKey] required to for the conversion from [GLWECiphertext] to [LWECiphertext]. +#[derive(PartialEq, Eq)] +pub struct GLWEToLWESwitchingKey(pub(crate) GLWESwitchingKey); + +impl Infos for GLWEToLWESwitchingKey { + type Inner = MatZnx; + + fn inner(&self) -> &Self::Inner { + &self.0.inner() + } + + fn basek(&self) -> usize { + self.0.basek() + } + + fn k(&self) -> usize { + self.0.k() + } +} + +impl GLWEToLWESwitchingKey { + pub fn digits(&self) -> usize { + self.0.digits() + } + + pub fn rank(&self) -> usize { + self.0.rank() + } + + pub fn rank_in(&self) -> usize { + self.0.rank_in() + } + + pub fn rank_out(&self) -> usize { + self.0.rank_out() + } +} + +impl ReaderFrom for GLWEToLWESwitchingKey { + fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { + self.0.read_from(reader) + } +} + +impl WriterTo for GLWEToLWESwitchingKey { + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + self.0.write_to(writer) + } +} + +impl GLWEToLWESwitchingKey> { + pub fn alloc(n: usize, basek: usize, k: usize, rows: usize, rank_in: usize) -> Self { + Self(GLWESwitchingKey::alloc(n, basek, k, rows, 1, rank_in, 1)) + } + + pub fn encrypt_sk_scratch_space(module: &Module, n: usize, basek: usize, k: usize, rank_in: usize) -> usize + where + Module: GGLWEEncryptSkFamily, + { + GLWESecretExec::bytes_of(module, n, rank_in) + + (GLWESwitchingKey::encrypt_sk_scratch_space(module, n, basek, k, rank_in, 1) | GLWESecret::bytes_of(n, rank_in)) + } +} + +impl GLWEToLWESwitchingKey { + pub fn encrypt_sk( + &mut self, + module: &Module, + sk_lwe: &LWESecret, + sk_glwe: &GLWESecret, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + DLwe: DataRef, + DGlwe: DataRef, + Module: GGLWEEncryptSkFamily + VecZnxAutomorphismInplace + VecZnxSwithcDegree + VecZnxAddScalarInplace, + Scratch: ScratchAvailable + TakeScalarZnx + TakeVecZnxDft + TakeGLWESecretExec + TakeVecZnx, + { + #[cfg(debug_assertions)] + { + assert!(sk_lwe.n() <= module.n()); + } + + let (mut sk_lwe_as_glwe, scratch1) = scratch.take_glwe_secret(sk_glwe.n(), 1); + sk_lwe_as_glwe.data.zero(); + sk_lwe_as_glwe.data.at_mut(0, 0)[..sk_lwe.n()].copy_from_slice(sk_lwe.data.at(0, 0)); + module.vec_znx_automorphism_inplace(-1, &mut sk_lwe_as_glwe.data.as_vec_znx_mut(), 0); + + self.0.encrypt_sk( + module, + sk_glwe, + &sk_lwe_as_glwe, + source_xa, + source_xe, + sigma, + scratch1, + ); + } +} + +#[derive(PartialEq, Eq)] +pub struct LWEToGLWESwitchingKey(pub(crate) GLWESwitchingKey); + +impl Infos for LWEToGLWESwitchingKey { + type Inner = MatZnx; + + fn inner(&self) -> &Self::Inner { + &self.0.inner() + } + + fn basek(&self) -> usize { + self.0.basek() + } + + fn k(&self) -> usize { + self.0.k() + } +} + +impl LWEToGLWESwitchingKey { + pub fn digits(&self) -> usize { + self.0.digits() + } + + pub fn rank(&self) -> usize { + self.0.rank() + } + + pub fn rank_in(&self) -> usize { + self.0.rank_in() + } + + pub fn rank_out(&self) -> usize { + self.0.rank_out() + } +} + +impl ReaderFrom for LWEToGLWESwitchingKey { + fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { + self.0.read_from(reader) + } +} + +impl WriterTo for LWEToGLWESwitchingKey { + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + self.0.write_to(writer) + } +} + +impl LWEToGLWESwitchingKey> { + pub fn alloc(n: usize, basek: usize, k: usize, rows: usize, rank_out: usize) -> Self { + Self(GLWESwitchingKey::alloc(n, basek, k, rows, 1, 1, rank_out)) + } + + pub fn encrypt_sk_scratch_space(module: &Module, n: usize, basek: usize, k: usize, rank_out: usize) -> usize + where + Module: GGLWEEncryptSkFamily, + { + GLWESwitchingKey::encrypt_sk_scratch_space(module, n, basek, k, 1, rank_out) + GLWESecret::bytes_of(n, 1) + } +} + +impl LWEToGLWESwitchingKey { + pub fn encrypt_sk( + &mut self, + module: &Module, + sk_lwe: &LWESecret, + sk_glwe: &GLWESecret, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + DLwe: DataRef, + DGlwe: DataRef, + Module: GGLWEEncryptSkFamily + VecZnxAutomorphismInplace + VecZnxSwithcDegree + VecZnxAddScalarInplace, + Scratch: ScratchAvailable + TakeScalarZnx + TakeVecZnxDft + TakeGLWESecretExec + TakeVecZnx, + { + #[cfg(debug_assertions)] + { + assert!(sk_lwe.n() <= module.n()); + } + + let (mut sk_lwe_as_glwe, scratch1) = scratch.take_glwe_secret(sk_glwe.n(), 1); + sk_lwe_as_glwe.data.at_mut(0, 0)[..sk_lwe.n()].copy_from_slice(sk_lwe.data.at(0, 0)); + sk_lwe_as_glwe.data.at_mut(0, 0)[sk_lwe.n()..].fill(0); + module.vec_znx_automorphism_inplace(-1, &mut sk_lwe_as_glwe.data.as_vec_znx_mut(), 0); + + self.0.encrypt_sk( + module, + &sk_lwe_as_glwe, + &sk_glwe, + source_xa, + source_xe, + sigma, + scratch1, + ); + } +} + +#[derive(PartialEq, Eq)] +pub struct LWESwitchingKey(pub(crate) GLWESwitchingKey); + +impl Infos for LWESwitchingKey { + type Inner = MatZnx; + + fn inner(&self) -> &Self::Inner { + &self.0.inner() + } + + fn basek(&self) -> usize { + self.0.basek() + } + + fn k(&self) -> usize { + self.0.k() + } +} + +impl LWESwitchingKey { + pub fn digits(&self) -> usize { + self.0.digits() + } + + pub fn rank(&self) -> usize { + self.0.rank() + } + + pub fn rank_in(&self) -> usize { + self.0.rank_in() + } + + pub fn rank_out(&self) -> usize { + self.0.rank_out() + } +} + +impl ReaderFrom for LWESwitchingKey { + fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { + self.0.read_from(reader) + } +} + +impl WriterTo for LWESwitchingKey { + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + self.0.write_to(writer) + } +} + +impl LWESwitchingKey> { + pub fn alloc(n: usize, basek: usize, k: usize, rows: usize) -> Self { + Self(GLWESwitchingKey::alloc(n, basek, k, rows, 1, 1, 1)) + } + + pub fn encrypt_sk_scratch_space(module: &Module, n: usize, basek: usize, k: usize) -> usize + where + Module: GGLWEEncryptSkFamily, + { + GLWESecret::bytes_of(n, 1) + + GLWESecretExec::bytes_of(module, n, 1) + + GLWESwitchingKey::encrypt_sk_scratch_space(module, n, basek, k, 1, 1) + } +} + +impl LWESwitchingKey { + pub fn encrypt_sk( + &mut self, + module: &Module, + sk_lwe_in: &LWESecret, + sk_lwe_out: &LWESecret, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + DIn: DataRef, + DOut: DataRef, + Module: GGLWEEncryptSkFamily + VecZnxAutomorphismInplace + VecZnxSwithcDegree + VecZnxAddScalarInplace, + Scratch: ScratchAvailable + TakeScalarZnx + TakeVecZnxDft + TakeGLWESecretExec + TakeVecZnx, + { + #[cfg(debug_assertions)] + { + assert!(sk_lwe_in.n() <= self.n()); + assert!(sk_lwe_out.n() <= self.n()); + assert!(self.n() <= module.n()); + } + + let (mut sk_in_glwe, scratch1) = scratch.take_glwe_secret(self.n(), 1); + let (mut sk_out_glwe, scratch2) = scratch1.take_glwe_secret(self.n(), 1); + + sk_out_glwe.data.at_mut(0, 0)[..sk_lwe_out.n()].copy_from_slice(sk_lwe_out.data.at(0, 0)); + sk_out_glwe.data.at_mut(0, 0)[sk_lwe_out.n()..].fill(0); + module.vec_znx_automorphism_inplace(-1, &mut sk_out_glwe.data.as_vec_znx_mut(), 0); + + sk_in_glwe.data.at_mut(0, 0)[..sk_lwe_in.n()].copy_from_slice(sk_lwe_in.data.at(0, 0)); + sk_in_glwe.data.at_mut(0, 0)[sk_lwe_in.n()..].fill(0); + module.vec_znx_automorphism_inplace(-1, &mut sk_in_glwe.data.as_vec_znx_mut(), 0); + + self.0.encrypt_sk( + module, + &sk_in_glwe, + &sk_out_glwe, + source_xa, + source_xe, + sigma, + scratch2, + ); + } +} + +impl LWECiphertext> { + pub fn from_glwe_scratch_space( + module: &Module, + n: usize, + basek: usize, + k_lwe: usize, + k_glwe: usize, + k_ksk: usize, + rank: usize, + ) -> usize + where + Module: GLWEKeyswitchFamily, + { + GLWECiphertext::bytes_of(n, basek, k_lwe, 1) + + GLWECiphertext::keyswitch_scratch_space(module, n, basek, k_lwe, k_glwe, k_ksk, 1, rank, 1) + } + + pub fn keyswitch_scratch_space( + module: &Module, + n: usize, + basek: usize, + k_lwe_out: usize, + k_lwe_in: usize, + k_ksk: usize, + ) -> usize + where + Module: GLWEKeyswitchFamily, + { + GLWECiphertext::bytes_of(n, basek, k_lwe_out.max(k_lwe_in), 1) + + GLWECiphertext::keyswitch_inplace_scratch_space(module, n, basek, k_lwe_out, k_ksk, 1, 1) + } +} diff --git a/core/src/lwe/ciphertext.rs b/core/src/lwe/layouts.rs similarity index 70% rename from core/src/lwe/ciphertext.rs rename to core/src/lwe/layouts.rs index ef7c6d1..2b0c773 100644 --- a/core/src/lwe/ciphertext.rs +++ b/core/src/lwe/layouts.rs @@ -1,20 +1,62 @@ +use std::fmt; + use backend::hal::{ - api::ZnxInfos, + api::{FillUniform, Reset, ZnxInfos}, layouts::{Data, DataMut, DataRef, ReaderFrom, VecZnx, VecZnxToMut, VecZnxToRef, WriterTo}, }; +use sampling::source::Source; use crate::{Infos, SetMetaData}; +#[derive(PartialEq, Eq, Clone)] pub struct LWECiphertext { pub(crate) data: VecZnx, pub(crate) k: usize, pub(crate) basek: usize, } +impl fmt::Debug for LWECiphertext { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self) + } +} + +impl fmt::Display for LWECiphertext { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "LWECiphertext: basek={} k={}: {}", + self.basek(), + self.k(), + self.data + ) + } +} + +impl Reset for LWECiphertext +where + VecZnx: Reset, +{ + fn reset(&mut self) { + self.data.reset(); + self.basek = 0; + self.k = 0; + } +} + +impl FillUniform for LWECiphertext +where + VecZnx: FillUniform, +{ + fn fill_uniform(&mut self, source: &mut Source) { + self.data.fill_uniform(source); + } +} + impl LWECiphertext> { pub fn alloc(n: usize, basek: usize, k: usize) -> Self { Self { - data: VecZnx::alloc::(n + 1, 1, k.div_ceil(basek)), + data: VecZnx::alloc(n + 1, 1, k.div_ceil(basek)), k: k, basek: basek, } diff --git a/core/src/lwe/layouts_compressed.rs b/core/src/lwe/layouts_compressed.rs new file mode 100644 index 0000000..28dd8b1 --- /dev/null +++ b/core/src/lwe/layouts_compressed.rs @@ -0,0 +1,134 @@ +use std::fmt; + +use backend::hal::{ + api::{FillUniform, Reset, VecZnxFillUniform, ZnxInfos, ZnxView, ZnxViewMut}, + layouts::{Backend, Data, DataMut, DataRef, Module, ReaderFrom, VecZnx, WriterTo}, +}; +use sampling::source::Source; + +use crate::{Decompress, Infos, LWECiphertext, SetMetaData}; + +#[derive(PartialEq, Eq, Clone)] +pub struct LWECiphertextCompressed { + pub(crate) data: VecZnx, + pub(crate) k: usize, + pub(crate) basek: usize, + pub(crate) seed: [u8; 32], +} + +impl fmt::Debug for LWECiphertextCompressed { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self) + } +} + +impl fmt::Display for LWECiphertextCompressed { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "LWECiphertextCompressed: basek={} k={} seed={:?}: {}", + self.basek(), + self.k(), + self.seed, + self.data + ) + } +} + +impl Reset for LWECiphertextCompressed +where + VecZnx: Reset, +{ + fn reset(&mut self) { + self.data.reset(); + self.basek = 0; + self.k = 0; + self.seed = [0u8; 32]; + } +} + +impl FillUniform for LWECiphertextCompressed +where + VecZnx: FillUniform, +{ + fn fill_uniform(&mut self, source: &mut Source) { + self.data.fill_uniform(source); + } +} + +impl LWECiphertextCompressed> { + pub fn alloc(basek: usize, k: usize) -> Self { + Self { + data: VecZnx::alloc(1, 1, k.div_ceil(basek)), + k: k, + basek: basek, + seed: [0u8; 32], + } + } +} + +impl Infos for LWECiphertextCompressed +where + VecZnx: ZnxInfos, +{ + type Inner = VecZnx; + + fn n(&self) -> usize { + &self.inner().n() - 1 + } + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn basek(&self) -> usize { + self.basek + } + + fn k(&self) -> usize { + self.k + } +} + +impl SetMetaData for LWECiphertextCompressed { + fn set_k(&mut self, k: usize) { + self.k = k + } + + fn set_basek(&mut self, basek: usize) { + self.basek = basek + } +} + +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; + +impl ReaderFrom for LWECiphertextCompressed { + fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { + self.k = reader.read_u64::()? as usize; + self.basek = reader.read_u64::()? as usize; + reader.read(&mut self.seed)?; + self.data.read_from(reader) + } +} + +impl WriterTo for LWECiphertextCompressed { + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + writer.write_u64::(self.k as u64)?; + writer.write_u64::(self.basek as u64)?; + writer.write_all(&self.seed)?; + self.data.write_to(writer) + } +} + +impl Decompress> for LWECiphertext { + fn decompress(&mut self, module: &Module, other: &LWECiphertextCompressed) + where + Module: VecZnxFillUniform, + { + let mut source = Source::new(other.seed); + module.vec_znx_fill_uniform(other.basek(), &mut self.data, 0, other.k(), &mut source); + (0..self.size()).for_each(|i| { + self.data.at_mut(0, i)[0] = other.data.at(0, i)[0]; + }); + } +} diff --git a/core/src/lwe/mod.rs b/core/src/lwe/mod.rs index 1e3d351..0f91f3c 100644 --- a/core/src/lwe/mod.rs +++ b/core/src/lwe/mod.rs @@ -1,13 +1,19 @@ -pub mod ciphertext; -pub mod decryption; -pub mod encryption; -pub mod keyswtich; -pub mod plaintext; -pub mod secret; +mod conversion; +mod decryption; +mod encryption; +mod keyswitch_layouts_exec; +mod keyswtich_layouts; +mod layouts; +mod layouts_compressed; +mod plaintext; +mod secret; -pub use ciphertext::LWECiphertext; -pub use plaintext::LWEPlaintext; -pub use secret::LWESecret; +pub use keyswitch_layouts_exec::*; +pub use keyswtich_layouts::*; +pub use layouts::*; +pub use layouts_compressed::*; +pub use plaintext::*; +pub use secret::*; #[cfg(test)] -pub mod test_fft64; +pub mod tests; diff --git a/core/src/lwe/plaintext.rs b/core/src/lwe/plaintext.rs index fe1c6d1..d5f1aa2 100644 --- a/core/src/lwe/plaintext.rs +++ b/core/src/lwe/plaintext.rs @@ -11,7 +11,7 @@ pub struct LWEPlaintext { impl LWEPlaintext> { pub fn alloc(basek: usize, k: usize) -> Self { Self { - data: VecZnx::alloc::(1, 1, k.div_ceil(basek)), + data: VecZnx::alloc(1, 1, k.div_ceil(basek)), k: k, basek: basek, } diff --git a/core/src/lwe/test_fft64/mod.rs b/core/src/lwe/test_fft64/mod.rs deleted file mode 100644 index 11eb2fc..0000000 --- a/core/src/lwe/test_fft64/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod conversion; diff --git a/core/src/lwe/tests/cpu_spqlios/fft64.rs b/core/src/lwe/tests/cpu_spqlios/fft64.rs new file mode 100644 index 0000000..0620b5a --- /dev/null +++ b/core/src/lwe/tests/cpu_spqlios/fft64.rs @@ -0,0 +1,27 @@ +use backend::{ + hal::{api::ModuleNew, layouts::Module}, + implementation::cpu_spqlios::FFT64, +}; + +use crate::tests::generic_conversion::{test_glwe_to_lwe, test_keyswitch, test_lwe_to_glwe}; + +#[test] +fn lwe_to_glwe() { + let log_n: usize = 5; + let module: Module = Module::::new(1 << log_n); + test_lwe_to_glwe(&module) +} + +#[test] +fn glwe_to_lwe() { + let log_n: usize = 5; + let module: Module = Module::::new(1 << log_n); + test_glwe_to_lwe(&module) +} + +#[test] +fn keyswitch() { + let log_n: usize = 5; + let module: Module = Module::::new(1 << log_n); + test_keyswitch(&module) +} diff --git a/core/src/lwe/tests/cpu_spqlios/mod.rs b/core/src/lwe/tests/cpu_spqlios/mod.rs new file mode 100644 index 0000000..aebaafb --- /dev/null +++ b/core/src/lwe/tests/cpu_spqlios/mod.rs @@ -0,0 +1 @@ +mod fft64; diff --git a/core/src/lwe/test_fft64/conversion.rs b/core/src/lwe/tests/generic_conversion.rs similarity index 72% rename from core/src/lwe/test_fft64/conversion.rs rename to core/src/lwe/tests/generic_conversion.rs index 3d3a2f1..a54f521 100644 --- a/core/src/lwe/test_fft64/conversion.rs +++ b/core/src/lwe/tests/generic_conversion.rs @@ -1,51 +1,29 @@ -use backend::{ - hal::{ - api::{ - MatZnxAlloc, ModuleNew, ScalarZnxAlloc, ScalarZnxAllocBytes, ScratchOwnedAlloc, ScratchOwnedBorrow, - VecZnxAddScalarInplace, VecZnxAlloc, VecZnxAllocBytes, VecZnxAutomorphismInplace, VecZnxEncodeCoeffsi64, - VecZnxSwithcDegree, ZnxView, - }, - layouts::{Backend, Module, ScratchOwned}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, - TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, - }, +use backend::hal::{ + api::{ + ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxEncodeCoeffsi64, + VecZnxSwithcDegree, ZnxView, + }, + layouts::{Backend, Module, ScratchOwned}, + oep::{ + ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, + TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, }, - implementation::cpu_spqlios::FFT64, }; use sampling::source::Source; use crate::{ GGLWEEncryptSkFamily, GGLWEExecLayoutFamily, GLWECiphertext, GLWEDecryptFamily, GLWEKeyswitchFamily, GLWEPlaintext, - GLWESecret, GLWESecretExec, Infos, LWECiphertext, LWESecret, - lwe::{ - LWEPlaintext, - keyswtich::{ - GLWEToLWESwitchingKey, GLWEToLWESwitchingKeyExec, LWESwitchingKey, LWESwitchingKeyExec, LWEToGLWESwitchingKey, - LWEToGLWESwitchingKeyExec, - }, - }, + GLWESecret, GLWESecretExec, GLWEToLWESwitchingKey, GLWEToLWESwitchingKeyExec, Infos, LWECiphertext, LWEPlaintext, LWESecret, + LWESwitchingKey, LWESwitchingKeyExec, LWEToGLWESwitchingKey, LWEToGLWESwitchingKeyExec, }; -#[test] -fn lwe_to_glwe() { - let log_n: usize = 5; - let module: Module = Module::::new(1 << log_n); - test_lwe_to_glwe(&module) -} - pub(crate) trait LWETestModuleFamily = GGLWEEncryptSkFamily + GLWEDecryptFamily + VecZnxSwithcDegree + VecZnxAddScalarInplace - + VecZnxAlloc + GGLWEExecLayoutFamily + GLWEKeyswitchFamily - + ScalarZnxAllocBytes - + VecZnxAllocBytes - + ScalarZnxAlloc + VecZnxEncodeCoeffsi64 - + MatZnxAlloc + VecZnxAutomorphismInplace; pub(crate) trait LWETestScratchFamily = TakeScalarZnxImpl @@ -62,6 +40,7 @@ where Module: LWETestModuleFamily, B: LWETestScratchFamily, { + let n: usize = module.n(); let basek: usize = 17; let sigma: f64 = 3.2; @@ -80,12 +59,12 @@ where let mut source_xe: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - LWEToGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k_ksk, rank) - | GLWECiphertext::from_lwe_scratch_space(module, basek, k_lwe_ct, k_glwe_ct, k_ksk, rank) - | GLWECiphertext::decrypt_scratch_space(module, basek, k_glwe_ct), + LWEToGLWESwitchingKey::encrypt_sk_scratch_space(module, n, basek, k_ksk, rank) + | GLWECiphertext::from_lwe_scratch_space(module, n, basek, k_lwe_ct, k_glwe_ct, k_ksk, rank) + | GLWECiphertext::decrypt_scratch_space(module, n, basek, k_glwe_ct), ); - let mut sk_glwe: GLWESecret> = GLWESecret::alloc(module, rank); + let mut sk_glwe: GLWESecret> = GLWESecret::alloc(n, rank); sk_glwe.fill_ternary_prob(0.5, &mut source_xs); let sk_glwe_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk_glwe); @@ -108,7 +87,7 @@ where sigma, ); - let mut ksk: LWEToGLWESwitchingKey> = LWEToGLWESwitchingKey::alloc(module, basek, k_ksk, lwe_ct.size(), rank); + let mut ksk: LWEToGLWESwitchingKey> = LWEToGLWESwitchingKey::alloc(n, basek, k_ksk, lwe_ct.size(), rank); ksk.encrypt_sk( module, @@ -120,30 +99,24 @@ where scratch.borrow(), ); - let mut glwe_ct: GLWECiphertext> = GLWECiphertext::alloc(module, basek, k_glwe_ct, rank); + let mut glwe_ct: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_glwe_ct, rank); let ksk_exec: LWEToGLWESwitchingKeyExec, B> = LWEToGLWESwitchingKeyExec::from(module, &ksk, scratch.borrow()); glwe_ct.from_lwe(module, &lwe_ct, &ksk_exec, scratch.borrow()); - let mut glwe_pt: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k_glwe_ct); + let mut glwe_pt: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_glwe_ct); glwe_ct.decrypt(module, &mut glwe_pt, &sk_glwe_exec, scratch.borrow()); assert_eq!(glwe_pt.data.at(0, 0)[0], lwe_pt.data.at(0, 0)[0]); } -#[test] -fn glwe_to_lwe() { - let log_n: usize = 5; - let module: Module = Module::::new(1 << log_n); - test_glwe_to_lwe(&module) -} - -fn test_glwe_to_lwe(module: &Module) +pub(crate) fn test_glwe_to_lwe(module: &Module) where Module: LWETestModuleFamily, B: LWETestScratchFamily, { + let n: usize = module.n(); let basek: usize = 17; let sigma: f64 = 3.2; @@ -162,12 +135,12 @@ where let mut source_xe: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - LWEToGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k_ksk, rank) - | LWECiphertext::from_glwe_scratch_space(module, basek, k_lwe_ct, k_glwe_ct, k_ksk, rank) - | GLWECiphertext::decrypt_scratch_space(module, basek, k_glwe_ct), + LWEToGLWESwitchingKey::encrypt_sk_scratch_space(module, n, basek, k_ksk, rank) + | LWECiphertext::from_glwe_scratch_space(module, n, basek, k_lwe_ct, k_glwe_ct, k_ksk, rank) + | GLWECiphertext::decrypt_scratch_space(module, n, basek, k_glwe_ct), ); - let mut sk_glwe: GLWESecret> = GLWESecret::alloc(module, rank); + let mut sk_glwe: GLWESecret> = GLWESecret::alloc(n, rank); sk_glwe.fill_ternary_prob(0.5, &mut source_xs); let sk_glwe_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk_glwe); @@ -176,10 +149,10 @@ where sk_lwe.fill_ternary_prob(0.5, &mut source_xs); let data: i64 = 17; - let mut glwe_pt: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k_glwe_ct); + let mut glwe_pt: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_glwe_ct); module.encode_coeff_i64(basek, &mut glwe_pt.data, 0, k_lwe_pt, 0, data, k_lwe_pt); - let mut glwe_ct = GLWECiphertext::alloc(module, basek, k_glwe_ct, rank); + let mut glwe_ct = GLWECiphertext::alloc(n, basek, k_glwe_ct, rank); glwe_ct.encrypt_sk( module, &glwe_pt, @@ -190,7 +163,7 @@ where scratch.borrow(), ); - let mut ksk: GLWEToLWESwitchingKey> = GLWEToLWESwitchingKey::alloc(module, basek, k_ksk, glwe_ct.size(), rank); + let mut ksk: GLWEToLWESwitchingKey> = GLWEToLWESwitchingKey::alloc(n, basek, k_ksk, glwe_ct.size(), rank); ksk.encrypt_sk( module, @@ -214,18 +187,12 @@ where assert_eq!(glwe_pt.data.at(0, 0)[0], lwe_pt.data.at(0, 0)[0]); } -#[test] -fn keyswitch() { - let log_n: usize = 5; - let module: Module = Module::::new(1 << log_n); - test_keyswitch(&module) -} - -fn test_keyswitch(module: &Module) +pub(crate) fn test_keyswitch(module: &Module) where Module: LWETestModuleFamily, B: LWETestScratchFamily, { + let n: usize = module.n(); let basek: usize = 17; let sigma: f64 = 3.2; @@ -241,8 +208,8 @@ where let mut source_xe: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - LWESwitchingKey::encrypt_sk_scratch_space(module, basek, k_ksk) - | LWECiphertext::keyswitch_scratch_space(module, basek, k_lwe_ct, k_lwe_ct, k_ksk), + LWESwitchingKey::encrypt_sk_scratch_space(module, n, basek, k_ksk) + | LWECiphertext::keyswitch_scratch_space(module, n, basek, k_lwe_ct, k_lwe_ct, k_ksk), ); let mut sk_lwe_in: LWESecret> = LWESecret::alloc(n_lwe_in); @@ -266,7 +233,7 @@ where sigma, ); - let mut ksk: LWESwitchingKey> = LWESwitchingKey::alloc(module, basek, k_ksk, lwe_ct_in.size()); + let mut ksk: LWESwitchingKey> = LWESwitchingKey::alloc(n, basek, k_ksk, lwe_ct_in.size()); ksk.encrypt_sk( module, diff --git a/core/src/lwe/tests/generic_serialization.rs b/core/src/lwe/tests/generic_serialization.rs new file mode 100644 index 0000000..24d2c35 --- /dev/null +++ b/core/src/lwe/tests/generic_serialization.rs @@ -0,0 +1,15 @@ +use backend::hal::tests::serialization::test_reader_writer_interface; + +use crate::{LWECiphertext, LWECiphertextCompressed}; + +#[test] +fn lwe_serialization() { + let original: LWECiphertext> = LWECiphertext::alloc(771, 12, 54); + test_reader_writer_interface(original); +} + +#[test] +fn lwe_serialization_compressed() { + let original: LWECiphertextCompressed> = LWECiphertextCompressed::alloc(12, 54); + test_reader_writer_interface(original); +} diff --git a/core/src/lwe/tests/mod.rs b/core/src/lwe/tests/mod.rs new file mode 100644 index 0000000..465f959 --- /dev/null +++ b/core/src/lwe/tests/mod.rs @@ -0,0 +1,4 @@ +mod generic_conversion; +mod generic_serialization; + +mod cpu_spqlios; diff --git a/core/src/scratch.rs b/core/src/scratch.rs index 43d18e4..12e42c0 100644 --- a/core/src/scratch.rs +++ b/core/src/scratch.rs @@ -1,6 +1,6 @@ use backend::hal::{ api::{TakeMatZnx, TakeScalarZnx, TakeSvpPPol, TakeVecZnx, TakeVecZnxDft, TakeVmpPMat}, - layouts::{Backend, DataRef, Module, Scratch}, + layouts::{Backend, DataRef, Scratch}, oep::{TakeMatZnxImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, TakeVmpPMatImpl}, }; @@ -16,15 +16,14 @@ pub trait TakeLike<'a, B: Backend, T> { } pub trait TakeGLWECt { - fn take_glwe_ct(&mut self, module: &Module, basek: usize, k: usize, rank: usize) - -> (GLWECiphertext<&mut [u8]>, &mut Self); + fn take_glwe_ct(&mut self, n: usize, basek: usize, k: usize, rank: usize) -> (GLWECiphertext<&mut [u8]>, &mut Self); } pub trait TakeGLWECtSlice { fn take_glwe_ct_slice( &mut self, size: usize, - module: &Module, + n: usize, basek: usize, k: usize, rank: usize, @@ -32,13 +31,13 @@ pub trait TakeGLWECtSlice { } pub trait TakeGLWEPt { - fn take_glwe_pt(&mut self, module: &Module, basek: usize, k: usize) -> (GLWEPlaintext<&mut [u8]>, &mut Self); + fn take_glwe_pt(&mut self, n: usize, basek: usize, k: usize) -> (GLWEPlaintext<&mut [u8]>, &mut Self); } pub trait TakeGGLWE { fn take_gglwe( &mut self, - module: &Module, + n: usize, basek: usize, k: usize, rows: usize, @@ -51,7 +50,7 @@ pub trait TakeGGLWE { pub trait TakeGGLWEExec { fn take_gglwe_exec( &mut self, - module: &Module, + n: usize, basek: usize, k: usize, rows: usize, @@ -64,7 +63,7 @@ pub trait TakeGGLWEExec { pub trait TakeGGSW { fn take_ggsw( &mut self, - module: &Module, + n: usize, basek: usize, k: usize, rows: usize, @@ -76,7 +75,7 @@ pub trait TakeGGSW { pub trait TakeGGSWExec { fn take_ggsw_exec( &mut self, - module: &Module, + n: usize, basek: usize, k: usize, rows: usize, @@ -86,21 +85,21 @@ pub trait TakeGGSWExec { } pub trait TakeGLWESecret { - fn take_glwe_secret(&mut self, module: &Module, rank: usize) -> (GLWESecret<&mut [u8]>, &mut Self); + fn take_glwe_secret(&mut self, n: usize, rank: usize) -> (GLWESecret<&mut [u8]>, &mut Self); } pub trait TakeGLWESecretExec { - fn take_glwe_secret_exec(&mut self, module: &Module, rank: usize) -> (GLWESecretExec<&mut [u8], B>, &mut Self); + fn take_glwe_secret_exec(&mut self, n: usize, rank: usize) -> (GLWESecretExec<&mut [u8], B>, &mut Self); } pub trait TakeGLWEPk { - fn take_glwe_pk(&mut self, module: &Module, basek: usize, k: usize, rank: usize) -> (GLWEPublicKey<&mut [u8]>, &mut Self); + fn take_glwe_pk(&mut self, n: usize, basek: usize, k: usize, rank: usize) -> (GLWEPublicKey<&mut [u8]>, &mut Self); } pub trait TakeGLWEPkExec { fn take_glwe_pk_exec( &mut self, - module: &Module, + n: usize, basek: usize, k: usize, rank: usize, @@ -110,7 +109,7 @@ pub trait TakeGLWEPkExec { pub trait TakeGLWESwitchingKey { fn take_glwe_switching_key( &mut self, - module: &Module, + n: usize, basek: usize, k: usize, rows: usize, @@ -123,7 +122,7 @@ pub trait TakeGLWESwitchingKey { pub trait TakeGLWESwitchingKeyExec { fn take_glwe_switching_key_exec( &mut self, - module: &Module, + n: usize, basek: usize, k: usize, rows: usize, @@ -136,7 +135,7 @@ pub trait TakeGLWESwitchingKeyExec { pub trait TakeTensorKey { fn take_tensor_key( &mut self, - module: &Module, + n: usize, basek: usize, k: usize, rows: usize, @@ -148,7 +147,7 @@ pub trait TakeTensorKey { pub trait TakeTensorKeyExec { fn take_tensor_key_exec( &mut self, - module: &Module, + n: usize, basek: usize, k: usize, rows: usize, @@ -160,7 +159,7 @@ pub trait TakeTensorKeyExec { pub trait TakeAutomorphismKey { fn take_automorphism_key( &mut self, - module: &Module, + n: usize, basek: usize, k: usize, rows: usize, @@ -172,7 +171,7 @@ pub trait TakeAutomorphismKey { pub trait TakeAutomorphismKeyExec { fn take_automorphism_key_exec( &mut self, - module: &Module, + n: usize, basek: usize, k: usize, rows: usize, @@ -183,16 +182,10 @@ pub trait TakeAutomorphismKeyExec { impl TakeGLWECt for Scratch where - Scratch: TakeVecZnx, + Scratch: TakeVecZnx, { - fn take_glwe_ct( - &mut self, - module: &Module, - basek: usize, - k: usize, - rank: usize, - ) -> (GLWECiphertext<&mut [u8]>, &mut Self) { - let (data, scratch) = self.take_vec_znx(module, rank + 1, k.div_ceil(basek)); + fn take_glwe_ct(&mut self, n: usize, basek: usize, k: usize, rank: usize) -> (GLWECiphertext<&mut [u8]>, &mut Self) { + let (data, scratch) = self.take_vec_znx(n, rank + 1, k.div_ceil(basek)); (GLWECiphertext { data, basek, k }, scratch) } } @@ -219,12 +212,12 @@ where impl TakeGLWECtSlice for Scratch where - Scratch: TakeVecZnx, + Scratch: TakeVecZnx, { fn take_glwe_ct_slice( &mut self, size: usize, - module: &Module, + n: usize, basek: usize, k: usize, rank: usize, @@ -232,7 +225,7 @@ where let mut scratch: &mut Scratch = self; let mut cts: Vec> = Vec::with_capacity(size); for _ in 0..size { - let (ct, new_scratch) = scratch.take_glwe_ct(module, basek, k, rank); + let (ct, new_scratch) = scratch.take_glwe_ct(n, basek, k, rank); scratch = new_scratch; cts.push(ct); } @@ -242,10 +235,10 @@ where impl TakeGLWEPt for Scratch where - Scratch: TakeVecZnx, + Scratch: TakeVecZnx, { - fn take_glwe_pt(&mut self, module: &Module, basek: usize, k: usize) -> (GLWEPlaintext<&mut [u8]>, &mut Self) { - let (data, scratch) = self.take_vec_znx(module, 1, k.div_ceil(basek)); + fn take_glwe_pt(&mut self, n: usize, basek: usize, k: usize) -> (GLWEPlaintext<&mut [u8]>, &mut Self) { + let (data, scratch) = self.take_vec_znx(n, 1, k.div_ceil(basek)); (GLWEPlaintext { data, basek, k }, scratch) } } @@ -272,11 +265,11 @@ where impl TakeGGLWE for Scratch where - Scratch: TakeMatZnx, + Scratch: TakeMatZnx, { fn take_gglwe( &mut self, - module: &Module, + n: usize, basek: usize, k: usize, rows: usize, @@ -285,7 +278,7 @@ where rank_out: usize, ) -> (GGLWECiphertext<&mut [u8]>, &mut Self) { let (data, scratch) = self.take_mat_znx( - module, + n, rows.div_ceil(digits), rank_in, rank_out + 1, @@ -337,7 +330,7 @@ where { fn take_gglwe_exec( &mut self, - module: &Module, + n: usize, basek: usize, k: usize, rows: usize, @@ -346,7 +339,7 @@ where rank_out: usize, ) -> (GGLWECiphertextExec<&mut [u8], B>, &mut Self) { let (data, scratch) = self.take_vmp_pmat( - module, + n, rows.div_ceil(digits), rank_in, rank_out + 1, @@ -394,11 +387,11 @@ where impl TakeGGSW for Scratch where - Scratch: TakeMatZnx, + Scratch: TakeMatZnx, { fn take_ggsw( &mut self, - module: &Module, + n: usize, basek: usize, k: usize, rows: usize, @@ -406,7 +399,7 @@ where rank: usize, ) -> (GGSWCiphertext<&mut [u8]>, &mut Self) { let (data, scratch) = self.take_mat_znx( - module, + n, rows.div_ceil(digits), rank + 1, rank + 1, @@ -458,7 +451,7 @@ where { fn take_ggsw_exec( &mut self, - module: &Module, + n: usize, basek: usize, k: usize, rows: usize, @@ -466,7 +459,7 @@ where rank: usize, ) -> (GGSWCiphertextExec<&mut [u8], B>, &mut Self) { let (data, scratch) = self.take_vmp_pmat( - module, + n, rows.div_ceil(digits), rank + 1, rank + 1, @@ -514,10 +507,10 @@ where impl TakeGLWEPk for Scratch where - Scratch: TakeVecZnx, + Scratch: TakeVecZnx, { - fn take_glwe_pk(&mut self, module: &Module, basek: usize, k: usize, rank: usize) -> (GLWEPublicKey<&mut [u8]>, &mut Self) { - let (data, scratch) = self.take_vec_znx(module, rank + 1, k.div_ceil(basek)); + fn take_glwe_pk(&mut self, n: usize, basek: usize, k: usize, rank: usize) -> (GLWEPublicKey<&mut [u8]>, &mut Self) { + let (data, scratch) = self.take_vec_znx(n, rank + 1, k.div_ceil(basek)); ( GLWEPublicKey { data, @@ -557,12 +550,12 @@ where { fn take_glwe_pk_exec( &mut self, - module: &Module, + n: usize, basek: usize, k: usize, rank: usize, ) -> (GLWEPublicKeyExec<&mut [u8], B>, &mut Self) { - let (data, scratch) = self.take_vec_znx_dft(module, rank + 1, k.div_ceil(basek)); + let (data, scratch) = self.take_vec_znx_dft(n, rank + 1, k.div_ceil(basek)); ( GLWEPublicKeyExec { data, @@ -598,10 +591,10 @@ where impl TakeGLWESecret for Scratch where - Scratch: TakeScalarZnx, + Scratch: TakeScalarZnx, { - fn take_glwe_secret(&mut self, module: &Module, rank: usize) -> (GLWESecret<&mut [u8]>, &mut Self) { - let (data, scratch) = self.take_scalar_znx(module, rank); + fn take_glwe_secret(&mut self, n: usize, rank: usize) -> (GLWESecret<&mut [u8]>, &mut Self) { + let (data, scratch) = self.take_scalar_znx(n, rank); ( GLWESecret { data, @@ -635,8 +628,8 @@ impl TakeGLWESecretExec for Scratch where Scratch: TakeSvpPPol, { - fn take_glwe_secret_exec(&mut self, module: &Module, rank: usize) -> (GLWESecretExec<&mut [u8], B>, &mut Self) { - let (data, scratch) = self.take_svp_ppol(module, rank); + fn take_glwe_secret_exec(&mut self, n: usize, rank: usize) -> (GLWESecretExec<&mut [u8], B>, &mut Self) { + let (data, scratch) = self.take_svp_ppol(n, rank); ( GLWESecretExec { data, @@ -668,11 +661,11 @@ where impl TakeGLWESwitchingKey for Scratch where - Scratch: TakeMatZnx, + Scratch: TakeMatZnx, { fn take_glwe_switching_key( &mut self, - module: &Module, + n: usize, basek: usize, k: usize, rows: usize, @@ -680,7 +673,7 @@ where rank_in: usize, rank_out: usize, ) -> (GLWESwitchingKey<&mut [u8]>, &mut Self) { - let (data, scratch) = self.take_gglwe(module, basek, k, rows, digits, rank_in, rank_out); + let (data, scratch) = self.take_gglwe(n, basek, k, rows, digits, rank_in, rank_out); ( GLWESwitchingKey { key: data, @@ -719,7 +712,7 @@ where { fn take_glwe_switching_key_exec( &mut self, - module: &Module, + n: usize, basek: usize, k: usize, rows: usize, @@ -727,7 +720,7 @@ where rank_in: usize, rank_out: usize, ) -> (GLWESwitchingKeyExec<&mut [u8], B>, &mut Self) { - let (data, scratch) = self.take_gglwe_exec(module, basek, k, rows, digits, rank_in, rank_out); + let (data, scratch) = self.take_gglwe_exec(n, basek, k, rows, digits, rank_in, rank_out); ( GLWESwitchingKeyExec { key: data, @@ -762,18 +755,18 @@ where impl TakeAutomorphismKey for Scratch where - Scratch: TakeMatZnx, + Scratch: TakeMatZnx, { fn take_automorphism_key( &mut self, - module: &Module, + n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize, ) -> (AutomorphismKey<&mut [u8]>, &mut Self) { - let (data, scratch) = self.take_glwe_switching_key(module, basek, k, rows, digits, rank, rank); + let (data, scratch) = self.take_glwe_switching_key(n, basek, k, rows, digits, rank, rank); (AutomorphismKey { key: data, p: 0 }, scratch) } } @@ -798,14 +791,14 @@ where { fn take_automorphism_key_exec( &mut self, - module: &Module, + n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize, ) -> (AutomorphismKeyExec<&mut [u8], B>, &mut Self) { - let (data, scratch) = self.take_glwe_switching_key_exec(module, basek, k, rows, digits, rank, rank); + let (data, scratch) = self.take_glwe_switching_key_exec(n, basek, k, rows, digits, rank, rank); (AutomorphismKeyExec { key: data, p: 0 }, scratch) } } @@ -826,11 +819,11 @@ where impl TakeTensorKey for Scratch where - Scratch: TakeMatZnx, + Scratch: TakeMatZnx, { fn take_tensor_key( &mut self, - module: &Module, + n: usize, basek: usize, k: usize, rows: usize, @@ -843,12 +836,12 @@ where let mut scratch: &mut Scratch = self; if pairs != 0 { - let (gglwe, s) = scratch.take_glwe_switching_key(module, basek, k, rows, digits, 1, rank); + let (gglwe, s) = scratch.take_glwe_switching_key(n, basek, k, rows, digits, 1, rank); scratch = s; keys.push(gglwe); } for _ in 1..pairs { - let (gglwe, s) = scratch.take_glwe_switching_key(module, basek, k, rows, digits, 1, rank); + let (gglwe, s) = scratch.take_glwe_switching_key(n, basek, k, rows, digits, 1, rank); scratch = s; keys.push(gglwe); } @@ -891,7 +884,7 @@ where { fn take_tensor_key_exec( &mut self, - module: &Module, + n: usize, basek: usize, k: usize, rows: usize, @@ -904,12 +897,12 @@ where let mut scratch: &mut Scratch = self; if pairs != 0 { - let (gglwe, s) = scratch.take_glwe_switching_key_exec(module, basek, k, rows, digits, 1, rank); + let (gglwe, s) = scratch.take_glwe_switching_key_exec(n, basek, k, rows, digits, 1, rank); scratch = s; keys.push(gglwe); } for _ in 1..pairs { - let (gglwe, s) = scratch.take_glwe_switching_key_exec(module, basek, k, rows, digits, 1, rank); + let (gglwe, s) = scratch.take_glwe_switching_key_exec(n, basek, k, rows, digits, 1, rank); scratch = s; keys.push(gglwe); }