diff --git a/poulpy-core/README.md b/poulpy-core/README.md index 07d5304..259988e 100644 --- a/poulpy-core/README.md +++ b/poulpy-core/README.md @@ -52,8 +52,8 @@ fn main() { // Scratch space let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::encrypt_sk_scratch_space(&module, n, base2k, ct.k()) - | GLWECiphertext::decrypt_scratch_space(&module, n, base2k, ct.k()), + GLWECiphertext::encrypt_sk_tmp_bytes(&module, n, base2k, ct.k()) + | GLWECiphertext::decrypt_tmp_bytes(&module, n, base2k, ct.k()), ); // Generate secret-key diff --git a/poulpy-core/benches/external_product_glwe_fft64.rs b/poulpy-core/benches/external_product_glwe_fft64.rs index 333d68f..4af5d1f 100644 --- a/poulpy-core/benches/external_product_glwe_fft64.rs +++ b/poulpy-core/benches/external_product_glwe_fft64.rs @@ -67,9 +67,9 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) { let pt_rgsw: ScalarZnx> = ScalarZnx::alloc(n.into(), 1); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSW::encrypt_sk_scratch_space(&module, &ggsw_layout) - | GLWE::encrypt_sk_scratch_space(&module, &glwe_in_layout) - | GLWE::external_product_scratch_space(&module, &glwe_out_layout, &glwe_in_layout, &ggsw_layout), + GGSW::encrypt_sk_tmp_bytes(&module, &ggsw_layout) + | GLWE::encrypt_sk_tmp_bytes(&module, &glwe_in_layout) + | GLWE::external_product_tmp_bytes(&module, &glwe_out_layout, &glwe_in_layout, &ggsw_layout), ); let mut source_xs = Source::new([0u8; 32]); @@ -167,9 +167,9 @@ fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) { let pt_rgsw: ScalarZnx> = ScalarZnx::alloc(n.into(), 1); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSW::encrypt_sk_scratch_space(&module, &ggsw_layout) - | GLWE::encrypt_sk_scratch_space(&module, &glwe_layout) - | GLWE::external_product_inplace_scratch_space(&module, &glwe_layout, &ggsw_layout), + GGSW::encrypt_sk_tmp_bytes(&module, &ggsw_layout) + | GLWE::encrypt_sk_tmp_bytes(&module, &glwe_layout) + | GLWE::external_product_inplace_tmp_bytes(&module, &glwe_layout, &ggsw_layout), ); let mut source_xs: Source = Source::new([0u8; 32]); diff --git a/poulpy-core/benches/keyswitch_glwe_fft64.rs b/poulpy-core/benches/keyswitch_glwe_fft64.rs index a1b5ea2..736b53c 100644 --- a/poulpy-core/benches/keyswitch_glwe_fft64.rs +++ b/poulpy-core/benches/keyswitch_glwe_fft64.rs @@ -67,9 +67,9 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { let mut ct_out: GLWE> = GLWE::alloc_from_infos(&glwe_out_layout); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, &gglwe_atk_layout) - | GLWE::encrypt_sk_scratch_space(&module, &glwe_in_layout) - | GLWE::keyswitch_scratch_space( + GLWESwitchingKey::encrypt_sk_tmp_bytes(&module, &gglwe_atk_layout) + | GLWE::encrypt_sk_tmp_bytes(&module, &glwe_in_layout) + | GLWE::keyswitch_tmp_bytes( &module, &glwe_out_layout, &glwe_in_layout, @@ -178,9 +178,9 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) { let mut ct: GLWE> = GLWE::alloc_from_infos(&glwe_layout); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, &gglwe_layout) - | GLWE::encrypt_sk_scratch_space(&module, &glwe_layout) - | GLWE::keyswitch_inplace_scratch_space(&module, &glwe_layout, &gglwe_layout), + GLWESwitchingKey::encrypt_sk_tmp_bytes(&module, &gglwe_layout) + | GLWE::encrypt_sk_tmp_bytes(&module, &glwe_layout) + | GLWE::keyswitch_inplace_tmp_bytes(&module, &glwe_layout, &gglwe_layout), ); let mut source_xs: Source = Source::new([0u8; 32]); diff --git a/poulpy-core/examples/encryption.rs b/poulpy-core/examples/encryption.rs index f59d5a1..d4b17b9 100644 --- a/poulpy-core/examples/encryption.rs +++ b/poulpy-core/examples/encryption.rs @@ -54,7 +54,7 @@ fn main() { // Scratch space let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWE::encrypt_sk_scratch_space(&module, &glwe_ct_infos) | GLWE::decrypt_scratch_space(&module, &glwe_ct_infos), + GLWE::encrypt_sk_tmp_bytes(&module, &glwe_ct_infos) | GLWE::decrypt_tmp_bytes(&module, &glwe_ct_infos), ); // Generate secret-key diff --git a/poulpy-core/src/automorphism/gglwe_atk.rs b/poulpy-core/src/automorphism/gglwe_atk.rs index e822b63..9650aa2 100644 --- a/poulpy-core/src/automorphism/gglwe_atk.rs +++ b/poulpy-core/src/automorphism/gglwe_atk.rs @@ -10,7 +10,7 @@ use poulpy_hal::{ use crate::layouts::{AutomorphismKey, GGLWEInfos, GLWE, prepared::AutomorphismKeyPrepared}; impl AutomorphismKey> { - pub fn automorphism_scratch_space( + pub fn automorphism_tmp_bytes( module: &Module, out_infos: &OUT, in_infos: &IN, @@ -22,7 +22,7 @@ impl AutomorphismKey> { KEY: GGLWEInfos, Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - GLWE::keyswitch_scratch_space( + GLWE::keyswitch_tmp_bytes( module, &out_infos.glwe_layout(), &in_infos.glwe_layout(), @@ -30,13 +30,13 @@ impl AutomorphismKey> { ) } - pub fn automorphism_inplace_scratch_space(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize + pub fn automorphism_inplace_tmp_bytes(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize where OUT: GGLWEInfos, KEY: GGLWEInfos, Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - AutomorphismKey::automorphism_scratch_space(module, out_infos, out_infos, key_infos) + AutomorphismKey::automorphism_tmp_bytes(module, out_infos, out_infos, key_infos) } } diff --git a/poulpy-core/src/automorphism/ggsw_ct.rs b/poulpy-core/src/automorphism/ggsw_ct.rs index e99b3d5..a3cef86 100644 --- a/poulpy-core/src/automorphism/ggsw_ct.rs +++ b/poulpy-core/src/automorphism/ggsw_ct.rs @@ -14,7 +14,7 @@ use crate::layouts::{ }; impl GGSW> { - pub fn automorphism_scratch_space( + pub fn automorphism_tmp_bytes( module: &Module, out_infos: &OUT, in_infos: &IN, @@ -31,17 +31,17 @@ impl GGSW> { { let out_size: usize = out_infos.size(); let ci_dft: usize = module.bytes_of_vec_znx_dft((key_infos.rank_out() + 1).into(), out_size); - let ks_internal: usize = GLWE::keyswitch_scratch_space( + let ks_internal: usize = GLWE::keyswitch_tmp_bytes( module, &out_infos.glwe_layout(), &in_infos.glwe_layout(), key_infos, ); - let expand: usize = GGSW::expand_row_scratch_space(module, out_infos, tsk_infos); + let expand: usize = GGSW::expand_row_tmp_bytes(module, out_infos, tsk_infos); ci_dft + (ks_internal | expand) } - pub fn automorphism_inplace_scratch_space( + pub fn automorphism_inplace_tmp_bytes( module: &Module, out_infos: &OUT, key_infos: &KEY, @@ -54,7 +54,7 @@ impl GGSW> { Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigBytesOf + VecZnxNormalizeTmpBytes + VecZnxBigNormalizeTmpBytes, { - GGSW::automorphism_scratch_space(module, out_infos, out_infos, key_infos, tsk_infos) + GGSW::automorphism_tmp_bytes(module, out_infos, out_infos, key_infos, tsk_infos) } } @@ -115,7 +115,7 @@ impl GGSW { self.rank(), tensor_key.rank_out() ); - assert!(scratch.available() >= GGSW::automorphism_scratch_space(module, self, lhs, auto_key, tensor_key)) + assert!(scratch.available() >= GGSW::automorphism_tmp_bytes(module, self, lhs, auto_key, tensor_key)) }; // Keyswitch the j-th row of the col 0 diff --git a/poulpy-core/src/automorphism/glwe_ct.rs b/poulpy-core/src/automorphism/glwe_ct.rs index 5d30917..0c8b581 100644 --- a/poulpy-core/src/automorphism/glwe_ct.rs +++ b/poulpy-core/src/automorphism/glwe_ct.rs @@ -11,7 +11,7 @@ use poulpy_hal::{ use crate::layouts::{GGLWEInfos, GLWE, GLWEInfos, LWEInfos, prepared::AutomorphismKeyPrepared}; impl GLWE> { - pub fn automorphism_scratch_space( + pub fn automorphism_tmp_bytes( module: &Module, out_infos: &OUT, in_infos: &IN, @@ -23,16 +23,16 @@ impl GLWE> { KEY: GGLWEInfos, Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - Self::keyswitch_scratch_space(module, out_infos, in_infos, key_infos) + Self::keyswitch_tmp_bytes(module, out_infos, in_infos, key_infos) } - pub fn automorphism_inplace_scratch_space(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize + pub fn automorphism_inplace_tmp_bytes(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize where OUT: GLWEInfos, KEY: GGLWEInfos, Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - Self::keyswitch_inplace_scratch_space(module, out_infos, key_infos) + Self::keyswitch_inplace_tmp_bytes(module, out_infos, key_infos) } } diff --git a/poulpy-core/src/conversion/gglwe_to_ggsw.rs b/poulpy-core/src/conversion/gglwe_to_ggsw.rs new file mode 100644 index 0000000..a7b86fa --- /dev/null +++ b/poulpy-core/src/conversion/gglwe_to_ggsw.rs @@ -0,0 +1,279 @@ +use poulpy_hal::{ + api::{ + ModuleN, ScratchAvailable, ScratchTakeBasic, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxCopy, VecZnxDftAddInplace, + VecZnxDftApply, VecZnxDftBytesOf, VecZnxDftCopy, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeTmpBytes, + VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + }, + layouts::{Backend, DataMut, Module, Scratch, VmpPMat, ZnxInfos}, +}; + +use crate::{ + ScratchTakeCore, + layouts::{ + GGLWE, GGLWEInfos, GGLWEToRef, GGSW, GGSWInfos, GGSWToMut, GLWEInfos, LWEInfos, + prepared::{TensorKeyPrepared, TensorKeyPreparedToRef}, + }, + operations::GLWEOperations, +}; + +impl GGLWE> { + pub fn from_gglw_tmp_bytes(module: &M, res_infos: &R, tsk_infos: &A) -> usize + where + M: GGSWFromGGLWE, + R: GGSWInfos, + A: GGLWEInfos, + { + module.ggsw_from_gglwe_tmp_bytes(res_infos, tsk_infos) + } +} + +impl GGSW { + pub fn from_gglwe(&mut self, module: &M, gglwe: &G, tsk: &T, scratch: &mut Scratch) + where + M: GGSWFromGGLWE, + G: GGLWEToRef, + T: TensorKeyPreparedToRef, + Scratch: ScratchTakeCore, + { + module.ggsw_from_gglwe(self, gglwe, tsk, scratch); + } +} + +impl GGSWFromGGLWE for Module where Self: GGSWExpandRows + VecZnxCopy {} + +pub trait GGSWFromGGLWE +where + Self: GGSWExpandRows + VecZnxCopy, +{ + fn ggsw_from_gglwe_tmp_bytes(&self, res_infos: &R, tsk_infos: &A) -> usize + where + R: GGSWInfos, + A: GGLWEInfos, + { + self.ggsw_expand_rows_tmp_bytes(res_infos, tsk_infos) + } + + fn ggsw_from_gglwe(&self, res: &mut R, a: &A, tsk: &T, scratch: &mut Scratch) + where + R: GGSWToMut, + A: GGLWEToRef, + T: TensorKeyPreparedToRef, + Scratch: ScratchTakeCore, + { + let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); + let a: &GGLWE<&[u8]> = &a.to_ref(); + let tsk: &TensorKeyPrepared<&[u8], BE> = &tsk.to_ref(); + + assert_eq!(res.rank(), a.rank_out()); + assert_eq!(res.dnum(), a.dnum()); + assert_eq!(res.n(), self.n() as u32); + assert_eq!(a.n(), self.n() as u32); + assert_eq!(tsk.n(), self.n() as u32); + + for row in 0..res.dnum().into() { + res.at_mut(row, 0).copy(self, &a.at(row, 0)); + } + + self.ggsw_expand_row(res, tsk, scratch); + } +} + +impl GGSWExpandRows for Module where + Self: Sized + + ModuleN + + VecZnxDftBytesOf + + VmpApplyDftToDftTmpBytes + + VecZnxBigBytesOf + + VecZnxNormalizeTmpBytes + + VecZnxDftBytesOf + + VmpApplyDftToDftTmpBytes + + VecZnxBigBytesOf + + VecZnxNormalizeTmpBytes + + VecZnxDftApply + + VecZnxDftCopy + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VecZnxDftAddInplace + + VecZnxBigNormalize + + VecZnxIdftApplyTmpA + + VecZnxNormalize +{ +} + +pub(crate) trait GGSWExpandRows +where + Self: Sized + + ModuleN + + VecZnxDftBytesOf + + VmpApplyDftToDftTmpBytes + + VecZnxBigBytesOf + + VecZnxNormalizeTmpBytes + + VecZnxDftBytesOf + + VmpApplyDftToDftTmpBytes + + VecZnxBigBytesOf + + VecZnxNormalizeTmpBytes + + VecZnxDftApply + + VecZnxDftCopy + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VecZnxDftAddInplace + + VecZnxBigNormalize + + VecZnxIdftApplyTmpA + + VecZnxNormalize, +{ + fn ggsw_expand_rows_tmp_bytes(&self, res_infos: &R, tsk_infos: &A) -> usize + where + R: GGSWInfos, + A: GGLWEInfos, + { + let tsk_size: usize = tsk_infos.k().div_ceil(tsk_infos.base2k()) as usize; + let size_in: usize = res_infos + .k() + .div_ceil(tsk_infos.base2k()) + .div_ceil(tsk_infos.dsize().into()) as usize; + + let tmp_dft_i: usize = self.bytes_of_vec_znx_dft((tsk_infos.rank_out() + 1).into(), tsk_size); + let tmp_a: usize = self.bytes_of_vec_znx_dft(1, size_in); + let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes( + tsk_size, + size_in, + size_in, + (tsk_infos.rank_in()).into(), // Verify if rank+1 + (tsk_infos.rank_out()).into(), // Verify if rank+1 + tsk_size, + ); + let tmp_idft: usize = self.bytes_of_vec_znx_big(1, tsk_size); + let norm: usize = self.vec_znx_normalize_tmp_bytes(); + + tmp_dft_i + ((tmp_a + vmp) | (tmp_idft + norm)) + } + + fn ggsw_expand_row(&self, res: &mut R, tsk: &T, scratch: &mut Scratch) + where + R: GGSWToMut, + T: TensorKeyPreparedToRef, + Scratch: ScratchTakeCore, + { + let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); + let tsk: &TensorKeyPrepared<&[u8], BE> = &tsk.to_ref(); + + let basek_in: usize = res.base2k().into(); + let basek_tsk: usize = tsk.base2k().into(); + + assert!(scratch.available() >= self.ggsw_expand_rows_tmp_bytes(res, tsk)); + + let rank: usize = res.rank().into(); + let cols: usize = rank + 1; + + let a_size: usize = (res.size() * basek_in).div_ceil(basek_tsk); + + // Keyswitch the j-th row of the col 0 + for row_i in 0..res.dnum().into() { + let a = &res.at(row_i, 0).data; + + // Pre-compute DFT of (a0, a1, a2) + let (mut ci_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, a_size); + + if basek_in == basek_tsk { + for i in 0..cols { + self.vec_znx_dft_apply(1, 0, &mut ci_dft, i, a, i); + } + } else { + let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(self, 1, a_size); + for i in 0..cols { + self.vec_znx_normalize(basek_tsk, &mut a_conv, 0, basek_in, a, i, scratch_2); + self.vec_znx_dft_apply(1, 0, &mut ci_dft, i, &a_conv, 0); + } + } + + for col_j in 1..cols { + // Example for rank 3: + // + // Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is + // actually composed of that many dnum and we focus on a specific row here + // implicitely given ci_dft. + // + // # Input + // + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) + // col 1: (0, 0, 0, 0) + // col 2: (0, 0, 0, 0) + // col 3: (0, 0, 0, 0) + // + // # Output + // + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) + // col 1: (-(b0s0 + b1s1 + b2s2) , b0 + M[i], b1 , b2 ) + // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M[i], c2 ) + // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M[i]) + + let dsize: usize = tsk.dsize().into(); + + let (mut tmp_dft_i, scratch_2) = scratch_1.take_vec_znx_dft(self, cols, tsk.size()); + let (mut tmp_a, scratch_3) = scratch_2.take_vec_znx_dft(self, 1, ci_dft.size().div_ceil(dsize)); + + { + // Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2 + // + // # Example for col=1 + // + // a0 * (-(f0s0 + f1s1 + f1s2) + s0^2, f0, f1, f2) = (-(a0f0s0 + a0f1s1 + a0f1s2) + a0s0^2, a0f0, a0f1, a0f2) + // + + // a1 * (-(g0s0 + g1s1 + g1s2) + s0s1, g0, g1, g2) = (-(a1g0s0 + a1g1s1 + a1g1s2) + a1s0s1, a1g0, a1g1, a1g2) + // + + // a2 * (-(h0s0 + h1s1 + h1s2) + s0s2, h0, h1, h2) = (-(a2h0s0 + a2h1s1 + a2h1s2) + a2s0s2, a2h0, a2h1, a2h2) + // = + // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2) + for col_i in 1..cols { + let pmat: &VmpPMat<&[u8], BE> = &tsk.at(col_i - 1, col_j - 1).key.data; // Selects Enc(s[i]s[j]) + + // Extracts a[i] and multipies with Enc(s[i]s[j]) + for di in 0..dsize { + tmp_a.set_size((ci_dft.size() + di) / dsize); + + // Small optimization for dsize > 2 + // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then + // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}. + // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. + // It is possible to further ignore the last dsize-1 limbs, but this introduce + // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same + // noise is kept with respect to the ideal functionality. + tmp_dft_i.set_size(tsk.size() - ((dsize - di) as isize - 2).max(0) as usize); + + self.vec_znx_dft_copy(dsize, dsize - 1 - di, &mut tmp_a, 0, &ci_dft, col_i); + if di == 0 && col_i == 1 { + self.vmp_apply_dft_to_dft(&mut tmp_dft_i, &tmp_a, pmat, scratch_3); + } else { + self.vmp_apply_dft_to_dft_add(&mut tmp_dft_i, &tmp_a, pmat, di, scratch_3); + } + } + } + } + + // Adds -(sum a[i] * s[i]) + m) on the i-th column of tmp_idft_i + // + // (-(x0s0 + x1s1 + x2s2) + a0s0s0 + a1s0s1 + a2s0s2, x0, x1, x2) + // + + // (0, -(a0s0 + a1s1 + a2s2) + M[i], 0, 0) + // = + // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0 -(a0s0 + a1s1 + a2s2) + M[i], x1, x2) + // = + // (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2) + self.vec_znx_dft_add_inplace(&mut tmp_dft_i, col_j, &ci_dft, 0); + let (mut tmp_idft, scratch_3) = scratch_2.take_vec_znx_big(self, 1, tsk.size()); + for i in 0..cols { + self.vec_znx_idft_apply_tmpa(&mut tmp_idft, 0, &mut tmp_dft_i, i); + self.vec_znx_big_normalize( + basek_in, + &mut res.at_mut(row_i, col_j).data, + i, + basek_tsk, + &tmp_idft, + 0, + scratch_3, + ); + } + } + } + } +} diff --git a/poulpy-core/src/conversion/glwe_to_lwe.rs b/poulpy-core/src/conversion/glwe_to_lwe.rs index c023a43..b6c6ed1 100644 --- a/poulpy-core/src/conversion/glwe_to_lwe.rs +++ b/poulpy-core/src/conversion/glwe_to_lwe.rs @@ -10,7 +10,7 @@ use poulpy_hal::{ use crate::layouts::{GGLWEInfos, GLWE, GLWEInfos, GLWELayout, LWE, LWEInfos, Rank, prepared::GLWEToLWESwitchingKeyPrepared}; impl LWE> { - pub fn from_glwe_scratch_space( + pub fn from_glwe_tmp_bytes( module: &Module, lwe_infos: &OUT, glwe_infos: &IN, @@ -34,7 +34,7 @@ impl LWE> { lwe_infos.base2k(), lwe_infos.k(), 1u32.into(), - ) + GLWE::keyswitch_scratch_space(module, &glwe_layout, glwe_infos, key_infos) + ) + GLWE::keyswitch_tmp_bytes(module, &glwe_layout, glwe_infos, key_infos) } } diff --git a/poulpy-core/src/conversion/lwe_to_glwe.rs b/poulpy-core/src/conversion/lwe_to_glwe.rs index e3ba833..c4a3b88 100644 --- a/poulpy-core/src/conversion/lwe_to_glwe.rs +++ b/poulpy-core/src/conversion/lwe_to_glwe.rs @@ -10,7 +10,7 @@ use poulpy_hal::{ use crate::layouts::{GGLWEInfos, GLWE, GLWEInfos, GLWELayout, LWE, LWEInfos, prepared::LWEToGLWESwitchingKeyPrepared}; impl GLWE> { - pub fn from_lwe_scratch_space( + pub fn from_lwe_tmp_bytes( module: &Module, glwe_infos: &OUT, lwe_infos: &IN, @@ -28,7 +28,7 @@ impl GLWE> { lwe_infos.k().max(glwe_infos.k()), 1u32.into(), ); - let ks: usize = GLWE::keyswitch_inplace_scratch_space(module, glwe_infos, key_infos); + let ks: usize = GLWE::keyswitch_inplace_tmp_bytes(module, glwe_infos, key_infos); if lwe_infos.base2k() == key_infos.base2k() { ct + ks } else { diff --git a/poulpy-core/src/conversion/mod.rs b/poulpy-core/src/conversion/mod.rs index 090208b..9771531 100644 --- a/poulpy-core/src/conversion/mod.rs +++ b/poulpy-core/src/conversion/mod.rs @@ -1,2 +1,5 @@ +mod gglwe_to_ggsw; mod glwe_to_lwe; mod lwe_to_glwe; + +pub use gglwe_to_ggsw::*; diff --git a/poulpy-core/src/decryption/glwe_ct.rs b/poulpy-core/src/decryption/glwe_ct.rs index fdc040a..4306d33 100644 --- a/poulpy-core/src/decryption/glwe_ct.rs +++ b/poulpy-core/src/decryption/glwe_ct.rs @@ -9,7 +9,7 @@ use poulpy_hal::{ use crate::layouts::{GLWE, GLWEInfos, GLWEPlaintext, LWEInfos, prepared::GLWESecretPrepared}; impl GLWE> { - pub fn decrypt_scratch_space(module: &Module, infos: &A) -> usize + pub fn decrypt_tmp_bytes(module: &Module, infos: &A) -> usize where A: GLWEInfos, Module: VecZnxDftBytesOf + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf, diff --git a/poulpy-core/src/encryption/compressed/gglwe_atk.rs b/poulpy-core/src/encryption/compressed/gglwe_atk.rs index 17cd17f..f0afcae 100644 --- a/poulpy-core/src/encryption/compressed/gglwe_atk.rs +++ b/poulpy-core/src/encryption/compressed/gglwe_atk.rs @@ -13,13 +13,13 @@ use crate::{ }; impl AutomorphismKeyCompressed> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize + pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GGLWEInfos, Module: VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes + SvpPPolBytesOf, { assert_eq!(module.n() as u32, infos.n()); - GLWESwitchingKeyCompressed::encrypt_sk_scratch_space(module, infos) + GLWESecret::bytes_of(infos.n(), infos.rank_out()) + GLWESwitchingKeyCompressed::encrypt_sk_tmp_bytes(module, infos) + GLWESecret::bytes_of(infos.n(), infos.rank_out()) } } @@ -63,10 +63,10 @@ where assert_eq!(res.rank_out(), res.rank_in()); assert_eq!(sk.rank(), res.rank_out()); assert!( - scratch.available() >= AutomorphismKeyCompressed::encrypt_sk_scratch_space(self, res), - "scratch.available(): {} < AutomorphismKey::encrypt_sk_scratch_space: {}", + scratch.available() >= AutomorphismKeyCompressed::encrypt_sk_tmp_bytes(self, res), + "scratch.available(): {} < AutomorphismKey::encrypt_sk_tmp_bytes: {}", scratch.available(), - AutomorphismKeyCompressed::encrypt_sk_scratch_space(self, res) + AutomorphismKeyCompressed::encrypt_sk_tmp_bytes(self, res) ) } diff --git a/poulpy-core/src/encryption/compressed/gglwe_ct.rs b/poulpy-core/src/encryption/compressed/gglwe_ct.rs index 7757e95..b67dc88 100644 --- a/poulpy-core/src/encryption/compressed/gglwe_ct.rs +++ b/poulpy-core/src/encryption/compressed/gglwe_ct.rs @@ -34,12 +34,12 @@ impl GGLWECompressed { } impl GGLWECompressed> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize + pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GGLWEInfos, Module: VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes, { - GGLWE::encrypt_sk_scratch_space(module, infos) + GGLWE::encrypt_sk_tmp_bytes(module, infos) } } @@ -106,10 +106,10 @@ where assert_eq!(res.n(), sk.n()); assert_eq!(pt.n() as u32, sk.n()); assert!( - scratch.available() >= GGLWECompressed::encrypt_sk_scratch_space(self, res), - "scratch.available: {} < GGLWECiphertext::encrypt_sk_scratch_space: {}", + scratch.available() >= GGLWECompressed::encrypt_sk_tmp_bytes(self, res), + "scratch.available: {} < GGLWECiphertext::encrypt_sk_tmp_bytes: {}", scratch.available(), - GGLWECompressed::encrypt_sk_scratch_space(self, res) + GGLWECompressed::encrypt_sk_tmp_bytes(self, res) ); assert!( res.dnum().0 * res.dsize().0 * res.base2k().0 <= res.k().0, diff --git a/poulpy-core/src/encryption/compressed/gglwe_ksk.rs b/poulpy-core/src/encryption/compressed/gglwe_ksk.rs index 2c0266a..5c2c8c2 100644 --- a/poulpy-core/src/encryption/compressed/gglwe_ksk.rs +++ b/poulpy-core/src/encryption/compressed/gglwe_ksk.rs @@ -14,12 +14,12 @@ use crate::{ }; impl GLWESwitchingKeyCompressed> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize + pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GGLWEInfos, Module: VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes + SvpPPolBytesOf, { - (GGLWE::encrypt_sk_scratch_space(module, infos) | ScalarZnx::bytes_of(module.n(), 1)) + (GGLWE::encrypt_sk_tmp_bytes(module, infos) | ScalarZnx::bytes_of(module.n(), 1)) + ScalarZnx::bytes_of(module.n(), infos.rank_in().into()) + GLWESecretPrepared::bytes_of(module, infos.rank_out()) } @@ -91,10 +91,10 @@ where assert!(sk_in.n().0 <= self.n() as u32); assert!(sk_out.n().0 <= self.n() as u32); assert!( - scratch.available() >= GLWESwitchingKey::encrypt_sk_scratch_space(self, res), - "scratch.available()={} < GLWESwitchingKey::encrypt_sk_scratch_space={}", + scratch.available() >= GLWESwitchingKey::encrypt_sk_tmp_bytes(self, res), + "scratch.available()={} < GLWESwitchingKey::encrypt_sk_tmp_bytes={}", scratch.available(), - GLWESwitchingKey::encrypt_sk_scratch_space(self, res) + GLWESwitchingKey::encrypt_sk_tmp_bytes(self, res) ) } diff --git a/poulpy-core/src/encryption/compressed/gglwe_tsk.rs b/poulpy-core/src/encryption/compressed/gglwe_tsk.rs index 6115fdd..2beaa4b 100644 --- a/poulpy-core/src/encryption/compressed/gglwe_tsk.rs +++ b/poulpy-core/src/encryption/compressed/gglwe_tsk.rs @@ -16,12 +16,12 @@ use crate::{ }; impl TensorKeyCompressed> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize + pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GGLWEInfos, Module: SvpPPolBytesOf + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes + VecZnxBigBytesOf, { - TensorKey::encrypt_sk_scratch_space(module, infos) + TensorKey::encrypt_sk_tmp_bytes(module, infos) } } diff --git a/poulpy-core/src/encryption/compressed/ggsw_ct.rs b/poulpy-core/src/encryption/compressed/ggsw_ct.rs index 8a8949b..567f04f 100644 --- a/poulpy-core/src/encryption/compressed/ggsw_ct.rs +++ b/poulpy-core/src/encryption/compressed/ggsw_ct.rs @@ -14,12 +14,12 @@ use crate::{ }; impl GGSWCompressed> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize + pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GGSWInfos, Module: VecZnxNormalizeTmpBytes + VecZnxDftBytesOf, { - GGSW::encrypt_sk_scratch_space(module, infos) + GGSW::encrypt_sk_tmp_bytes(module, infos) } } diff --git a/poulpy-core/src/encryption/compressed/glwe_ct.rs b/poulpy-core/src/encryption/compressed/glwe_ct.rs index 1c8efe6..f04a07a 100644 --- a/poulpy-core/src/encryption/compressed/glwe_ct.rs +++ b/poulpy-core/src/encryption/compressed/glwe_ct.rs @@ -14,12 +14,12 @@ use crate::{ }; impl GLWECompressed> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize + pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GLWEInfos, Module: VecZnxNormalizeTmpBytes + VecZnxDftBytesOf, { - GLWE::encrypt_sk_scratch_space(module, infos) + GLWE::encrypt_sk_tmp_bytes(module, infos) } } diff --git a/poulpy-core/src/encryption/gglwe_atk.rs b/poulpy-core/src/encryption/gglwe_atk.rs index 8e0ecb7..6536c7e 100644 --- a/poulpy-core/src/encryption/gglwe_atk.rs +++ b/poulpy-core/src/encryption/gglwe_atk.rs @@ -14,7 +14,7 @@ use crate::layouts::{ }; impl AutomorphismKey> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize + pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GGLWEInfos, Module: SvpPPolBytesOf + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes, @@ -24,10 +24,10 @@ impl AutomorphismKey> { infos.rank_out(), "rank_in != rank_out is not supported for GGLWEAutomorphismKey" ); - GLWESwitchingKey::encrypt_sk_scratch_space(module, infos) + GLWESecret::bytes_of_from_infos(module, &infos.glwe_layout()) + GLWESwitchingKey::encrypt_sk_tmp_bytes(module, infos) + GLWESecret::bytes_of_from_infos(module, &infos.glwe_layout()) } - pub fn encrypt_pk_scratch_space(module: &Module, _infos: &A) -> usize + pub fn encrypt_pk_tmp_bytes(module: &Module, _infos: &A) -> usize where A: GGLWEInfos, { @@ -36,7 +36,7 @@ impl AutomorphismKey> { _infos.rank_out(), "rank_in != rank_out is not supported for GGLWEAutomorphismKey" ); - GLWESwitchingKey::encrypt_pk_scratch_space(module, _infos) + GLWESwitchingKey::encrypt_pk_tmp_bytes(module, _infos) } } @@ -119,10 +119,10 @@ where assert_eq!(res.rank_out(), res.rank_in()); assert_eq!(sk.rank(), res.rank_out()); assert!( - scratch.available() >= AutomorphismKey::encrypt_sk_scratch_space(self, res), - "scratch.available(): {} < AutomorphismKey::encrypt_sk_scratch_space: {:?}", + scratch.available() >= AutomorphismKey::encrypt_sk_tmp_bytes(self, res), + "scratch.available(): {} < AutomorphismKey::encrypt_sk_tmp_bytes: {:?}", scratch.available(), - AutomorphismKey::encrypt_sk_scratch_space(self, res) + AutomorphismKey::encrypt_sk_tmp_bytes(self, res) ) } diff --git a/poulpy-core/src/encryption/gglwe_ct.rs b/poulpy-core/src/encryption/gglwe_ct.rs index d99ca5f..d333892 100644 --- a/poulpy-core/src/encryption/gglwe_ct.rs +++ b/poulpy-core/src/encryption/gglwe_ct.rs @@ -13,16 +13,16 @@ use crate::{ }; impl GGLWE> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize + pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GGLWEInfos, Module: VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes, { - GLWE::encrypt_sk_scratch_space(module, &infos.glwe_layout()) + GLWE::encrypt_sk_tmp_bytes(module, &infos.glwe_layout()) + (GLWEPlaintext::bytes_of_from_infos(module, &infos.glwe_layout()) | module.vec_znx_normalize_tmp_bytes()) } - pub fn encrypt_pk_scratch_space(_module: &Module, _infos: &A) -> usize + pub fn encrypt_pk_tmp_bytes(_module: &Module, _infos: &A) -> usize where A: GGLWEInfos, { @@ -88,12 +88,12 @@ where assert_eq!(res.n(), sk.n()); assert_eq!(pt.n() as u32, sk.n()); assert!( - scratch.available() >= GGLWE::encrypt_sk_scratch_space(self, res), - "scratch.available: {} < GGLWECiphertext::encrypt_sk_scratch_space(self, res.rank()={}, res.size()={}): {}", + scratch.available() >= GGLWE::encrypt_sk_tmp_bytes(self, res), + "scratch.available: {} < GGLWECiphertext::encrypt_sk_tmp_bytes(self, res.rank()={}, res.size()={}): {}", scratch.available(), res.rank_out(), res.size(), - GGLWE::encrypt_sk_scratch_space(self, res) + GGLWE::encrypt_sk_tmp_bytes(self, res) ); assert!( res.dnum().0 * res.dsize().0 * res.base2k().0 <= res.k().0, diff --git a/poulpy-core/src/encryption/gglwe_ksk.rs b/poulpy-core/src/encryption/gglwe_ksk.rs index 1210aa7..8c3a70f 100644 --- a/poulpy-core/src/encryption/gglwe_ksk.rs +++ b/poulpy-core/src/encryption/gglwe_ksk.rs @@ -13,21 +13,21 @@ use crate::layouts::{ }; impl GLWESwitchingKey> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize + pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GGLWEInfos, Module: SvpPPolBytesOf + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes, { - (GGLWE::encrypt_sk_scratch_space(module, infos) | ScalarZnx::bytes_of(module.n(), 1)) + (GGLWE::encrypt_sk_tmp_bytes(module, infos) | ScalarZnx::bytes_of(module.n(), 1)) + ScalarZnx::bytes_of(module.n(), infos.rank_in().into()) + GLWESecretPrepared::bytes_of_from_infos(module, &infos.glwe_layout()) } - pub fn encrypt_pk_scratch_space(module: &Module, _infos: &A) -> usize + pub fn encrypt_pk_tmp_bytes(module: &Module, _infos: &A) -> usize where A: GGLWEInfos, { - GGLWE::encrypt_pk_scratch_space(module, _infos) + GGLWE::encrypt_pk_tmp_bytes(module, _infos) } } @@ -66,10 +66,10 @@ impl GLWESwitchingKey { assert!(sk_in.n().0 <= module.n() as u32); assert!(sk_out.n().0 <= module.n() as u32); assert!( - scratch.available() >= GLWESwitchingKey::encrypt_sk_scratch_space(module, self), - "scratch.available()={} < GLWESwitchingKey::encrypt_sk_scratch_space={}", + scratch.available() >= GLWESwitchingKey::encrypt_sk_tmp_bytes(module, self), + "scratch.available()={} < GLWESwitchingKey::encrypt_sk_tmp_bytes={}", scratch.available(), - GLWESwitchingKey::encrypt_sk_scratch_space(module, self) + GLWESwitchingKey::encrypt_sk_tmp_bytes(module, self) ) } diff --git a/poulpy-core/src/encryption/gglwe_tsk.rs b/poulpy-core/src/encryption/gglwe_tsk.rs index 125cdf5..672808b 100644 --- a/poulpy-core/src/encryption/gglwe_tsk.rs +++ b/poulpy-core/src/encryption/gglwe_tsk.rs @@ -14,7 +14,7 @@ use crate::layouts::{ }; impl TensorKey> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize + pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GGLWEInfos, Module: SvpPPolBytesOf + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes + VecZnxBigBytesOf, @@ -24,7 +24,7 @@ impl TensorKey> { + module.bytes_of_vec_znx_big(1, 1) + module.bytes_of_vec_znx_dft(1, 1) + GLWESecret::bytes_of(Degree(module.n() as u32), Rank(1)) - + GLWESwitchingKey::encrypt_sk_scratch_space(module, infos) + + GLWESwitchingKey::encrypt_sk_tmp_bytes(module, infos) } } diff --git a/poulpy-core/src/encryption/ggsw_ct.rs b/poulpy-core/src/encryption/ggsw_ct.rs index 23443a0..b044ae3 100644 --- a/poulpy-core/src/encryption/ggsw_ct.rs +++ b/poulpy-core/src/encryption/ggsw_ct.rs @@ -14,13 +14,13 @@ use crate::{ }; impl GGSW> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize + pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GGSWInfos, Module: VecZnxNormalizeTmpBytes + VecZnxDftBytesOf, { let size = infos.size(); - GLWE::encrypt_sk_scratch_space(module, &infos.glwe_layout()) + GLWE::encrypt_sk_tmp_bytes(module, &infos.glwe_layout()) + VecZnx::bytes_of(module.n(), (infos.rank() + 1).into(), size) + VecZnx::bytes_of(module.n(), 1, size) + module.bytes_of_vec_znx_dft((infos.rank() + 1).into(), size) diff --git a/poulpy-core/src/encryption/glwe_ct.rs b/poulpy-core/src/encryption/glwe_ct.rs index c5ed129..16bbadf 100644 --- a/poulpy-core/src/encryption/glwe_ct.rs +++ b/poulpy-core/src/encryption/glwe_ct.rs @@ -19,7 +19,7 @@ use crate::{ }; impl GLWE> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize + pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GLWEInfos, Module: VecZnxNormalizeTmpBytes + VecZnxDftBytesOf, @@ -28,7 +28,7 @@ impl GLWE> { assert_eq!(module.n() as u32, infos.n()); module.vec_znx_normalize_tmp_bytes() + 2 * VecZnx::bytes_of(module.n(), 1, size) + module.bytes_of_vec_znx_dft(1, size) } - pub fn encrypt_pk_scratch_space(module: &Module, infos: &A) -> usize + pub fn encrypt_pk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GLWEInfos, Module: VecZnxDftBytesOf + SvpPPolBytesOf + VecZnxBigBytesOf + VecZnxNormalizeTmpBytes, @@ -147,10 +147,10 @@ where assert_eq!(sk.n(), self.n() as u32); assert_eq!(pt.n(), self.n() as u32); assert!( - scratch.available() >= GLWE::encrypt_sk_scratch_space(self, &res), - "scratch.available(): {} < GLWECiphertext::encrypt_sk_scratch_space: {}", + scratch.available() >= GLWE::encrypt_sk_tmp_bytes(self, &res), + "scratch.available(): {} < GLWECiphertext::encrypt_sk_tmp_bytes: {}", scratch.available(), - GLWE::encrypt_sk_scratch_space(self, &res) + GLWE::encrypt_sk_tmp_bytes(self, &res) ) } @@ -209,10 +209,10 @@ where assert_eq!(res.n(), self.n() as u32); assert_eq!(sk.n(), self.n() as u32); assert!( - scratch.available() >= GLWE::encrypt_sk_scratch_space(self, &res), - "scratch.available(): {} < GLWECiphertext::encrypt_sk_scratch_space: {}", + scratch.available() >= GLWE::encrypt_sk_tmp_bytes(self, &res), + "scratch.available(): {} < GLWECiphertext::encrypt_sk_tmp_bytes: {}", scratch.available(), - GLWE::encrypt_sk_scratch_space(self, &res) + GLWE::encrypt_sk_tmp_bytes(self, &res) ) } diff --git a/poulpy-core/src/encryption/glwe_pk.rs b/poulpy-core/src/encryption/glwe_pk.rs index 6312b13..d89f515 100644 --- a/poulpy-core/src/encryption/glwe_pk.rs +++ b/poulpy-core/src/encryption/glwe_pk.rs @@ -45,7 +45,7 @@ where } // Its ok to allocate scratch space here since pk is usually generated only once. - let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWE::encrypt_sk_scratch_space(self, res)); + let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWE::encrypt_sk_tmp_bytes(self, res)); let mut tmp: GLWE> = GLWE::alloc_from_infos(res); diff --git a/poulpy-core/src/encryption/glwe_to_lwe_ksk.rs b/poulpy-core/src/encryption/glwe_to_lwe_ksk.rs index 6658480..971766e 100644 --- a/poulpy-core/src/encryption/glwe_to_lwe_ksk.rs +++ b/poulpy-core/src/encryption/glwe_to_lwe_ksk.rs @@ -14,13 +14,13 @@ use crate::layouts::{ }; impl GLWEToLWESwitchingKey> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize + pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GGLWEInfos, Module: SvpPPolBytesOf + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes, { GLWESecretPrepared::bytes_of(module, infos.rank_in()) - + (GLWESwitchingKey::encrypt_sk_scratch_space(module, infos) | GLWESecret::bytes_of(infos.n(), infos.rank_in())) + + (GLWESwitchingKey::encrypt_sk_tmp_bytes(module, infos) | GLWESecret::bytes_of(infos.n(), infos.rank_in())) } } diff --git a/poulpy-core/src/encryption/lwe_ksk.rs b/poulpy-core/src/encryption/lwe_ksk.rs index 04b52aa..1f1820a 100644 --- a/poulpy-core/src/encryption/lwe_ksk.rs +++ b/poulpy-core/src/encryption/lwe_ksk.rs @@ -14,7 +14,7 @@ use crate::layouts::{ }; impl LWESwitchingKey> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize + pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GGLWEInfos, Module: SvpPPolBytesOf + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes, @@ -36,7 +36,7 @@ impl LWESwitchingKey> { ); GLWESecret::bytes_of(Degree(module.n() as u32), Rank(1)) + GLWESecretPrepared::bytes_of(module, Rank(1)) - + GLWESwitchingKey::encrypt_sk_scratch_space(module, infos) + + GLWESwitchingKey::encrypt_sk_tmp_bytes(module, infos) } } diff --git a/poulpy-core/src/encryption/lwe_to_glwe_ksk.rs b/poulpy-core/src/encryption/lwe_to_glwe_ksk.rs index b50a2db..56a6701 100644 --- a/poulpy-core/src/encryption/lwe_to_glwe_ksk.rs +++ b/poulpy-core/src/encryption/lwe_to_glwe_ksk.rs @@ -12,7 +12,7 @@ use poulpy_hal::{ use crate::layouts::{Degree, GGLWEInfos, GLWESecret, GLWESwitchingKey, LWEInfos, LWESecret, LWEToGLWESwitchingKey, Rank}; impl LWEToGLWESwitchingKey> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize + pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GGLWEInfos, Module: SvpPPolBytesOf + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes, @@ -22,8 +22,7 @@ impl LWEToGLWESwitchingKey> { Rank(1), "rank_in != 1 is not supported for LWEToGLWESwitchingKey" ); - GLWESwitchingKey::encrypt_sk_scratch_space(module, infos) - + GLWESecret::bytes_of(Degree(module.n() as u32), infos.rank_in()) + GLWESwitchingKey::encrypt_sk_tmp_bytes(module, infos) + GLWESecret::bytes_of(Degree(module.n() as u32), infos.rank_in()) } } diff --git a/poulpy-core/src/external_product/gglwe_ksk.rs b/poulpy-core/src/external_product/gglwe_ksk.rs index b7bd4af..5bb4557 100644 --- a/poulpy-core/src/external_product/gglwe_ksk.rs +++ b/poulpy-core/src/external_product/gglwe_ksk.rs @@ -18,7 +18,7 @@ where A: GGLWEInfos, B: GGSWInfos, { - self.glwe_external_product_scratch_space(res_infos, a_infos, b_infos) + self.glwe_external_product_tmp_bytes(res_infos, a_infos, b_infos) } fn gglwe_external_product(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) diff --git a/poulpy-core/src/external_product/ggsw_ct.rs b/poulpy-core/src/external_product/ggsw_ct.rs index 0b38ebf..d3a59a6 100644 --- a/poulpy-core/src/external_product/ggsw_ct.rs +++ b/poulpy-core/src/external_product/ggsw_ct.rs @@ -21,7 +21,7 @@ where A: GGSWInfos, B: GGSWInfos, { - self.glwe_external_product_scratch_space(res_infos, a_infos, b_infos) + self.glwe_external_product_tmp_bytes(res_infos, a_infos, b_infos) } fn ggsw_external_product(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) diff --git a/poulpy-core/src/external_product/glwe_ct.rs b/poulpy-core/src/external_product/glwe_ct.rs index 4735ff9..ab9968c 100644 --- a/poulpy-core/src/external_product/glwe_ct.rs +++ b/poulpy-core/src/external_product/glwe_ct.rs @@ -15,14 +15,14 @@ use crate::{ }; impl GLWE> { - pub fn external_product_scratch_space(module: &M, res_infos: &R, a_infos: &A, b_infos: &B) -> usize + pub fn external_product_tmp_bytes(module: &M, res_infos: &R, a_infos: &A, b_infos: &B) -> usize where R: GLWEInfos, A: GLWEInfos, B: GGSWInfos, M: GLWEExternalProduct, { - module.glwe_external_product_scratch_space(res_infos, a_infos, b_infos) + module.glwe_external_product_tmp_bytes(res_infos, a_infos, b_infos) } } @@ -61,7 +61,7 @@ where + VecZnxBigNormalize + VecZnxNormalize, { - fn glwe_external_product_scratch_space(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize + fn glwe_external_product_tmp_bytes(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize where R: GLWEInfos, A: GLWEInfos, @@ -111,7 +111,7 @@ where assert_eq!(rhs.rank(), res.rank()); assert_eq!(rhs.n(), res.n()); - assert!(scratch.available() >= self.glwe_external_product_scratch_space(res, res, rhs)); + assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, res, rhs)); } let cols: usize = (rhs.rank() + 1).into(); @@ -225,7 +225,7 @@ where assert_eq!(rhs.rank(), res.rank()); assert_eq!(rhs.n(), res.n()); assert_eq!(lhs.n(), res.n()); - assert!(scratch.available() >= self.glwe_external_product_scratch_space(res, lhs, rhs)); + assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, lhs, rhs)); } let cols: usize = (rhs.rank() + 1).into(); diff --git a/poulpy-core/src/glwe_packing.rs b/poulpy-core/src/glwe_packing.rs index 05a2965..7dacb97 100644 --- a/poulpy-core/src/glwe_packing.rs +++ b/poulpy-core/src/glwe_packing.rs @@ -90,13 +90,13 @@ impl GLWEPacker { } /// Number of scratch space bytes required to call [Self::add]. - pub fn scratch_space(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize + pub fn tmp_bytes(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize where OUT: GLWEInfos, KEY: GGLWEInfos, Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - pack_core_scratch_space(module, out_infos, key_infos) + pack_core_tmp_bytes(module, out_infos, key_infos) } pub fn galois_elements(module: &Module) -> Vec { @@ -111,7 +111,7 @@ impl GLWEPacker { /// of packed ciphertexts reaches N/2^log_batch is a result written. /// * `a`: ciphertext to pack. Can optionally give None to pack a 0 ciphertext. /// * `auto_keys`: a [HashMap] containing the [AutomorphismKeyExec]s. - /// * `scratch`: scratch space of size at least [Self::scratch_space]. + /// * `scratch`: scratch space of size at least [Self::tmp_bytes]. pub fn add( &mut self, module: &Module, @@ -177,13 +177,13 @@ impl GLWEPacker { } } -fn pack_core_scratch_space(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize +fn pack_core_tmp_bytes(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize where OUT: GLWEInfos, KEY: GGLWEInfos, Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - combine_scratch_space(module, out_infos, key_infos) + combine_tmp_bytes(module, out_infos, key_infos) } fn pack_core( @@ -268,14 +268,14 @@ fn pack_core( } } -fn combine_scratch_space(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize +fn combine_tmp_bytes(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize where OUT: GLWEInfos, KEY: GGLWEInfos, Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { GLWE::bytes_of_from_infos(module, out_infos) - + (GLWE::rsh_scratch_space(module.n()) | GLWE::automorphism_inplace_scratch_space(module, out_infos, key_infos)) + + (GLWE::rsh_tmp_bytes(module.n()) | GLWE::automorphism_inplace_tmp_bytes(module, out_infos, key_infos)) } /// [combine] merges two ciphertexts together. diff --git a/poulpy-core/src/glwe_trace.rs b/poulpy-core/src/glwe_trace.rs index 1d7a6e6..36aabb9 100644 --- a/poulpy-core/src/glwe_trace.rs +++ b/poulpy-core/src/glwe_trace.rs @@ -27,19 +27,14 @@ impl GLWE> { gal_els } - pub fn trace_scratch_space( - module: &Module, - out_infos: &OUT, - in_infos: &IN, - key_infos: &KEY, - ) -> usize + pub fn trace_tmp_bytes(module: &Module, out_infos: &OUT, in_infos: &IN, key_infos: &KEY) -> usize where OUT: GLWEInfos, IN: GLWEInfos, KEY: GGLWEInfos, Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - let trace: usize = Self::automorphism_inplace_scratch_space(module, out_infos, key_infos); + let trace: usize = Self::automorphism_inplace_tmp_bytes(module, out_infos, key_infos); if in_infos.base2k() != key_infos.base2k() { let glwe_conv: usize = VecZnx::bytes_of( module.n(), @@ -52,13 +47,13 @@ impl GLWE> { trace } - pub fn trace_inplace_scratch_space(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize + pub fn trace_inplace_tmp_bytes(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize where OUT: GLWEInfos, KEY: GGLWEInfos, Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - Self::trace_scratch_space(module, out_infos, out_infos, key_infos) + Self::trace_tmp_bytes(module, out_infos, out_infos, key_infos) } } diff --git a/poulpy-core/src/keyswitching/gglwe_ct.rs b/poulpy-core/src/keyswitching/gglwe_ct.rs index b1fd8f4..a65266f 100644 --- a/poulpy-core/src/keyswitching/gglwe_ct.rs +++ b/poulpy-core/src/keyswitching/gglwe_ct.rs @@ -13,7 +13,7 @@ use crate::layouts::{ }; impl AutomorphismKey> { - pub fn keyswitch_scratch_space( + pub fn keyswitch_tmp_bytes( module: &Module, out_infos: &OUT, in_infos: &IN, @@ -25,16 +25,16 @@ impl AutomorphismKey> { KEY: GGLWEInfos, Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - GLWESwitchingKey::keyswitch_scratch_space(module, out_infos, in_infos, key_infos) + GLWESwitchingKey::keyswitch_tmp_bytes(module, out_infos, in_infos, key_infos) } - pub fn keyswitch_inplace_scratch_space(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize + pub fn keyswitch_inplace_tmp_bytes(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize where OUT: GGLWEInfos, KEY: GGLWEInfos, Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - GLWESwitchingKey::keyswitch_inplace_scratch_space(module, out_infos, key_infos) + GLWESwitchingKey::keyswitch_inplace_tmp_bytes(module, out_infos, key_infos) } } @@ -86,7 +86,7 @@ impl AutomorphismKey { } impl GLWESwitchingKey> { - pub fn keyswitch_scratch_space( + pub fn keyswitch_tmp_bytes( module: &Module, out_infos: &OUT, in_infos: &IN, @@ -98,16 +98,16 @@ impl GLWESwitchingKey> { KEY: GGLWEInfos, Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - GLWE::keyswitch_scratch_space(module, out_infos, in_infos, key_apply) + GLWE::keyswitch_tmp_bytes(module, out_infos, in_infos, key_apply) } - pub fn keyswitch_inplace_scratch_space(module: &Module, out_infos: &OUT, key_apply: &KEY) -> usize + pub fn keyswitch_inplace_tmp_bytes(module: &Module, out_infos: &OUT, key_apply: &KEY) -> usize where OUT: GGLWEInfos + GLWEInfos, KEY: GGLWEInfos + GLWEInfos, Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - GLWE::keyswitch_inplace_scratch_space(module, out_infos, key_apply) + GLWE::keyswitch_inplace_tmp_bytes(module, out_infos, key_apply) } } diff --git a/poulpy-core/src/keyswitching/ggsw_ct.rs b/poulpy-core/src/keyswitching/ggsw_ct.rs index 4f89a41..078739b 100644 --- a/poulpy-core/src/keyswitching/ggsw_ct.rs +++ b/poulpy-core/src/keyswitching/ggsw_ct.rs @@ -1,359 +1,131 @@ -use poulpy_hal::{ - api::{ - ScratchAvailable, VecZnxBigAddSmallInplace, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, - VecZnxDftAddInplace, VecZnxDftApply, VecZnxDftBytesOf, VecZnxDftCopy, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, - VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, - }, - layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx, VmpPMat, ZnxInfos}, -}; +use poulpy_hal::layouts::{Backend, DataMut, Scratch, VecZnx}; use crate::{ + GGSWExpandRows, ScratchTakeCore, + keyswitching::glwe_ct::GLWEKeySwitching, layouts::{ - GGLWE, GGLWEInfos, GGSW, GGSWInfos, GLWE, GLWEInfos, LWEInfos, - prepared::{GLWESwitchingKeyPrepared, TensorKeyPrepared}, + GGLWEInfos, GGSW, GGSWInfos, GGSWToMut, GGSWToRef, + prepared::{GLWESwitchingKeyPreparedToRef, TensorKeyPreparedToRef}, }, - operations::GLWEOperations, }; impl GGSW> { - pub(crate) fn expand_row_scratch_space(module: &Module, out_infos: &OUT, tsk_infos: &TSK) -> usize - where - OUT: GGSWInfos, - TSK: GGLWEInfos, - Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigBytesOf + VecZnxNormalizeTmpBytes, - { - let tsk_size: usize = tsk_infos.k().div_ceil(tsk_infos.base2k()) as usize; - let size_in: usize = out_infos - .k() - .div_ceil(tsk_infos.base2k()) - .div_ceil(tsk_infos.dsize().into()) as usize; - - let tmp_dft_i: usize = module.bytes_of_vec_znx_dft((tsk_infos.rank_out() + 1).into(), tsk_size); - let tmp_a: usize = module.bytes_of_vec_znx_dft(1, size_in); - let vmp: usize = module.vmp_apply_dft_to_dft_tmp_bytes( - tsk_size, - size_in, - size_in, - (tsk_infos.rank_in()).into(), // Verify if rank+1 - (tsk_infos.rank_out()).into(), // Verify if rank+1 - tsk_size, - ); - let tmp_idft: usize = module.bytes_of_vec_znx_big(1, tsk_size); - let norm: usize = module.vec_znx_normalize_tmp_bytes(); - - tmp_dft_i + ((tmp_a + vmp) | (tmp_idft + norm)) - } - - #[allow(clippy::too_many_arguments)] - pub fn keyswitch_scratch_space( - module: &Module, - out_infos: &OUT, - in_infos: &IN, - apply_infos: &KEY, - tsk_infos: &TSK, + pub fn keyswitch_tmp_bytes( + module: &M, + res_infos: &R, + a_infos: &A, + key_infos: &K, + tsk_infos: &T, ) -> usize where - OUT: GGSWInfos, - IN: GGSWInfos, - KEY: GGLWEInfos, - TSK: GGLWEInfos, - Module: - VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigBytesOf + VecZnxNormalizeTmpBytes + VecZnxBigNormalizeTmpBytes, + R: GGSWInfos, + A: GGSWInfos, + K: GGLWEInfos, + T: GGLWEInfos, + M: GGSWKeySwitch, { - #[cfg(debug_assertions)] - { - assert_eq!(apply_infos.rank_in(), apply_infos.rank_out()); - assert_eq!(tsk_infos.rank_in(), tsk_infos.rank_out()); - assert_eq!(apply_infos.rank_in(), tsk_infos.rank_in()); - } + module.ggsw_keyswitch_tmp_bytes(res_infos, a_infos, key_infos, tsk_infos) + } +} - let rank: usize = apply_infos.rank_out().into(); +impl GGSW { + pub fn keyswitch(&mut self, module: &M, a: &A, key: &K, tsk: &T, scratch: &mut Scratch) + where + A: GGSWToRef, + K: GLWESwitchingKeyPreparedToRef, + T: TensorKeyPreparedToRef, + Scratch: ScratchTakeCore, + M: GGSWKeySwitch, + { + module.ggsw_keyswitch(self, a, key, tsk, scratch); + } - let size_out: usize = out_infos.k().div_ceil(out_infos.base2k()) as usize; - let res_znx: usize = VecZnx::bytes_of(module.n(), rank + 1, size_out); - let ci_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, size_out); - let ks: usize = GLWE::keyswitch_scratch_space(module, out_infos, in_infos, apply_infos); - let expand_rows: usize = GGSW::expand_row_scratch_space(module, out_infos, tsk_infos); - let res_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, size_out); + pub fn keyswitch_inplace(&mut self, module: &M, key: &K, tsk: &T, scratch: &mut Scratch) + where + K: GLWESwitchingKeyPreparedToRef, + T: TensorKeyPreparedToRef, + Scratch: ScratchTakeCore, + M: GGSWKeySwitch, + { + module.ggsw_keyswitch_inplace(self, key, tsk, scratch); + } +} - if in_infos.base2k() == tsk_infos.base2k() { +pub trait GGSWKeySwitch +where + Self: GLWEKeySwitching + GGSWExpandRows, +{ + fn ggsw_keyswitch_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &K, tsk_infos: &T) -> usize + where + R: GGSWInfos, + A: GGSWInfos, + K: GGLWEInfos, + T: GGLWEInfos, + { + assert_eq!(key_infos.rank_in(), key_infos.rank_out()); + assert_eq!(tsk_infos.rank_in(), tsk_infos.rank_out()); + assert_eq!(key_infos.rank_in(), tsk_infos.rank_in()); + + let rank: usize = key_infos.rank_out().into(); + + let size_out: usize = res_infos.k().div_ceil(res_infos.base2k()) as usize; + let res_znx: usize = VecZnx::bytes_of(self.n(), rank + 1, size_out); + let ci_dft: usize = self.bytes_of_vec_znx_dft(rank + 1, size_out); + let ks: usize = self.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos); + let expand_rows: usize = self.ggsw_expand_rows_tmp_bytes(res_infos, tsk_infos); + let res_dft: usize = self.bytes_of_vec_znx_dft(rank + 1, size_out); + + if a_infos.base2k() == tsk_infos.base2k() { res_znx + ci_dft + (ks | expand_rows | res_dft) } else { let a_conv: usize = VecZnx::bytes_of( - module.n(), + self.n(), 1, - out_infos.k().div_ceil(tsk_infos.base2k()) as usize, - ) + module.vec_znx_normalize_tmp_bytes(); + res_infos.k().div_ceil(tsk_infos.base2k()) as usize, + ) + self.vec_znx_normalize_tmp_bytes(); res_znx + ci_dft + (a_conv | ks | expand_rows | res_dft) } } - #[allow(clippy::too_many_arguments)] - pub fn keyswitch_inplace_scratch_space( - module: &Module, - out_infos: &OUT, - apply_infos: &KEY, - tsk_infos: &TSK, - ) -> usize + fn ggsw_keyswitch(&self, res: &mut R, a: &A, key: &K, tsk: &T, scratch: &mut Scratch) where - OUT: GGSWInfos, - KEY: GGLWEInfos, - TSK: GGLWEInfos, - Module: - VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigBytesOf + VecZnxNormalizeTmpBytes + VecZnxBigNormalizeTmpBytes, + R: GGSWToMut, + A: GGSWToRef, + K: GLWESwitchingKeyPreparedToRef, + T: TensorKeyPreparedToRef, + Scratch: ScratchTakeCore, { - GGSW::keyswitch_scratch_space(module, out_infos, out_infos, apply_infos, tsk_infos) + let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); + let a: &GGSW<&[u8]> = &a.to_ref(); + + assert_eq!(res.ggsw_layout(), a.ggsw_layout()); + + for row in 0..a.dnum().into() { + // Key-switch column 0, i.e. + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2) + self.glwe_keyswitch(&mut res.at_mut(row, 0), &a.at(row, 0), key, scratch); + } + + self.ggsw_expand_row(res, tsk, scratch); + } + + fn ggsw_keyswitch_inplace(&self, res: &mut R, key: &K, tsk: &T, scratch: &mut Scratch) + where + R: GGSWToMut, + K: GLWESwitchingKeyPreparedToRef, + T: TensorKeyPreparedToRef, + Scratch: ScratchTakeCore, + { + let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); + + for row in 0..res.dnum().into() { + // Key-switch column 0, i.e. + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2) + self.glwe_keyswitch_inplace(&mut res.at_mut(row, 0), key, scratch); + } + + self.ggsw_expand_row(res, tsk, scratch); } } -impl GGSW { - pub fn from_gglwe( - &mut self, - module: &Module, - a: &GGLWE, - tsk: &TensorKeyPrepared, - scratch: &mut Scratch, - ) where - DataA: DataRef, - DataTsk: DataRef, - Module: VecZnxCopy - + VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigBytesOf - + VecZnxNormalizeTmpBytes - + VecZnxDftApply - + VecZnxDftCopy - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftAddInplace - + VecZnxBigNormalize - + VecZnxIdftApplyTmpA - + VecZnxNormalize, - Scratch: ScratchAvailable, - { - #[cfg(debug_assertions)] - { - use crate::layouts::{GLWEInfos, LWEInfos}; - - assert_eq!(self.rank(), a.rank_out()); - assert_eq!(self.dnum(), a.dnum()); - assert_eq!(self.n(), module.n() as u32); - assert_eq!(a.n(), module.n() as u32); - assert_eq!(tsk.n(), module.n() as u32); - } - (0..self.dnum().into()).for_each(|row_i| { - self.at_mut(row_i, 0).copy(module, &a.at(row_i, 0)); - }); - self.expand_row(module, tsk, scratch); - } - - pub fn keyswitch( - &mut self, - module: &Module, - lhs: &GGSW, - ksk: &GLWESwitchingKeyPrepared, - tsk: &TensorKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxDftBytesOf - + VecZnxBigBytesOf - + VecZnxNormalizeTmpBytes - + VecZnxDftCopy - + VecZnxDftAddInplace - + VecZnxIdftApplyTmpA - + VecZnxNormalize, - Scratch: ScratchAvailable, - { - (0..lhs.dnum().into()).for_each(|row_i| { - // Key-switch column 0, i.e. - // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2) - self.at_mut(row_i, 0) - .keyswitch(module, &lhs.at(row_i, 0), ksk, scratch); - }); - self.expand_row(module, tsk, scratch); - } - - pub fn keyswitch_inplace( - &mut self, - module: &Module, - ksk: &GLWESwitchingKeyPrepared, - tsk: &TensorKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxDftBytesOf - + VecZnxBigBytesOf - + VecZnxNormalizeTmpBytes - + VecZnxDftCopy - + VecZnxDftAddInplace - + VecZnxIdftApplyTmpA - + VecZnxNormalize, - Scratch: ScratchAvailable, - { - (0..self.dnum().into()).for_each(|row_i| { - // Key-switch column 0, i.e. - // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2) - self.at_mut(row_i, 0) - .keyswitch_inplace(module, ksk, scratch); - }); - self.expand_row(module, tsk, scratch); - } - - pub fn expand_row( - &mut self, - module: &Module, - tsk: &TensorKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigBytesOf - + VecZnxNormalizeTmpBytes - + VecZnxDftApply - + VecZnxDftCopy - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftAddInplace - + VecZnxBigNormalize - + VecZnxIdftApplyTmpA - + VecZnxNormalize, - Scratch: ScratchAvailable, - { - let basek_in: usize = self.base2k().into(); - let basek_tsk: usize = tsk.base2k().into(); - - assert!(scratch.available() >= GGSW::expand_row_scratch_space(module, self, tsk)); - - let n: usize = self.n().into(); - let rank: usize = self.rank().into(); - let cols: usize = rank + 1; - - let a_size: usize = (self.size() * basek_in).div_ceil(basek_tsk); - - // Keyswitch the j-th row of the col 0 - for row_i in 0..self.dnum().into() { - let a = &self.at(row_i, 0).data; - - // Pre-compute DFT of (a0, a1, a2) - let (mut ci_dft, scratch_1) = scratch.take_vec_znx_dft(n, cols, a_size); - - if basek_in == basek_tsk { - for i in 0..cols { - module.vec_znx_dft_apply(1, 0, &mut ci_dft, i, a, i); - } - } else { - let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(n, 1, a_size); - for i in 0..cols { - module.vec_znx_normalize(basek_tsk, &mut a_conv, 0, basek_in, a, i, scratch_2); - module.vec_znx_dft_apply(1, 0, &mut ci_dft, i, &a_conv, 0); - } - } - - for col_j in 1..cols { - // Example for rank 3: - // - // Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is - // actually composed of that many dnum and we focus on a specific row here - // implicitely given ci_dft. - // - // # Input - // - // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) - // col 1: (0, 0, 0, 0) - // col 2: (0, 0, 0, 0) - // col 3: (0, 0, 0, 0) - // - // # Output - // - // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) - // col 1: (-(b0s0 + b1s1 + b2s2) , b0 + M[i], b1 , b2 ) - // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M[i], c2 ) - // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M[i]) - - let dsize: usize = tsk.dsize().into(); - - let (mut tmp_dft_i, scratch_2) = scratch_1.take_vec_znx_dft(n, cols, tsk.size()); - let (mut tmp_a, scratch_3) = scratch_2.take_vec_znx_dft(n, 1, ci_dft.size().div_ceil(dsize)); - - { - // Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2 - // - // # Example for col=1 - // - // a0 * (-(f0s0 + f1s1 + f1s2) + s0^2, f0, f1, f2) = (-(a0f0s0 + a0f1s1 + a0f1s2) + a0s0^2, a0f0, a0f1, a0f2) - // + - // a1 * (-(g0s0 + g1s1 + g1s2) + s0s1, g0, g1, g2) = (-(a1g0s0 + a1g1s1 + a1g1s2) + a1s0s1, a1g0, a1g1, a1g2) - // + - // a2 * (-(h0s0 + h1s1 + h1s2) + s0s2, h0, h1, h2) = (-(a2h0s0 + a2h1s1 + a2h1s2) + a2s0s2, a2h0, a2h1, a2h2) - // = - // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2) - for col_i in 1..cols { - let pmat: &VmpPMat = &tsk.at(col_i - 1, col_j - 1).key.data; // Selects Enc(s[i]s[j]) - - // Extracts a[i] and multipies with Enc(s[i]s[j]) - for di in 0..dsize { - tmp_a.set_size((ci_dft.size() + di) / dsize); - - // Small optimization for dsize > 2 - // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then - // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}. - // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. - // It is possible to further ignore the last dsize-1 limbs, but this introduce - // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same - // noise is kept with respect to the ideal functionality. - tmp_dft_i.set_size(tsk.size() - ((dsize - di) as isize - 2).max(0) as usize); - - module.vec_znx_dft_copy(dsize, dsize - 1 - di, &mut tmp_a, 0, &ci_dft, col_i); - if di == 0 && col_i == 1 { - module.vmp_apply_dft_to_dft(&mut tmp_dft_i, &tmp_a, pmat, scratch_3); - } else { - module.vmp_apply_dft_to_dft_add(&mut tmp_dft_i, &tmp_a, pmat, di, scratch_3); - } - } - } - } - - // Adds -(sum a[i] * s[i]) + m) on the i-th column of tmp_idft_i - // - // (-(x0s0 + x1s1 + x2s2) + a0s0s0 + a1s0s1 + a2s0s2, x0, x1, x2) - // + - // (0, -(a0s0 + a1s1 + a2s2) + M[i], 0, 0) - // = - // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0 -(a0s0 + a1s1 + a2s2) + M[i], x1, x2) - // = - // (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2) - module.vec_znx_dft_add_inplace(&mut tmp_dft_i, col_j, &ci_dft, 0); - let (mut tmp_idft, scratch_3) = scratch_2.take_vec_znx_big(n, 1, tsk.size()); - for i in 0..cols { - module.vec_znx_idft_apply_tmpa(&mut tmp_idft, 0, &mut tmp_dft_i, i); - module.vec_znx_big_normalize( - basek_in, - &mut self.at_mut(row_i, col_j).data, - i, - basek_tsk, - &tmp_idft, - 0, - scratch_3, - ); - } - } - } - } -} +impl GGSW {} diff --git a/poulpy-core/src/keyswitching/glwe_ct.rs b/poulpy-core/src/keyswitching/glwe_ct.rs index ebd0243..ac4c7b8 100644 --- a/poulpy-core/src/keyswitching/glwe_ct.rs +++ b/poulpy-core/src/keyswitching/glwe_ct.rs @@ -1,186 +1,179 @@ use poulpy_hal::{ api::{ - ScratchAvailable, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftApply, - VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, + ModuleN, ScratchAvailable, ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, + VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, ZnxInfos}, }; -use crate::layouts::{GGLWEInfos, GLWE, GLWEInfos, LWEInfos, prepared::GLWESwitchingKeyPrepared}; +use crate::{ + ScratchTakeCore, + layouts::{ + GGLWEInfos, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, + prepared::{GLWESwitchingKeyPrepared, GLWESwitchingKeyPreparedToRef}, + }, +}; impl GLWE> { - pub fn keyswitch_scratch_space( - module: &Module, - out_infos: &OUT, - in_infos: &IN, - key_apply: &KEY, - ) -> usize + pub fn keyswitch_tmp_bytes(module: &M, res_infos: &R, a_infos: &A, b_infos: &B) -> usize where - OUT: GLWEInfos, - IN: GLWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, + R: GLWEInfos, + A: GLWEInfos, + B: GGLWEInfos, + M: GLWEKeySwitching, { - let in_size: usize = in_infos + module.glwe_keyswitch_tmp_bytes(res_infos, a_infos, b_infos) + } +} + +impl GLWE { + pub fn keyswitch(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch) + where + A: GLWEToRef, + B: GLWESwitchingKeyPreparedToRef, + M: GLWEKeySwitching, + Scratch: ScratchTakeCore, + { + module.glwe_keyswitch(self, a, b, scratch); + } + + pub fn keyswitch_inplace(&mut self, module: &M, a: &A, scratch: &mut Scratch) + where + A: GLWESwitchingKeyPreparedToRef, + M: GLWEKeySwitching, + Scratch: ScratchTakeCore, + { + module.glwe_keyswitch_inplace(self, a, scratch); + } +} + +impl GLWEKeySwitching for Module where + Self: Sized + + ModuleN + + VecZnxDftBytesOf + + VmpApplyDftToDftTmpBytes + + VecZnxBigNormalizeTmpBytes + + VecZnxNormalizeTmpBytes + + VecZnxDftBytesOf + + VmpApplyDftToDftTmpBytes + + VecZnxBigNormalizeTmpBytes + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VecZnxDftApply + + VecZnxIdftApplyConsume + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize + + VecZnxNormalize + + VecZnxNormalizeTmpBytes +{ +} + +pub trait GLWEKeySwitching +where + Self: Sized + + ModuleN + + VecZnxDftBytesOf + + VmpApplyDftToDftTmpBytes + + VecZnxBigNormalizeTmpBytes + + VecZnxNormalizeTmpBytes + + VecZnxDftBytesOf + + VmpApplyDftToDftTmpBytes + + VecZnxBigNormalizeTmpBytes + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VecZnxDftApply + + VecZnxIdftApplyConsume + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize + + VecZnxNormalize + + VecZnxNormalizeTmpBytes, +{ + fn glwe_keyswitch_tmp_bytes(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize + where + R: GLWEInfos, + A: GLWEInfos, + B: GGLWEInfos, + { + let in_size: usize = a_infos .k() - .div_ceil(key_apply.base2k()) - .div_ceil(key_apply.dsize().into()) as usize; - let out_size: usize = out_infos.size(); - let ksk_size: usize = key_apply.size(); - let res_dft: usize = module.bytes_of_vec_znx_dft((key_apply.rank_out() + 1).into(), ksk_size); // TODO OPTIMIZE - let ai_dft: usize = module.bytes_of_vec_znx_dft((key_apply.rank_in()).into(), in_size); - let vmp: usize = module.vmp_apply_dft_to_dft_tmp_bytes( + .div_ceil(b_infos.base2k()) + .div_ceil(b_infos.dsize().into()) as usize; + let out_size: usize = res_infos.size(); + let ksk_size: usize = b_infos.size(); + let res_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank_out() + 1).into(), ksk_size); // TODO OPTIMIZE + let ai_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank_in()).into(), in_size); + let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes( out_size, in_size, in_size, - (key_apply.rank_in()).into(), - (key_apply.rank_out() + 1).into(), + (b_infos.rank_in()).into(), + (b_infos.rank_out() + 1).into(), ksk_size, - ) + module.bytes_of_vec_znx_dft((key_apply.rank_in()).into(), in_size); - let normalize_big: usize = module.vec_znx_big_normalize_tmp_bytes(); - if in_infos.base2k() == key_apply.base2k() { + ) + self.bytes_of_vec_znx_dft((b_infos.rank_in()).into(), in_size); + let normalize_big: usize = self.vec_znx_big_normalize_tmp_bytes(); + if a_infos.base2k() == b_infos.base2k() { res_dft + ((ai_dft + vmp) | normalize_big) - } else if key_apply.dsize() == 1 { + } else if b_infos.dsize() == 1 { // In this case, we only need one column, temporary, that we can drop once a_dft is computed. - let normalize_conv: usize = VecZnx::bytes_of(module.n(), 1, in_size) + module.vec_znx_normalize_tmp_bytes(); + let normalize_conv: usize = VecZnx::bytes_of(self.n(), 1, in_size) + self.vec_znx_normalize_tmp_bytes(); res_dft + (((ai_dft + normalize_conv) | vmp) | normalize_big) } else { // Since we stride over a to get a_dft when dsize > 1, we need to store the full columns of a with in the base conversion. - let normalize_conv: usize = VecZnx::bytes_of(module.n(), (key_apply.rank_in()).into(), in_size); - res_dft + ((ai_dft + normalize_conv + (module.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big) + let normalize_conv: usize = VecZnx::bytes_of(self.n(), (b_infos.rank_in()).into(), in_size); + res_dft + ((ai_dft + normalize_conv + (self.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big) } } - pub fn keyswitch_inplace_scratch_space(module: &Module, out_infos: &OUT, key_apply: &KEY) -> usize + fn glwe_keyswitch(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) where - OUT: GLWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, + R: GLWEToMut, + A: GLWEToRef, + B: GLWESwitchingKeyPreparedToRef, + Scratch: ScratchTakeCore, { - Self::keyswitch_scratch_space(module, out_infos, out_infos, key_apply) - } -} + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); + let b: &GLWESwitchingKeyPrepared<&[u8], BE> = &b.to_ref(); -impl GLWE { - #[allow(dead_code)] - pub(crate) fn assert_keyswitch( - &self, - module: &Module, - lhs: &GLWE, - rhs: &GLWESwitchingKeyPrepared, - scratch: &Scratch, - ) where - DataLhs: DataRef, - DataRhs: DataRef, - Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable, - { assert_eq!( - lhs.rank(), - rhs.rank_in(), - "lhs.rank(): {} != rhs.rank_in(): {}", - lhs.rank(), - rhs.rank_in() + a.rank(), + b.rank_in(), + "a.rank(): {} != b.rank_in(): {}", + a.rank(), + b.rank_in() ); assert_eq!( - self.rank(), - rhs.rank_out(), - "self.rank(): {} != rhs.rank_out(): {}", - self.rank(), - rhs.rank_out() + res.rank(), + b.rank_out(), + "res.rank(): {} != b.rank_out(): {}", + res.rank(), + b.rank_out() ); - assert_eq!(rhs.n(), self.n()); - assert_eq!(lhs.n(), self.n()); - let scrach_needed: usize = GLWE::keyswitch_scratch_space(module, self, lhs, rhs); + assert_eq!(res.n(), self.n() as u32); + assert_eq!(a.n(), self.n() as u32); + assert_eq!(b.n(), self.n() as u32); + + let scrach_needed: usize = self.glwe_keyswitch_tmp_bytes(res, a, b); assert!( scratch.available() >= scrach_needed, - "scratch.available()={} < GLWECiphertext::keyswitch_scratch_space( - module, - self.base2k(), - self.k(), - lhs.base2k(), - lhs.k(), - rhs.base2k(), - rhs.k(), - rhs.dsize(), - rhs.rank_in(), - rhs.rank_out(), - )={scrach_needed}", + "scratch.available()={} < glwe_keyswitch_tmp_bytes={scrach_needed}", scratch.available(), ); - } - #[allow(dead_code)] - pub(crate) fn assert_keyswitch_inplace( - &self, - module: &Module, - rhs: &GLWESwitchingKeyPrepared, - scratch: &Scratch, - ) where - DataRhs: DataRef, - Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable, - { - assert_eq!( - self.rank(), - rhs.rank_out(), - "self.rank(): {} != rhs.rank_out(): {}", - self.rank(), - rhs.rank_out() - ); + let basek_out: usize = res.base2k().into(); + let base2k_out: usize = b.base2k().into(); - assert_eq!(rhs.n(), self.n()); - - let scrach_needed: usize = GLWE::keyswitch_inplace_scratch_space(module, self, rhs); - - assert!( - scratch.available() >= scrach_needed, - "scratch.available()={} < GLWECiphertext::keyswitch_scratch_space()={scrach_needed}", - scratch.available(), - ); - } -} - -impl GLWE { - pub fn keyswitch( - &mut self, - module: &Module, - glwe_in: &GLWE, - rhs: &GLWESwitchingKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable, - { - #[cfg(debug_assertions)] - { - self.assert_keyswitch(module, glwe_in, rhs, scratch); - } - - let basek_out: usize = self.base2k().into(); - let basek_ksk: usize = rhs.base2k().into(); - - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // Todo optimise - let res_big: VecZnxBig<_, B> = glwe_in.keyswitch_internal(module, res_dft, rhs, scratch_1); - (0..(self.rank() + 1).into()).for_each(|i| { - module.vec_znx_big_normalize( + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), b.size()); // Todo optimise + let res_big: VecZnxBig<&mut [u8], BE> = keyswitch_internal(self, res_dft, a, b, scratch_1); + (0..(res.rank() + 1).into()).for_each(|i| { + self.vec_znx_big_normalize( basek_out, - &mut self.data, + &mut res.data, i, - basek_ksk, + base2k_out, &res_big, i, scratch_1, @@ -188,227 +181,190 @@ impl GLWE { }) } - pub fn keyswitch_inplace( - &mut self, - module: &Module, - rhs: &GLWESwitchingKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable, - { - #[cfg(debug_assertions)] - { - self.assert_keyswitch_inplace(module, rhs, scratch); - } - - let basek_in: usize = self.base2k().into(); - let basek_ksk: usize = rhs.base2k().into(); - - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // Todo optimise - let res_big: VecZnxBig<_, B> = self.keyswitch_internal(module, res_dft, rhs, scratch_1); - (0..(self.rank() + 1).into()).for_each(|i| { - module.vec_znx_big_normalize( - basek_in, - &mut self.data, - i, - basek_ksk, - &res_big, - i, - scratch_1, - ); - }) - } -} - -impl GLWE { - pub(crate) fn keyswitch_internal( - &self, - module: &Module, - res_dft: VecZnxDft, - rhs: &GLWESwitchingKeyPrepared, - scratch: &mut Scratch, - ) -> VecZnxBig + fn glwe_keyswitch_inplace(&self, res: &mut R, a: &A, scratch: &mut Scratch) where - DataRes: DataMut, - DataKey: DataRef, - Module: VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalize, - Scratch:, + R: GLWEToMut, + A: GLWESwitchingKeyPreparedToRef, + Scratch: ScratchTakeCore, { - if rhs.dsize() == 1 { - return keyswitch_vmp_one_digit( - module, - self.base2k().into(), - rhs.base2k().into(), - res_dft, - &self.data, - &rhs.key.data, - scratch, + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWESwitchingKeyPrepared<&[u8], BE> = &a.to_ref(); + + assert_eq!( + res.rank(), + a.rank_in(), + "res.rank(): {} != a.rank_in(): {}", + res.rank(), + a.rank_in() + ); + assert_eq!( + res.rank(), + a.rank_out(), + "res.rank(): {} != b.rank_out(): {}", + res.rank(), + a.rank_out() + ); + + assert_eq!(res.n(), self.n() as u32); + assert_eq!(a.n(), self.n() as u32); + + let scrach_needed: usize = self.glwe_keyswitch_tmp_bytes(res, res, a); + + assert!( + scratch.available() >= scrach_needed, + "scratch.available()={} < glwe_keyswitch_tmp_bytes={scrach_needed}", + scratch.available(), + ); + + let base2k_in: usize = res.base2k().into(); + let base2k_out: usize = a.base2k().into(); + + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), a.size()); // Todo optimise + let res_big: VecZnxBig<&mut [u8], BE> = keyswitch_internal(self, res_dft, res, a, scratch_1); + (0..(res.rank() + 1).into()).for_each(|i| { + self.vec_znx_big_normalize( + base2k_in, + &mut res.data, + i, + base2k_out, + &res_big, + i, + scratch_1, ); - } - - keyswitch_vmp_multiple_digits( - module, - self.base2k().into(), - rhs.base2k().into(), - res_dft, - &self.data, - &rhs.key.data, - rhs.dsize().into(), - scratch, - ) + }) } } -fn keyswitch_vmp_one_digit( - module: &Module, - basek_in: usize, - basek_ksk: usize, - mut res_dft: VecZnxDft, - a: &VecZnx, - mat: &VmpPMat, - scratch: &mut Scratch, -) -> VecZnxBig +impl GLWE> {} + +impl GLWE {} + +fn keyswitch_internal( + module: &M, + mut res: VecZnxDft, + a: &GLWE, + b: &GLWESwitchingKeyPrepared, + scratch: &mut Scratch, +) -> VecZnxBig where - DataRes: DataMut, - DataIn: DataRef, - DataVmp: DataRef, - Module: VecZnxDftBytesOf - + VecZnxDftApply - + VmpApplyDftToDft - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxNormalize, - Scratch:, + DR: DataMut, + DA: DataRef, + DB: DataRef, + M: ModuleN + + VecZnxDftBytesOf + + VmpApplyDftToDftTmpBytes + + VecZnxBigNormalizeTmpBytes + + VmpApplyDftToDftTmpBytes + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VecZnxDftApply + + VecZnxIdftApplyConsume + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize + + VecZnxNormalize, + Scratch: ScratchTakeCore, { - let cols: usize = a.cols(); + let base2k_in: usize = a.base2k().into(); + let base2k_out: usize = b.base2k().into(); + let cols: usize = (a.rank() + 1).into(); + let a_size: usize = (a.size() * base2k_in).div_ceil(base2k_out); + let pmat: &VmpPMat = &b.key.data; - let a_size: usize = (a.size() * basek_in).div_ceil(basek_ksk); - let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(a.n(), cols - 1, a.size()); + if b.dsize() == 1 { + let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols - 1, a.size()); - if basek_in == basek_ksk { - (0..cols - 1).for_each(|col_i| { - module.vec_znx_dft_apply(1, 0, &mut ai_dft, col_i, a, col_i + 1); - }); + if base2k_in == base2k_out { + (0..cols - 1).for_each(|col_i| { + module.vec_znx_dft_apply(1, 0, &mut ai_dft, col_i, a.data(), col_i + 1); + }); + } else { + let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(module, 1, a_size); + (0..cols - 1).for_each(|col_i| { + module.vec_znx_normalize( + base2k_out, + &mut a_conv, + 0, + base2k_in, + a.data(), + col_i + 1, + scratch_2, + ); + module.vec_znx_dft_apply(1, 0, &mut ai_dft, col_i, &a_conv, 0); + }); + } + + module.vmp_apply_dft_to_dft(&mut res, &ai_dft, pmat, scratch_1); } else { - let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(a.n(), 1, a_size); - (0..cols - 1).for_each(|col_i| { - module.vec_znx_normalize(basek_ksk, &mut a_conv, 0, basek_in, a, col_i + 1, scratch_2); - module.vec_znx_dft_apply(1, 0, &mut ai_dft, col_i, &a_conv, 0); - }); + let dsize: usize = b.dsize().into(); + + let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols - 1, a_size.div_ceil(dsize)); + ai_dft.data_mut().fill(0); + + if base2k_in == base2k_out { + for di in 0..dsize { + ai_dft.set_size((a_size + di) / dsize); + + // Small optimization for dsize > 2 + // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then + // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}. + // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. + // It is possible to further ignore the last dsize-1 limbs, but this introduce + // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same + // noise is kept with respect to the ideal functionality. + res.set_size(pmat.size() - ((dsize - di) as isize - 2).max(0) as usize); + + for j in 0..cols - 1 { + module.vec_znx_dft_apply(dsize, dsize - di - 1, &mut ai_dft, j, a.data(), j + 1); + } + + if di == 0 { + module.vmp_apply_dft_to_dft(&mut res, &ai_dft, pmat, scratch_1); + } else { + module.vmp_apply_dft_to_dft_add(&mut res, &ai_dft, pmat, di, scratch_1); + } + } + } else { + let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(module, cols - 1, a_size); + for j in 0..cols - 1 { + module.vec_znx_normalize( + base2k_out, + &mut a_conv, + j, + base2k_in, + a.data(), + j + 1, + scratch_2, + ); + } + + for di in 0..dsize { + ai_dft.set_size((a_size + di) / dsize); + + // Small optimization for dsize > 2 + // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then + // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}. + // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. + // It is possible to further ignore the last dsize-1 limbs, but this introduce + // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same + // noise is kept with respect to the ideal functionality. + res.set_size(pmat.size() - ((dsize - di) as isize - 2).max(0) as usize); + + for j in 0..cols - 1 { + module.vec_znx_dft_apply(dsize, dsize - di - 1, &mut ai_dft, j, &a_conv, j); + } + + if di == 0 { + module.vmp_apply_dft_to_dft(&mut res, &ai_dft, pmat, scratch_2); + } else { + module.vmp_apply_dft_to_dft_add(&mut res, &ai_dft, pmat, di, scratch_2); + } + } + } + + res.set_size(res.max_size()); } - module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch_1); - let mut res_big: VecZnxBig = module.vec_znx_idft_apply_consume(res_dft); - module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0); - res_big -} - -#[allow(clippy::too_many_arguments)] -fn keyswitch_vmp_multiple_digits( - module: &Module, - basek_in: usize, - basek_ksk: usize, - mut res_dft: VecZnxDft, - a: &VecZnx, - mat: &VmpPMat, - dsize: usize, - scratch: &mut Scratch, -) -> VecZnxBig -where - DataRes: DataMut, - DataIn: DataRef, - DataVmp: DataRef, - Module: VecZnxDftBytesOf - + VecZnxDftApply - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxNormalize, - Scratch:, -{ - let cols: usize = a.cols(); - let a_size: usize = (a.size() * basek_in).div_ceil(basek_ksk); - let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(a.n(), cols - 1, a_size.div_ceil(dsize)); - ai_dft.data_mut().fill(0); - - if basek_in == basek_ksk { - for di in 0..dsize { - ai_dft.set_size((a_size + di) / dsize); - - // Small optimization for dsize > 2 - // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then - // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}. - // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. - // It is possible to further ignore the last dsize-1 limbs, but this introduce - // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same - // noise is kept with respect to the ideal functionality. - res_dft.set_size(mat.size() - ((dsize - di) as isize - 2).max(0) as usize); - - for j in 0..cols - 1 { - module.vec_znx_dft_apply(dsize, dsize - di - 1, &mut ai_dft, j, a, j + 1); - } - - if di == 0 { - module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch_1); - } else { - module.vmp_apply_dft_to_dft_add(&mut res_dft, &ai_dft, mat, di, scratch_1); - } - } - } else { - let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(a.n(), cols - 1, a_size); - for j in 0..cols - 1 { - module.vec_znx_normalize(basek_ksk, &mut a_conv, j, basek_in, a, j + 1, scratch_2); - } - - for di in 0..dsize { - ai_dft.set_size((a_size + di) / dsize); - - // Small optimization for dsize > 2 - // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then - // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}. - // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. - // It is possible to further ignore the last dsize-1 limbs, but this introduce - // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same - // noise is kept with respect to the ideal functionality. - res_dft.set_size(mat.size() - ((dsize - di) as isize - 2).max(0) as usize); - - for j in 0..cols - 1 { - module.vec_znx_dft_apply(dsize, dsize - di - 1, &mut ai_dft, j, &a_conv, j); - } - - if di == 0 { - module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch_2); - } else { - module.vmp_apply_dft_to_dft_add(&mut res_dft, &ai_dft, mat, di, scratch_2); - } - } - } - - res_dft.set_size(res_dft.max_size()); - let mut res_big: VecZnxBig = module.vec_znx_idft_apply_consume(res_dft); - module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0); + let mut res_big: VecZnxBig = module.vec_znx_idft_apply_consume(res); + module.vec_znx_big_add_small_inplace(&mut res_big, 0, a.data(), 0); res_big } diff --git a/poulpy-core/src/keyswitching/lwe_ct.rs b/poulpy-core/src/keyswitching/lwe_ct.rs index 938c732..e39259f 100644 --- a/poulpy-core/src/keyswitching/lwe_ct.rs +++ b/poulpy-core/src/keyswitching/lwe_ct.rs @@ -10,7 +10,7 @@ use poulpy_hal::{ use crate::layouts::{GGLWEInfos, GLWE, GLWELayout, LWE, LWEInfos, Rank, TorusPrecision, prepared::LWESwitchingKeyPrepared}; impl LWE> { - pub fn keyswitch_scratch_space( + pub fn keyswitch_tmp_bytes( module: &Module, out_infos: &OUT, in_infos: &IN, @@ -50,7 +50,7 @@ impl LWE> { let glwe_in: usize = GLWE::bytes_of_from_infos(module, &glwe_in_infos); let glwe_out: usize = GLWE::bytes_of_from_infos(module, &glwe_out_infos); - let ks: usize = GLWE::keyswitch_scratch_space(module, &glwe_out_infos, &glwe_in_infos, key_infos); + let ks: usize = GLWE::keyswitch_tmp_bytes(module, &glwe_out_infos, &glwe_in_infos, key_infos); glwe_in + glwe_out + ks } @@ -84,7 +84,7 @@ impl LWE { { assert!(self.n() <= module.n() as u32); assert!(a.n() <= module.n() as u32); - assert!(scratch.available() >= LWE::keyswitch_scratch_space(module, self, a, ksk)); + assert!(scratch.available() >= LWE::keyswitch_tmp_bytes(module, self, a, ksk)); } let max_k: TorusPrecision = self.k().max(a.k()); diff --git a/poulpy-core/src/lib.rs b/poulpy-core/src/lib.rs index b0c5a38..15e6c76 100644 --- a/poulpy-core/src/lib.rs +++ b/poulpy-core/src/lib.rs @@ -14,6 +14,7 @@ mod utils; pub use operations::*; pub mod layouts; +pub use conversion::*; pub use dist::*; pub use external_product::*; pub use glwe_packing::*; diff --git a/poulpy-core/src/noise/gglwe_ct.rs b/poulpy-core/src/noise/gglwe_ct.rs index c924192..516a210 100644 --- a/poulpy-core/src/noise/gglwe_ct.rs +++ b/poulpy-core/src/noise/gglwe_ct.rs @@ -35,7 +35,7 @@ impl GGLWE { let dsize: usize = self.dsize().into(); let base2k: usize = self.base2k().into(); - let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWE::decrypt_scratch_space(module, self)); + let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWE::decrypt_tmp_bytes(module, self)); let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(module, self); (0..self.rank_in().into()).for_each(|col_i| { diff --git a/poulpy-core/src/noise/ggsw_ct.rs b/poulpy-core/src/noise/ggsw_ct.rs index 4f4c4f4..92f806b 100644 --- a/poulpy-core/src/noise/ggsw_ct.rs +++ b/poulpy-core/src/noise/ggsw_ct.rs @@ -48,7 +48,7 @@ impl GGSW { let mut pt_big: VecZnxBig, B> = module.vec_znx_big_alloc(1, self.size()); let mut scratch: ScratchOwned = - ScratchOwned::alloc(GLWE::decrypt_scratch_space(module, self) | module.vec_znx_normalize_tmp_bytes()); + ScratchOwned::alloc(GLWE::decrypt_tmp_bytes(module, self) | module.vec_znx_normalize_tmp_bytes()); (0..(self.rank() + 1).into()).for_each(|col_j| { (0..self.dnum().into()).for_each(|row_i| { @@ -120,7 +120,7 @@ impl GGSW { let mut pt_big: VecZnxBig, B> = module.vec_znx_big_alloc(1, self.size()); let mut scratch: ScratchOwned = - ScratchOwned::alloc(GLWE::decrypt_scratch_space(module, self) | module.vec_znx_normalize_tmp_bytes()); + ScratchOwned::alloc(GLWE::decrypt_tmp_bytes(module, self) | module.vec_znx_normalize_tmp_bytes()); (0..(self.rank() + 1).into()).for_each(|col_j| { (0..self.dnum().into()).for_each(|row_i| { diff --git a/poulpy-core/src/noise/glwe_ct.rs b/poulpy-core/src/noise/glwe_ct.rs index 3242d91..40b86d9 100644 --- a/poulpy-core/src/noise/glwe_ct.rs +++ b/poulpy-core/src/noise/glwe_ct.rs @@ -61,7 +61,7 @@ impl GLWE { + VecZnxNormalizeInplace, B: Backend + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, { - let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWE::decrypt_scratch_space(module, self)); + let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWE::decrypt_tmp_bytes(module, self)); let noise_have: f64 = self.noise(module, sk_prepared, pt_want, scratch.borrow()); assert!(noise_have <= max_noise, "{noise_have} {max_noise}"); } diff --git a/poulpy-core/src/operations/glwe.rs b/poulpy-core/src/operations/glwe.rs index ea4cf9d..b8b32ce 100644 --- a/poulpy-core/src/operations/glwe.rs +++ b/poulpy-core/src/operations/glwe.rs @@ -246,10 +246,10 @@ pub trait GLWEOperations: GLWEToMut + GLWEInfos + SetGLWEInfos + Sized { }); } - fn copy(&mut self, module: &Module, a: &A) + fn copy(&mut self, module: &M, a: &A) where A: GLWEToRef + GLWEInfos, - Module: VecZnxCopy, + M: VecZnxCopy, { #[cfg(debug_assertions)] { @@ -319,8 +319,8 @@ pub trait GLWEOperations: GLWEToMut + GLWEInfos + SetGLWEInfos + Sized { } impl GLWE> { - pub fn rsh_scratch_space(n: usize) -> usize { - VecZnx::rsh_scratch_space(n) + pub fn rsh_tmp_bytes(n: usize) -> usize { + VecZnx::rsh_tmp_bytes(n) } } diff --git a/poulpy-core/src/tests/test_suite/automorphism/gglwe_atk.rs b/poulpy-core/src/tests/test_suite/automorphism/gglwe_atk.rs index b98bfaf..37b5403 100644 --- a/poulpy-core/src/tests/test_suite/automorphism/gglwe_atk.rs +++ b/poulpy-core/src/tests/test_suite/automorphism/gglwe_atk.rs @@ -120,9 +120,9 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - AutomorphismKey::encrypt_sk_scratch_space(module, &auto_key_in_infos) - | AutomorphismKey::encrypt_sk_scratch_space(module, &auto_key_apply_infos) - | AutomorphismKey::automorphism_scratch_space( + AutomorphismKey::encrypt_sk_tmp_bytes(module, &auto_key_in_infos) + | AutomorphismKey::encrypt_sk_tmp_bytes(module, &auto_key_apply_infos) + | AutomorphismKey::automorphism_tmp_bytes( module, &auto_key_out_infos, &auto_key_in_infos, @@ -319,9 +319,9 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - AutomorphismKey::encrypt_sk_scratch_space(module, &auto_key) - | AutomorphismKey::encrypt_sk_scratch_space(module, &auto_key_apply) - | AutomorphismKey::automorphism_inplace_scratch_space(module, &auto_key, &auto_key_apply), + AutomorphismKey::encrypt_sk_tmp_bytes(module, &auto_key) + | AutomorphismKey::encrypt_sk_tmp_bytes(module, &auto_key_apply) + | AutomorphismKey::automorphism_inplace_tmp_bytes(module, &auto_key, &auto_key_apply), ); let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&auto_key); diff --git a/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs b/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs index 17f182f..7c2a427 100644 --- a/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs +++ b/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs @@ -139,10 +139,10 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSW::encrypt_sk_scratch_space(module, &ct_in) - | AutomorphismKey::encrypt_sk_scratch_space(module, &auto_key) - | TensorKey::encrypt_sk_scratch_space(module, &tensor_key) - | GGSW::automorphism_scratch_space(module, &ct_out, &ct_in, &auto_key, &tensor_key), + GGSW::encrypt_sk_tmp_bytes(module, &ct_in) + | AutomorphismKey::encrypt_sk_tmp_bytes(module, &auto_key) + | TensorKey::encrypt_sk_tmp_bytes(module, &tensor_key) + | GGSW::automorphism_tmp_bytes(module, &ct_out, &ct_in, &auto_key, &tensor_key), ); let var_xs: f64 = 0.5; @@ -319,10 +319,10 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSW::encrypt_sk_scratch_space(module, &ct) - | AutomorphismKey::encrypt_sk_scratch_space(module, &auto_key) - | TensorKey::encrypt_sk_scratch_space(module, &tensor_key) - | GGSW::automorphism_inplace_scratch_space(module, &ct, &auto_key, &tensor_key), + GGSW::encrypt_sk_tmp_bytes(module, &ct) + | AutomorphismKey::encrypt_sk_tmp_bytes(module, &auto_key) + | TensorKey::encrypt_sk_tmp_bytes(module, &tensor_key) + | GGSW::automorphism_inplace_tmp_bytes(module, &ct, &auto_key, &tensor_key), ); let var_xs: f64 = 0.5; diff --git a/poulpy-core/src/tests/test_suite/automorphism/glwe_ct.rs b/poulpy-core/src/tests/test_suite/automorphism/glwe_ct.rs index 49643fe..02afcb8 100644 --- a/poulpy-core/src/tests/test_suite/automorphism/glwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/automorphism/glwe_ct.rs @@ -112,10 +112,10 @@ where module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::alloc( - AutomorphismKey::encrypt_sk_scratch_space(module, &autokey) - | GLWE::decrypt_scratch_space(module, &ct_out) - | GLWE::encrypt_sk_scratch_space(module, &ct_in) - | GLWE::automorphism_scratch_space(module, &ct_out, &ct_in, &autokey), + AutomorphismKey::encrypt_sk_tmp_bytes(module, &autokey) + | GLWE::decrypt_tmp_bytes(module, &ct_out) + | GLWE::encrypt_sk_tmp_bytes(module, &ct_in) + | GLWE::automorphism_tmp_bytes(module, &ct_out, &ct_in, &autokey), ); let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&ct_out); @@ -246,10 +246,10 @@ where module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::alloc( - AutomorphismKey::encrypt_sk_scratch_space(module, &autokey) - | GLWE::decrypt_scratch_space(module, &ct) - | GLWE::encrypt_sk_scratch_space(module, &ct) - | GLWE::automorphism_inplace_scratch_space(module, &ct, &autokey), + AutomorphismKey::encrypt_sk_tmp_bytes(module, &autokey) + | GLWE::decrypt_tmp_bytes(module, &ct) + | GLWE::encrypt_sk_tmp_bytes(module, &ct) + | GLWE::automorphism_inplace_tmp_bytes(module, &ct, &autokey), ); let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&ct); diff --git a/poulpy-core/src/tests/test_suite/conversion.rs b/poulpy-core/src/tests/test_suite/conversion.rs index 94d922e..972e0c6 100644 --- a/poulpy-core/src/tests/test_suite/conversion.rs +++ b/poulpy-core/src/tests/test_suite/conversion.rs @@ -96,9 +96,9 @@ where }; let mut scratch: ScratchOwned = ScratchOwned::alloc( - LWEToGLWESwitchingKey::encrypt_sk_scratch_space(module, &lwe_to_glwe_infos) - | GLWE::from_lwe_scratch_space(module, &glwe_infos, &lwe_infos, &lwe_to_glwe_infos) - | GLWE::decrypt_scratch_space(module, &glwe_infos), + LWEToGLWESwitchingKey::encrypt_sk_tmp_bytes(module, &lwe_to_glwe_infos) + | GLWE::from_lwe_tmp_bytes(module, &glwe_infos, &lwe_infos, &lwe_to_glwe_infos) + | GLWE::decrypt_tmp_bytes(module, &glwe_infos), ); let mut sk_glwe: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_infos); @@ -213,9 +213,9 @@ where let mut source_xe: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWEToLWESwitchingKey::encrypt_sk_scratch_space(module, &glwe_to_lwe_infos) - | LWE::from_glwe_scratch_space(module, &lwe_infos, &glwe_infos, &glwe_to_lwe_infos) - | GLWE::decrypt_scratch_space(module, &glwe_infos), + GLWEToLWESwitchingKey::encrypt_sk_tmp_bytes(module, &glwe_to_lwe_infos) + | LWE::from_glwe_tmp_bytes(module, &lwe_infos, &glwe_infos, &glwe_to_lwe_infos) + | GLWE::decrypt_tmp_bytes(module, &glwe_infos), ); let mut sk_glwe: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_infos); diff --git a/poulpy-core/src/tests/test_suite/encryption/gglwe_atk.rs b/poulpy-core/src/tests/test_suite/encryption/gglwe_atk.rs index 68fca13..bd164dc 100644 --- a/poulpy-core/src/tests/test_suite/encryption/gglwe_atk.rs +++ b/poulpy-core/src/tests/test_suite/encryption/gglwe_atk.rs @@ -90,7 +90,7 @@ where let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc(AutomorphismKey::encrypt_sk_scratch_space( + let mut scratch: ScratchOwned = ScratchOwned::alloc(AutomorphismKey::encrypt_sk_tmp_bytes( module, &atk_infos, )); @@ -192,7 +192,7 @@ where let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc(AutomorphismKey::encrypt_sk_scratch_space( + let mut scratch: ScratchOwned = ScratchOwned::alloc(AutomorphismKey::encrypt_sk_tmp_bytes( module, &atk_infos, )); diff --git a/poulpy-core/src/tests/test_suite/encryption/gglwe_ct.rs b/poulpy-core/src/tests/test_suite/encryption/gglwe_ct.rs index e3450f9..5e09781 100644 --- a/poulpy-core/src/tests/test_suite/encryption/gglwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/encryption/gglwe_ct.rs @@ -87,7 +87,7 @@ where let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWESwitchingKey::encrypt_sk_scratch_space( + let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWESwitchingKey::encrypt_sk_tmp_bytes( module, &gglwe_infos, )); @@ -179,7 +179,7 @@ where let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWESwitchingKeyCompressed::encrypt_sk_scratch_space( + let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWESwitchingKeyCompressed::encrypt_sk_tmp_bytes( module, &gglwe_infos, )); diff --git a/poulpy-core/src/tests/test_suite/encryption/ggsw_ct.rs b/poulpy-core/src/tests/test_suite/encryption/ggsw_ct.rs index c419bb4..7620709 100644 --- a/poulpy-core/src/tests/test_suite/encryption/ggsw_ct.rs +++ b/poulpy-core/src/tests/test_suite/encryption/ggsw_ct.rs @@ -55,7 +55,7 @@ where pt_scalar.fill_ternary_hw(0, n, &mut source_xs); - let mut scratch: ScratchOwned = ScratchOwned::alloc(GGSW::encrypt_sk_scratch_space(module, &ggsw_infos)); + let mut scratch: ScratchOwned = ScratchOwned::alloc(GGSW::encrypt_sk_tmp_bytes(module, &ggsw_infos)); let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&ggsw_infos); sk.fill_ternary_prob(0.5, &mut source_xs); @@ -144,7 +144,7 @@ where pt_scalar.fill_ternary_hw(0, n, &mut source_xs); - let mut scratch: ScratchOwned = ScratchOwned::alloc(GGSWCompressed::encrypt_sk_scratch_space( + let mut scratch: ScratchOwned = ScratchOwned::alloc(GGSWCompressed::encrypt_sk_tmp_bytes( module, &ggsw_infos, )); diff --git a/poulpy-core/src/tests/test_suite/encryption/glwe_ct.rs b/poulpy-core/src/tests/test_suite/encryption/glwe_ct.rs index 9094c33..ed2c1e3 100644 --- a/poulpy-core/src/tests/test_suite/encryption/glwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/encryption/glwe_ct.rs @@ -85,7 +85,7 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWE::encrypt_sk_scratch_space(module, &glwe_infos) | GLWE::decrypt_scratch_space(module, &glwe_infos), + GLWE::encrypt_sk_tmp_bytes(module, &glwe_infos) | GLWE::decrypt_tmp_bytes(module, &glwe_infos), ); let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_infos); @@ -178,7 +178,7 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECompressed::encrypt_sk_scratch_space(module, &glwe_infos) | GLWE::decrypt_scratch_space(module, &glwe_infos), + GLWECompressed::encrypt_sk_tmp_bytes(module, &glwe_infos) | GLWE::decrypt_tmp_bytes(module, &glwe_infos), ); let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_infos); @@ -269,7 +269,7 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWE::decrypt_scratch_space(module, &glwe_infos) | GLWE::encrypt_sk_scratch_space(module, &glwe_infos), + GLWE::decrypt_tmp_bytes(module, &glwe_infos) | GLWE::encrypt_sk_tmp_bytes(module, &glwe_infos), ); let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_infos); @@ -349,9 +349,9 @@ where let mut source_xu: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWE::encrypt_sk_scratch_space(module, &glwe_infos) - | GLWE::decrypt_scratch_space(module, &glwe_infos) - | GLWE::encrypt_pk_scratch_space(module, &glwe_infos), + GLWE::encrypt_sk_tmp_bytes(module, &glwe_infos) + | GLWE::decrypt_tmp_bytes(module, &glwe_infos) + | GLWE::encrypt_pk_tmp_bytes(module, &glwe_infos), ); let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_infos); diff --git a/poulpy-core/src/tests/test_suite/encryption/glwe_tsk.rs b/poulpy-core/src/tests/test_suite/encryption/glwe_tsk.rs index 1f369db..da05ba1 100644 --- a/poulpy-core/src/tests/test_suite/encryption/glwe_tsk.rs +++ b/poulpy-core/src/tests/test_suite/encryption/glwe_tsk.rs @@ -86,7 +86,7 @@ where let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc(TensorKey::encrypt_sk_scratch_space( + let mut scratch: ScratchOwned = ScratchOwned::alloc(TensorKey::encrypt_sk_tmp_bytes( module, &tensor_key_infos, )); @@ -204,7 +204,7 @@ where let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc(TensorKeyCompressed::encrypt_sk_scratch_space( + let mut scratch: ScratchOwned = ScratchOwned::alloc(TensorKeyCompressed::encrypt_sk_tmp_bytes( module, &tensor_key_infos, )); diff --git a/poulpy-core/src/tests/test_suite/external_product/gglwe_ksk.rs b/poulpy-core/src/tests/test_suite/external_product/gglwe_ksk.rs index f10643b..a369978 100644 --- a/poulpy-core/src/tests/test_suite/external_product/gglwe_ksk.rs +++ b/poulpy-core/src/tests/test_suite/external_product/gglwe_ksk.rs @@ -121,14 +121,14 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWESwitchingKey::encrypt_sk_scratch_space(module, &gglwe_in_infos) - | GLWESwitchingKey::external_product_scratch_space( + GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &gglwe_in_infos) + | GLWESwitchingKey::external_product_tmp_bytes( module, &gglwe_out_infos, &gglwe_in_infos, &ggsw_infos, ) - | GGSW::encrypt_sk_scratch_space(module, &ggsw_infos), + | GGSW::encrypt_sk_tmp_bytes(module, &ggsw_infos), ); let r: usize = 1; @@ -292,9 +292,9 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWESwitchingKey::encrypt_sk_scratch_space(module, &gglwe_out_infos) - | GLWESwitchingKey::external_product_inplace_scratch_space(module, &gglwe_out_infos, &ggsw_infos) - | GGSW::encrypt_sk_scratch_space(module, &ggsw_infos), + GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &gglwe_out_infos) + | GLWESwitchingKey::external_product_inplace_tmp_bytes(module, &gglwe_out_infos, &ggsw_infos) + | GGSW::encrypt_sk_tmp_bytes(module, &ggsw_infos), ); let r: usize = 1; diff --git a/poulpy-core/src/tests/test_suite/external_product/ggsw_ct.rs b/poulpy-core/src/tests/test_suite/external_product/ggsw_ct.rs index 8d62b19..84a2f68 100644 --- a/poulpy-core/src/tests/test_suite/external_product/ggsw_ct.rs +++ b/poulpy-core/src/tests/test_suite/external_product/ggsw_ct.rs @@ -128,9 +128,9 @@ where pt_apply.to_mut().raw_mut()[k] = 1; //X^{k} let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSW::encrypt_sk_scratch_space(module, &ggsw_apply_infos) - | GGSW::encrypt_sk_scratch_space(module, &ggsw_in_infos) - | GGSW::external_product_scratch_space(module, &ggsw_out_infos, &ggsw_in_infos, &ggsw_apply_infos), + GGSW::encrypt_sk_tmp_bytes(module, &ggsw_apply_infos) + | GGSW::encrypt_sk_tmp_bytes(module, &ggsw_in_infos) + | GGSW::external_product_tmp_bytes(module, &ggsw_out_infos, &ggsw_in_infos, &ggsw_apply_infos), ); let mut sk: GLWESecret> = GLWESecret::alloc(n.into(), rank.into()); @@ -282,9 +282,9 @@ where pt_apply.to_mut().raw_mut()[k] = 1; //X^{k} let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSW::encrypt_sk_scratch_space(module, &ggsw_apply_infos) - | GGSW::encrypt_sk_scratch_space(module, &ggsw_out_infos) - | GGSW::external_product_inplace_scratch_space(module, &ggsw_out_infos, &ggsw_apply_infos), + GGSW::encrypt_sk_tmp_bytes(module, &ggsw_apply_infos) + | GGSW::encrypt_sk_tmp_bytes(module, &ggsw_out_infos) + | GGSW::external_product_inplace_tmp_bytes(module, &ggsw_out_infos, &ggsw_apply_infos), ); let mut sk: GLWESecret> = GLWESecret::alloc(n.into(), rank.into()); diff --git a/poulpy-core/src/tests/test_suite/external_product/glwe_ct.rs b/poulpy-core/src/tests/test_suite/external_product/glwe_ct.rs index 08f9695..60026b8 100644 --- a/poulpy-core/src/tests/test_suite/external_product/glwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/external_product/glwe_ct.rs @@ -116,9 +116,9 @@ where pt_ggsw.raw_mut()[k] = 1; // X^{k} let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSW::encrypt_sk_scratch_space(module, &ggsw_apply_infos) - | GLWE::encrypt_sk_scratch_space(module, &glwe_in_infos) - | GLWE::external_product_scratch_space(module, &glwe_out_infos, &glwe_in_infos, &ggsw_apply_infos), + GGSW::encrypt_sk_tmp_bytes(module, &ggsw_apply_infos) + | GLWE::encrypt_sk_tmp_bytes(module, &glwe_in_infos) + | GLWE::external_product_tmp_bytes(module, &glwe_out_infos, &glwe_in_infos, &ggsw_apply_infos), ); let mut sk: GLWESecret> = GLWESecret::alloc(n.into(), rank.into()); @@ -259,9 +259,9 @@ where pt_ggsw.raw_mut()[k] = 1; // X^{k} let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSW::encrypt_sk_scratch_space(module, &ggsw_apply_infos) - | GLWE::encrypt_sk_scratch_space(module, &glwe_out_infos) - | GLWE::external_product_inplace_scratch_space(module, &glwe_out_infos, &ggsw_apply_infos), + GGSW::encrypt_sk_tmp_bytes(module, &ggsw_apply_infos) + | GLWE::encrypt_sk_tmp_bytes(module, &glwe_out_infos) + | GLWE::external_product_inplace_tmp_bytes(module, &glwe_out_infos, &ggsw_apply_infos), ); let mut sk: GLWESecret> = GLWESecret::alloc(n.into(), rank.into()); diff --git a/poulpy-core/src/tests/test_suite/keyswitch/gglwe_ct.rs b/poulpy-core/src/tests/test_suite/keyswitch/gglwe_ct.rs index 90c179b..c667848 100644 --- a/poulpy-core/src/tests/test_suite/keyswitch/gglwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/keyswitch/gglwe_ct.rs @@ -119,11 +119,11 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch_enc: ScratchOwned = ScratchOwned::alloc( - GLWESwitchingKey::encrypt_sk_scratch_space(module, &gglwe_s0s1_infos) - | GLWESwitchingKey::encrypt_sk_scratch_space(module, &gglwe_s1s2_infos) - | GLWESwitchingKey::encrypt_sk_scratch_space(module, &gglwe_s0s2_infos), + GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &gglwe_s0s1_infos) + | GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &gglwe_s1s2_infos) + | GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &gglwe_s0s2_infos), ); - let mut scratch_apply: ScratchOwned = ScratchOwned::alloc(GLWESwitchingKey::keyswitch_scratch_space( + let mut scratch_apply: ScratchOwned = ScratchOwned::alloc(GLWESwitchingKey::keyswitch_tmp_bytes( module, &gglwe_s0s1_infos, &gglwe_s0s2_infos, @@ -274,10 +274,10 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch_enc: ScratchOwned = ScratchOwned::alloc( - GLWESwitchingKey::encrypt_sk_scratch_space(module, &gglwe_s0s1_infos) - | GLWESwitchingKey::encrypt_sk_scratch_space(module, &gglwe_s1s2_infos), + GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &gglwe_s0s1_infos) + | GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &gglwe_s1s2_infos), ); - let mut scratch_apply: ScratchOwned = ScratchOwned::alloc(GLWESwitchingKey::keyswitch_inplace_scratch_space( + let mut scratch_apply: ScratchOwned = ScratchOwned::alloc(GLWESwitchingKey::keyswitch_inplace_tmp_bytes( module, &gglwe_s0s1_infos, &gglwe_s1s2_infos, diff --git a/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs b/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs index 6d74811..4d3d556 100644 --- a/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs +++ b/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs @@ -132,10 +132,10 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSW::encrypt_sk_scratch_space(module, &ggsw_in_infos) - | GLWESwitchingKey::encrypt_sk_scratch_space(module, &ksk_apply_infos) - | TensorKey::encrypt_sk_scratch_space(module, &tsk_infos) - | GGSW::keyswitch_scratch_space( + GGSW::encrypt_sk_tmp_bytes(module, &ggsw_in_infos) + | GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &ksk_apply_infos) + | TensorKey::encrypt_sk_tmp_bytes(module, &tsk_infos) + | GGSW::keyswitch_tmp_bytes( module, &ggsw_out_infos, &ggsw_in_infos, @@ -310,10 +310,10 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSW::encrypt_sk_scratch_space(module, &ggsw_out_infos) - | GLWESwitchingKey::encrypt_sk_scratch_space(module, &ksk_apply_infos) - | TensorKey::encrypt_sk_scratch_space(module, &tsk_infos) - | GGSW::keyswitch_inplace_scratch_space(module, &ggsw_out_infos, &ksk_apply_infos, &tsk_infos), + GGSW::encrypt_sk_tmp_bytes(module, &ggsw_out_infos) + | GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &ksk_apply_infos) + | TensorKey::encrypt_sk_tmp_bytes(module, &tsk_infos) + | GGSW::keyswitch_inplace_tmp_bytes(module, &ggsw_out_infos, &ksk_apply_infos, &tsk_infos), ); let var_xs: f64 = 0.5; diff --git a/poulpy-core/src/tests/test_suite/keyswitch/glwe_ct.rs b/poulpy-core/src/tests/test_suite/keyswitch/glwe_ct.rs index 7f629b8..2ea6e75 100644 --- a/poulpy-core/src/tests/test_suite/keyswitch/glwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/keyswitch/glwe_ct.rs @@ -112,9 +112,9 @@ where module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWESwitchingKey::encrypt_sk_scratch_space(module, &key_apply) - | GLWE::encrypt_sk_scratch_space(module, &glwe_in_infos) - | GLWE::keyswitch_scratch_space(module, &glwe_out_infos, &glwe_in_infos, &key_apply), + GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &key_apply) + | GLWE::encrypt_sk_tmp_bytes(module, &glwe_in_infos) + | GLWE::keyswitch_tmp_bytes(module, &glwe_out_infos, &glwe_in_infos, &key_apply), ); let mut sk_in: GLWESecret> = GLWESecret::alloc(n.into(), rank_in.into()); @@ -244,9 +244,9 @@ where module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWESwitchingKey::encrypt_sk_scratch_space(module, &key_apply_infos) - | GLWE::encrypt_sk_scratch_space(module, &glwe_out_infos) - | GLWE::keyswitch_inplace_scratch_space(module, &glwe_out_infos, &key_apply_infos), + GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &key_apply_infos) + | GLWE::encrypt_sk_tmp_bytes(module, &glwe_out_infos) + | GLWE::keyswitch_inplace_tmp_bytes(module, &glwe_out_infos, &key_apply_infos), ); let mut sk_in: GLWESecret> = GLWESecret::alloc(n.into(), rank.into()); diff --git a/poulpy-core/src/tests/test_suite/keyswitch/lwe_ct.rs b/poulpy-core/src/tests/test_suite/keyswitch/lwe_ct.rs index b833209..130c356 100644 --- a/poulpy-core/src/tests/test_suite/keyswitch/lwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/keyswitch/lwe_ct.rs @@ -99,8 +99,8 @@ where }; let mut scratch: ScratchOwned = ScratchOwned::alloc( - LWESwitchingKey::encrypt_sk_scratch_space(module, &key_apply_infos) - | LWE::keyswitch_scratch_space(module, &lwe_out_infos, &lwe_in_infos, &key_apply_infos), + LWESwitchingKey::encrypt_sk_tmp_bytes(module, &key_apply_infos) + | LWE::keyswitch_tmp_bytes(module, &lwe_out_infos, &lwe_in_infos, &key_apply_infos), ); let mut sk_lwe_in: LWESecret> = LWESecret::alloc(n_lwe_in.into()); diff --git a/poulpy-core/src/tests/test_suite/packing.rs b/poulpy-core/src/tests/test_suite/packing.rs index ba6d250..fdfbd57 100644 --- a/poulpy-core/src/tests/test_suite/packing.rs +++ b/poulpy-core/src/tests/test_suite/packing.rs @@ -105,9 +105,9 @@ where }; let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWE::encrypt_sk_scratch_space(module, &glwe_out_infos) - | AutomorphismKey::encrypt_sk_scratch_space(module, &key_infos) - | GLWEPacker::scratch_space(module, &glwe_out_infos, &key_infos), + GLWE::encrypt_sk_tmp_bytes(module, &glwe_out_infos) + | AutomorphismKey::encrypt_sk_tmp_bytes(module, &key_infos) + | GLWEPacker::tmp_bytes(module, &glwe_out_infos, &key_infos), ); let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_out_infos); diff --git a/poulpy-core/src/tests/test_suite/trace.rs b/poulpy-core/src/tests/test_suite/trace.rs index 2f5b5a5..932b401 100644 --- a/poulpy-core/src/tests/test_suite/trace.rs +++ b/poulpy-core/src/tests/test_suite/trace.rs @@ -107,10 +107,10 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWE::encrypt_sk_scratch_space(module, &glwe_out_infos) - | GLWE::decrypt_scratch_space(module, &glwe_out_infos) - | AutomorphismKey::encrypt_sk_scratch_space(module, &key_infos) - | GLWE::trace_inplace_scratch_space(module, &glwe_out_infos, &key_infos), + GLWE::encrypt_sk_tmp_bytes(module, &glwe_out_infos) + | GLWE::decrypt_tmp_bytes(module, &glwe_out_infos) + | AutomorphismKey::encrypt_sk_tmp_bytes(module, &key_infos) + | GLWE::trace_inplace_tmp_bytes(module, &glwe_out_infos, &key_infos), ); let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_out_infos); diff --git a/poulpy-hal/src/layouts/vec_znx.rs b/poulpy-hal/src/layouts/vec_znx.rs index c96de28..c084934 100644 --- a/poulpy-hal/src/layouts/vec_znx.rs +++ b/poulpy-hal/src/layouts/vec_znx.rs @@ -110,7 +110,7 @@ impl ZnxView for VecZnx { } impl VecZnx> { - pub fn rsh_scratch_space(n: usize) -> usize { + pub fn rsh_tmp_bytes(n: usize) -> usize { n * std::mem::size_of::() } } diff --git a/poulpy-hal/src/layouts/zn.rs b/poulpy-hal/src/layouts/zn.rs index eaba4b4..00f8067 100644 --- a/poulpy-hal/src/layouts/zn.rs +++ b/poulpy-hal/src/layouts/zn.rs @@ -98,7 +98,7 @@ impl ZnxView for Zn { } impl Zn> { - pub fn rsh_scratch_space(n: usize) -> usize { + pub fn rsh_tmp_bytes(n: usize) -> usize { n * std::mem::size_of::() } } diff --git a/poulpy-schemes/src/tfhe/blind_rotation/cggi_algo.rs b/poulpy-schemes/src/tfhe/blind_rotation/cggi_algo.rs index 1c2ebff..9a33e5e 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/cggi_algo.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/cggi_algo.rs @@ -21,7 +21,7 @@ use crate::tfhe::blind_rotation::{ }; #[allow(clippy::too_many_arguments)] -pub fn cggi_blind_rotate_scratch_space( +pub fn cggi_blind_rotate_tmp_bytes( module: &Module, block_size: usize, extension_factor: usize, @@ -61,7 +61,7 @@ where + vmp_xai + (vmp | (acc_big + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_apply_tmp_bytes()))) } else { - GLWE::bytes_of(glwe_infos) + GLWE::external_product_inplace_scratch_space(module, glwe_infos, brk_infos) + GLWE::bytes_of(glwe_infos) + GLWE::external_product_inplace_tmp_bytes(module, glwe_infos, brk_infos) } } diff --git a/poulpy-schemes/src/tfhe/blind_rotation/cggi_key.rs b/poulpy-schemes/src/tfhe/blind_rotation/cggi_key.rs index 26c9721..3114ea1 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/cggi_key.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/cggi_key.rs @@ -43,12 +43,12 @@ impl BlindRotationKeyAlloc for BlindRotationKey, CGGI> { } impl BlindRotationKey, CGGI> { - pub fn generate_from_sk_scratch_space(module: &Module, infos: &A) -> usize + pub fn generate_from_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GGSWInfos, Module: VecZnxNormalizeTmpBytes + VecZnxDftBytesOf, { - GGSW::encrypt_sk_scratch_space(module, infos) + GGSW::encrypt_sk_tmp_bytes(module, infos) } } @@ -145,12 +145,12 @@ impl BlindRotationKeyCompressed, CGGI> { } } - pub fn generate_from_sk_scratch_space(module: &Module, infos: &A) -> usize + pub fn generate_from_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GGSWInfos, Module: VecZnxNormalizeTmpBytes + VecZnxDftBytesOf, { - GGSWCompressed::encrypt_sk_scratch_space(module, infos) + GGSWCompressed::encrypt_sk_tmp_bytes(module, infos) } } diff --git a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs index 9148df3..9cae713 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs @@ -19,7 +19,7 @@ use poulpy_hal::{ use crate::tfhe::blind_rotation::{ BlincRotationExecute, BlindRotationKey, BlindRotationKeyAlloc, BlindRotationKeyEncryptSk, BlindRotationKeyLayout, - BlindRotationKeyPrepared, CGGI, LookUpTable, cggi_blind_rotate_scratch_space, mod_switch_2n, + BlindRotationKeyPrepared, CGGI, LookUpTable, cggi_blind_rotate_tmp_bytes, mod_switch_2n, }; use poulpy_core::layouts::{ @@ -123,7 +123,7 @@ where base2k: base2k.into(), }; - let mut scratch: ScratchOwned = ScratchOwned::::alloc(BlindRotationKey::generate_from_sk_scratch_space( + let mut scratch: ScratchOwned = ScratchOwned::::alloc(BlindRotationKey::generate_from_sk_tmp_bytes( module, &brk_infos, )); @@ -134,7 +134,7 @@ where let mut sk_lwe: LWESecret> = LWESecret::alloc(n_lwe.into()); sk_lwe.fill_binary_block(block_size, &mut source_xs); - let mut scratch_br: ScratchOwned = ScratchOwned::::alloc(cggi_blind_rotate_scratch_space( + let mut scratch_br: ScratchOwned = ScratchOwned::::alloc(cggi_blind_rotate_tmp_bytes( module, block_size, extension_factor,