diff --git a/poulpy-core/src/encryption/compressed/gglwe_atk.rs b/poulpy-core/src/encryption/compressed/gglwe_atk.rs index f0afcae..1ee58a6 100644 --- a/poulpy-core/src/encryption/compressed/gglwe_atk.rs +++ b/poulpy-core/src/encryption/compressed/gglwe_atk.rs @@ -1,10 +1,11 @@ use poulpy_hal::{ - api::{ScratchAvailable, SvpPPolBytesOf, VecZnxAutomorphism, VecZnxDftBytesOf, VecZnxNormalizeTmpBytes}, + api::{ModuleN, ScratchAvailable, ScratchTakeBasic, SvpPPolAlloc, SvpPPolBytesOf, VecZnxAutomorphism, VecZnxDftBytesOf, VecZnxNormalizeTmpBytes}, layouts::{Backend, DataMut, DataRef, Module, Scratch}, source::Source, }; use crate::{ + ScratchTakeCore, encryption::compressed::gglwe_ksk::GGLWEKeyCompressedEncryptSk, layouts::{ GGLWEInfos, GLWEInfos, GLWESecret, GLWESecretToRef, LWEInfos, @@ -16,10 +17,10 @@ impl AutomorphismKeyCompressed> { pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GGLWEInfos, - Module: VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes + SvpPPolBytesOf, + Module: ModuleN + SvpPPolAlloc + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes + SvpPPolBytesOf, { assert_eq!(module.n() as u32, infos.n()); - GLWESwitchingKeyCompressed::encrypt_sk_tmp_bytes(module, infos) + GLWESecret::bytes_of(infos.n(), infos.rank_out()) + GLWESwitchingKeyCompressed::encrypt_sk_tmp_bytes(module, infos) + GLWESecret::bytes_of(module, infos.rank_out()) } } @@ -39,8 +40,14 @@ pub trait GGLWEAutomorphismKeyCompressedEncryptSk { impl GGLWEAutomorphismKeyCompressedEncryptSk for Module where - Module: GGLWEKeyCompressedEncryptSk + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + SvpPPolBytesOf + VecZnxAutomorphism, - Scratch: ScratchAvailable, + Module: ModuleN + + GGLWEKeyCompressedEncryptSk + + VecZnxNormalizeTmpBytes + + VecZnxDftBytesOf + + SvpPPolBytesOf + + VecZnxAutomorphism + + SvpPPolAlloc, + Scratch: ScratchAvailable + ScratchTakeBasic + ScratchTakeCore, { fn gglwe_automorphism_key_compressed_encrypt_sk( &self, @@ -70,7 +77,7 @@ where ) } - let (mut sk_out, scratch_1) = scratch.take_glwe_secret(sk.n(), sk.rank()); + let (mut sk_out, scratch_1) = scratch.take_glwe_secret(self, sk.rank()); { (0..res.rank_out().into()).for_each(|i| { diff --git a/poulpy-core/src/encryption/compressed/gglwe_ct.rs b/poulpy-core/src/encryption/compressed/gglwe_ct.rs index b67dc88..cca081f 100644 --- a/poulpy-core/src/encryption/compressed/gglwe_ct.rs +++ b/poulpy-core/src/encryption/compressed/gglwe_ct.rs @@ -1,6 +1,6 @@ use poulpy_hal::{ api::{ - ScratchAvailable, VecZnxAddScalarInplace, VecZnxDftBytesOf, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, + ModuleN, ScratchAvailable, VecZnxAddScalarInplace, VecZnxDftBytesOf, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, ZnNormalizeInplace, }, layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxZero}, @@ -8,6 +8,7 @@ use poulpy_hal::{ }; use crate::{ + ScratchTakeCore, encryption::{SIGMA, glwe_ct::GLWEEncryptSkInternal}, layouts::{ GGLWE, GGLWEInfos, LWEInfos, @@ -60,13 +61,14 @@ pub trait GGLWECompressedEncryptSk { impl GGLWECompressedEncryptSk for Module where - Module: GLWEEncryptSkInternal + Module: ModuleN + + GLWEEncryptSkInternal + VecZnxNormalizeInplace + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxAddScalarInplace + ZnNormalizeInplace, - Scratch: ScratchAvailable, + Scratch: ScratchAvailable + ScratchTakeCore, { fn gglwe_compressed_encrypt_sk( &self, @@ -130,7 +132,7 @@ where let mut source_xa = Source::new(seed); - let (mut tmp_pt, scrach_1) = scratch.take_glwe_pt(res); + let (mut tmp_pt, scrach_1) = scratch.take_glwe_pt(self, res); (0..rank_in).for_each(|col_i| { (0..dnum).for_each(|d_i| { // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt diff --git a/poulpy-core/src/encryption/compressed/gglwe_ksk.rs b/poulpy-core/src/encryption/compressed/gglwe_ksk.rs index 93519b9..3dca15f 100644 --- a/poulpy-core/src/encryption/compressed/gglwe_ksk.rs +++ b/poulpy-core/src/encryption/compressed/gglwe_ksk.rs @@ -1,10 +1,11 @@ use poulpy_hal::{ - api::{ScratchAvailable, SvpPPolBytesOf, SvpPrepare, VecZnxDftBytesOf, VecZnxNormalizeTmpBytes, VecZnxSwitchRing}, + api::{ModuleN, ScratchAvailable, ScratchTakeBasic, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxDftBytesOf, VecZnxNormalizeTmpBytes, VecZnxSwitchRing}, layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch}, source::Source, }; use crate::{ + ScratchTakeCore, encryption::compressed::gglwe_ct::GGLWECompressedEncryptSk, layouts::{ GGLWE, GGLWEInfos, GLWEInfos, GLWESecret, GLWESecretToRef, LWEInfos, RingDegree, @@ -17,7 +18,7 @@ impl GLWESwitchingKeyCompressed> { pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GGLWEInfos, - Module: VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes + SvpPPolBytesOf, + Module: ModuleN + SvpPPolAlloc + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes + SvpPPolBytesOf, { (GGLWE::encrypt_sk_tmp_bytes(module, infos) | ScalarZnx::bytes_of(module.n(), 1)) + ScalarZnx::bytes_of(module.n(), infos.rank_in().into()) @@ -59,13 +60,15 @@ pub trait GGLWEKeyCompressedEncryptSk { impl GGLWEKeyCompressedEncryptSk for Module where - Module: GGLWECompressedEncryptSk + Module: ModuleN + + GGLWECompressedEncryptSk + SvpPPolBytesOf + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxSwitchRing - + SvpPrepare, - Scratch: ScratchAvailable, + + SvpPrepare + + SvpPPolAlloc, + Scratch: ScratchAvailable + ScratchTakeBasic + ScratchTakeCore, { fn gglwe_key_compressed_encrypt_sk( &self, @@ -100,7 +103,7 @@ where let n: usize = sk_in.n().max(sk_out.n()).into(); - let (mut sk_in_tmp, scratch_1) = scratch.take_scalar_znx(n, sk_in.rank().into()); + let (mut sk_in_tmp, scratch_1) = scratch.take_scalar_znx(self, sk_in.rank().into()); (0..sk_in.rank().into()).for_each(|i| { self.vec_znx_switch_ring( &mut sk_in_tmp.as_vec_znx_mut(), @@ -110,9 +113,9 @@ where ); }); - let (mut sk_out_tmp, scratch_2) = scratch_1.take_glwe_secret_prepared(RingDegree(n as u32), sk_out.rank()); + let (mut sk_out_tmp, scratch_2) = scratch_1.take_glwe_secret_prepared(self, sk_out.rank()); { - let (mut tmp, _) = scratch_2.take_scalar_znx(n, 1); + let (mut tmp, _) = scratch_2.take_scalar_znx(self, 1); (0..sk_out.rank().into()).for_each(|i| { self.vec_znx_switch_ring(&mut tmp.as_vec_znx_mut(), 0, &sk_out.data.as_vec_znx(), i); self.svp_prepare(&mut sk_out_tmp.data, i, &tmp, 0); diff --git a/poulpy-core/src/encryption/compressed/gglwe_tsk.rs b/poulpy-core/src/encryption/compressed/gglwe_tsk.rs index 2beaa4b..f3086df 100644 --- a/poulpy-core/src/encryption/compressed/gglwe_tsk.rs +++ b/poulpy-core/src/encryption/compressed/gglwe_tsk.rs @@ -1,16 +1,18 @@ use poulpy_hal::{ api::{ - SvpApplyDftToDft, SvpPPolBytesOf, SvpPrepare, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, + ModuleN, ScratchTakeBasic, SvpApplyDftToDft, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyTmpA, VecZnxNormalizeTmpBytes, }, + oep::{SvpPPolAllocBytesImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl}, layouts::{Backend, DataMut, DataRef, Module, Scratch}, source::Source, }; use crate::{ + ScratchTakeCore, encryption::compressed::gglwe_ksk::GGLWEKeyCompressedEncryptSk, layouts::{ - GGLWEInfos, GLWEInfos, GLWESecret, GLWESecretToRef, LWEInfos, Rank, TensorKey, + GetDist, GGLWEInfos, GLWEInfos, GLWESecret, GLWESecretToRef, LWEInfos, Rank, TensorKey, compressed::{TensorKeyCompressed, TensorKeyCompressedToMut}, }, }; @@ -19,7 +21,7 @@ impl TensorKeyCompressed> { pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GGLWEInfos, - Module: SvpPPolBytesOf + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes + VecZnxBigBytesOf, + Module: ModuleN + SvpPPolBytesOf + SvpPPolAlloc + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes + VecZnxBigBytesOf, { TensorKey::encrypt_sk_tmp_bytes(module, infos) } @@ -35,18 +37,25 @@ pub trait GGLWETensorKeyCompressedEncryptSk { scratch: &mut Scratch, ) where R: TensorKeyCompressedToMut, - S: GLWESecretToRef; + S: GLWESecretToRef + GetDist; } impl GGLWETensorKeyCompressedEncryptSk for Module where - Module: GGLWEKeyCompressedEncryptSk + Module: ModuleN + + GGLWEKeyCompressedEncryptSk + VecZnxDftApply + SvpApplyDftToDft + VecZnxIdftApplyTmpA + VecZnxBigNormalize - + SvpPrepare, - Scratch:, + + SvpPrepare + + SvpPPolAllocBytesImpl + + SvpPPolBytesOf + + VecZnxDftAllocBytesImpl + + VecZnxBigAllocBytesImpl + + VecZnxDftBytesOf + + VecZnxBigBytesOf, + Scratch: ScratchTakeBasic + ScratchTakeCore, { fn gglwe_tensor_key_encrypt_sk( &self, @@ -57,9 +66,13 @@ where scratch: &mut Scratch, ) where R: TensorKeyCompressedToMut, - S: GLWESecretToRef, + S: GLWESecretToRef + GetDist, { let res: &mut TensorKeyCompressed<&mut [u8]> = &mut res.to_mut(); + + let (mut sk_dft_prep, scratch_1) = scratch.take_glwe_secret_prepared(self, res.rank_out()); + sk_dft_prep.prepare(self, sk); + let sk: &GLWESecret<&[u8]> = &sk.to_ref(); #[cfg(debug_assertions)] @@ -68,21 +81,18 @@ where assert_eq!(res.n(), sk.n()); } - let n: usize = sk.n().into(); + // let n: usize = sk.n().into(); let rank: usize = res.rank_out().into(); - let (mut sk_dft_prep, scratch_1) = scratch.take_glwe_secret_prepared(sk.n(), res.rank_out()); - sk_dft_prep.prepare(self, sk, scratch_1); - - let (mut sk_dft, scratch_2) = scratch_1.take_vec_znx_dft(n, rank, 1); + let (mut sk_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, rank, 1); for i in 0..rank { self.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); } - let (mut sk_ij_big, scratch_3) = scratch_2.take_vec_znx_big(n, 1, 1); - let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(sk.n(), Rank(1)); - let (mut sk_ij_dft, scratch_5) = scratch_4.take_vec_znx_dft(n, 1, 1); + let (mut sk_ij_big, scratch_3) = scratch_2.take_vec_znx_big(self, 1, 1); + let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(self, Rank(1)); + let (mut sk_ij_dft, scratch_5) = scratch_4.take_vec_znx_dft(self, 1, 1); let mut source_xa: Source = Source::new(seed_xa); @@ -125,6 +135,7 @@ impl TensorKeyCompressed { source_xe: &mut Source, scratch: &mut Scratch, ) where + GLWESecret: GetDist, Module: GGLWETensorKeyCompressedEncryptSk, { module.gglwe_tensor_key_encrypt_sk(self, sk, seed_xa, source_xe, scratch); diff --git a/poulpy-core/src/encryption/compressed/ggsw_ct.rs b/poulpy-core/src/encryption/compressed/ggsw_ct.rs index 567f04f..6595e15 100644 --- a/poulpy-core/src/encryption/compressed/ggsw_ct.rs +++ b/poulpy-core/src/encryption/compressed/ggsw_ct.rs @@ -1,10 +1,11 @@ use poulpy_hal::{ - api::{VecZnxAddScalarInplace, VecZnxDftBytesOf, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes}, + api::{ModuleN, VecZnxAddScalarInplace, VecZnxDftBytesOf, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes}, layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxZero}, source::Source, }; use crate::{ + ScratchTakeCore, encryption::{SIGMA, glwe_ct::GLWEEncryptSkInternal}, layouts::{ GGSW, GGSWInfos, GLWEInfos, LWEInfos, @@ -40,8 +41,8 @@ pub trait GGSWCompressedEncryptSk { impl GGSWCompressedEncryptSk for Module where - Module: GLWEEncryptSkInternal + VecZnxAddScalarInplace + VecZnxNormalizeInplace, - Scratch:, + Module: ModuleN + GLWEEncryptSkInternal + VecZnxAddScalarInplace + VecZnxNormalizeInplace, + Scratch: ScratchTakeCore, { fn ggsw_compressed_encrypt_sk( &self, @@ -74,7 +75,7 @@ where let cols: usize = rank + 1; let dsize: usize = res.dsize().into(); - let (mut tmp_pt, scratch_1) = scratch.take_glwe_pt(&res.glwe_layout()); + let (mut tmp_pt, scratch_1) = scratch.take_glwe_pt(self, &res.glwe_layout()); let mut source = Source::new(seed_xa); diff --git a/poulpy-core/src/encryption/gglwe_atk.rs b/poulpy-core/src/encryption/gglwe_atk.rs index 6536c7e..972e06c 100644 --- a/poulpy-core/src/encryption/gglwe_atk.rs +++ b/poulpy-core/src/encryption/gglwe_atk.rs @@ -1,6 +1,6 @@ use poulpy_hal::{ api::{ - ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, + ModuleN, ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, @@ -9,15 +9,18 @@ use poulpy_hal::{ source::Source, }; -use crate::layouts::{ - AutomorphismKey, AutomorphismKeyToMut, GGLWEInfos, GLWEInfos, GLWESecret, GLWESecretToRef, GLWESwitchingKey, LWEInfos, +use crate::{ + ScratchTakeCore, + layouts::{ + AutomorphismKey, AutomorphismKeyToMut, GGLWEInfos, GLWEInfos, GLWESecret, GLWESecretToRef, GLWESwitchingKey, LWEInfos, + }, }; impl AutomorphismKey> { pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GGLWEInfos, - Module: SvpPPolBytesOf + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes, + Module: ModuleN + SvpPPolBytesOf + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes + SvpPPolAlloc, { assert_eq!( infos.rank_in(), @@ -76,7 +79,8 @@ where impl GGLWEAutomorphismKeyEncryptSk for Module where - Module: VecZnxAddScalarInplace + Module: ModuleN + + VecZnxAddScalarInplace + VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply @@ -93,8 +97,10 @@ where + SvpPrepare + VecZnxSwitchRing + SvpPPolBytesOf - + VecZnxAutomorphism, - Scratch: ScratchAvailable, + + VecZnxAutomorphism + + SvpPPolAlloc + + SvpPPolBytesOf, + Scratch: ScratchAvailable + ScratchTakeCore, { fn gglwe_automorphism_key_encrypt_sk( &self, @@ -126,7 +132,7 @@ where ) } - let (mut sk_out, scratch_1) = scratch.take_glwe_secret(sk.n(), sk.rank()); + let (mut sk_out, scratch_1) = scratch.take_glwe_secret(self, sk.rank()); { (0..res.rank_out().into()).for_each(|i| { diff --git a/poulpy-core/src/encryption/gglwe_ct.rs b/poulpy-core/src/encryption/gglwe_ct.rs index d333892..d0ab3a0 100644 --- a/poulpy-core/src/encryption/gglwe_ct.rs +++ b/poulpy-core/src/encryption/gglwe_ct.rs @@ -1,11 +1,11 @@ use poulpy_hal::{ - api::{ScratchAvailable, VecZnxAddScalarInplace, VecZnxDftBytesOf, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes}, + api::{ModuleN, ScratchAvailable, VecZnxAddScalarInplace, VecZnxDftBytesOf, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes}, layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxZero}, source::Source, }; use crate::{ - encryption::glwe_ct::GLWEEncryptSk, + encryption::glwe_ct::GLWEEncryptSk, layouts::GLWEInfos, ScratchTakeCore, layouts::{ GGLWE, GGLWEInfos, GGLWEToMut, GLWE, GLWEPlaintext, LWEInfos, prepared::{GLWESecretPrepared, GLWESecretPreparedToRef}, @@ -47,8 +47,8 @@ pub trait GGLWEEncryptSk { impl GGLWEEncryptSk for Module where - Module: GLWEEncryptSk + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxAddScalarInplace + VecZnxNormalizeInplace, - Scratch: ScratchAvailable, + Module: ModuleN + GLWEEncryptSk + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxAddScalarInplace + VecZnxNormalizeInplace, + Scratch: ScratchAvailable + ScratchTakeCore, { fn gglwe_encrypt_sk( &self, @@ -111,7 +111,7 @@ where let base2k: usize = res.base2k().into(); let rank_in: usize = res.rank_in().into(); - let (mut tmp_pt, scrach_1) = scratch.take_glwe_pt(res); + let (mut tmp_pt, scrach_1) = scratch.take_glwe_pt(self, &res.glwe_layout()); // For each input column (i.e. rank) produces a GGLWE ciphertext of rank_out+1 columns // // Example for ksk rank 2 to rank 3: diff --git a/poulpy-core/src/encryption/gglwe_ksk.rs b/poulpy-core/src/encryption/gglwe_ksk.rs index ef9b5bf..c825c73 100644 --- a/poulpy-core/src/encryption/gglwe_ksk.rs +++ b/poulpy-core/src/encryption/gglwe_ksk.rs @@ -1,22 +1,26 @@ use poulpy_hal::{ api::{ - ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, + ModuleN, ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, + ScratchTakeBasic, }, layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch}, source::Source, }; -use crate::layouts::{ +use crate::{ + ScratchTakeCore, + layouts::{ GGLWE, GGLWEInfos, GLWEInfos, GLWESecret, GLWESwitchingKey, LWEInfos, RingDegree, prepared::GLWESecretPrepared, + }, }; impl GLWESwitchingKey> { pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GGLWEInfos, - Module: SvpPPolBytesOf + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes, + Module: ModuleN + SvpPPolBytesOf + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes + SvpPPolAlloc, { (GGLWE::encrypt_sk_tmp_bytes(module, infos) | ScalarZnx::bytes_of(module.n(), 1)) + ScalarZnx::bytes_of(module.n(), infos.rank_in().into()) @@ -42,7 +46,8 @@ impl GLWESwitchingKey { source_xe: &mut Source, scratch: &mut Scratch, ) where - Module: VecZnxAddScalarInplace + Module: ModuleN + + VecZnxAddScalarInplace + VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply @@ -58,8 +63,9 @@ impl GLWESwitchingKey { + VecZnxSub + SvpPrepare + VecZnxSwitchRing - + SvpPPolBytesOf, - Scratch: ScratchAvailable, + + SvpPPolBytesOf + + SvpPPolAlloc, + Scratch: ScratchAvailable + ScratchTakeBasic + ScratchTakeCore, { #[cfg(debug_assertions)] { @@ -75,7 +81,7 @@ impl GLWESwitchingKey { let n: usize = sk_in.n().max(sk_out.n()).into(); - let (mut sk_in_tmp, scratch_1) = scratch.take_scalar_znx(n, sk_in.rank().into()); + let (mut sk_in_tmp, scratch_1) = scratch.take_scalar_znx(module, sk_in.rank().into()); (0..sk_in.rank().into()).for_each(|i| { module.vec_znx_switch_ring( &mut sk_in_tmp.as_vec_znx_mut(), @@ -85,9 +91,9 @@ impl GLWESwitchingKey { ); }); - let (mut sk_out_tmp, scratch_2) = scratch_1.take_glwe_secret_prepared(RingDegree(n as u32), sk_out.rank()); + let (mut sk_out_tmp, scratch_2) = scratch_1.take_glwe_secret_prepared(module, sk_out.rank()); { - let (mut tmp, _) = scratch_2.take_scalar_znx(n, 1); + let (mut tmp, _) = scratch_2.take_scalar_znx(module, 1); (0..sk_out.rank().into()).for_each(|i| { module.vec_znx_switch_ring(&mut tmp.as_vec_znx_mut(), 0, &sk_out.data.as_vec_znx(), i); module.svp_prepare(&mut sk_out_tmp.data, i, &tmp, 0); diff --git a/poulpy-core/src/encryption/gglwe_tsk.rs b/poulpy-core/src/encryption/gglwe_tsk.rs index 62d1a15..918eb8a 100644 --- a/poulpy-core/src/encryption/gglwe_tsk.rs +++ b/poulpy-core/src/encryption/gglwe_tsk.rs @@ -1,29 +1,33 @@ use poulpy_hal::{ api::{ - SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, + ModuleN, ScratchTakeBasic, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, }, layouts::{Backend, DataMut, DataRef, Module, Scratch}, + oep::VecZnxBigAllocBytesImpl, source::Source, }; -use crate::layouts::{ - GGLWEInfos, GLWEInfos, GLWESecret, GLWESwitchingKey, LWEInfos, Rank, RingDegree, TensorKey, prepared::GLWESecretPrepared, +use crate::{ + ScratchTakeCore, + layouts::{ + GetDist, GGLWEInfos, GLWEInfos, GLWESecret, GLWESwitchingKey, LWEInfos, Rank, TensorKey, prepared::GLWESecretPrepared, + }, }; impl TensorKey> { pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GGLWEInfos, - Module: SvpPPolBytesOf + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes + VecZnxBigBytesOf, + Module: ModuleN + SvpPPolBytesOf + SvpPPolAlloc + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes + VecZnxBigBytesOf, { GLWESecretPrepared::bytes_of(module, infos.rank_out()) + module.bytes_of_vec_znx_dft(infos.rank_out().into(), 1) + module.bytes_of_vec_znx_big(1, 1) + module.bytes_of_vec_znx_dft(1, 1) - + GLWESecret::bytes_of(RingDegree(module.n() as u32), Rank(1)) + + GLWESecret::bytes_of(module, Rank(1)) + GLWESwitchingKey::encrypt_sk_tmp_bytes(module, infos) } } @@ -37,7 +41,9 @@ impl TensorKey { source_xe: &mut Source, scratch: &mut Scratch, ) where - Module: SvpApplyDftToDft + GLWESecret: GetDist, + Module: ModuleN + + SvpApplyDftToDft + VecZnxIdftApplyTmpA + VecZnxAddScalarInplace + VecZnxDftBytesOf @@ -55,8 +61,11 @@ impl TensorKey { + VecZnxSub + SvpPrepare + VecZnxSwitchRing - + SvpPPolBytesOf, - Scratch:, + + SvpPPolBytesOf + + VecZnxBigAllocBytesImpl + + VecZnxBigBytesOf + + SvpPPolAlloc, + Scratch: ScratchTakeBasic + ScratchTakeCore, { #[cfg(debug_assertions)] { @@ -64,21 +73,21 @@ impl TensorKey { assert_eq!(self.n(), sk.n()); } - let n: RingDegree = sk.n(); + // let n: RingDegree = sk.n(); let rank: Rank = self.rank_out(); - let (mut sk_dft_prep, scratch_1) = scratch.take_glwe_secret_prepared(n, rank); - sk_dft_prep.prepare(module, sk, scratch_1); + let (mut sk_dft_prep, scratch_1) = scratch.take_glwe_secret_prepared(module, rank); + sk_dft_prep.prepare(module, sk); - let (mut sk_dft, scratch_2) = scratch_1.take_vec_znx_dft(n.into(), rank.into(), 1); + let (mut sk_dft, scratch_2) = scratch_1.take_vec_znx_dft(module, rank.into(), 1); (0..rank.into()).for_each(|i| { module.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); }); - let (mut sk_ij_big, scratch_3) = scratch_2.take_vec_znx_big(n.into(), 1, 1); - let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(n, Rank(1)); - let (mut sk_ij_dft, scratch_5) = scratch_4.take_vec_znx_dft(n.into(), 1, 1); + let (mut sk_ij_big, scratch_3) = scratch_2.take_vec_znx_big(module, 1, 1); + let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(module, Rank(1)); + let (mut sk_ij_dft, scratch_5) = scratch_4.take_vec_znx_dft(module, 1, 1); (0..rank.into()).for_each(|i| { (i..rank.into()).for_each(|j| { diff --git a/poulpy-core/src/encryption/ggsw_ct.rs b/poulpy-core/src/encryption/ggsw_ct.rs index b044ae3..3357380 100644 --- a/poulpy-core/src/encryption/ggsw_ct.rs +++ b/poulpy-core/src/encryption/ggsw_ct.rs @@ -1,11 +1,11 @@ use poulpy_hal::{ - api::{VecZnxAddScalarInplace, VecZnxDftBytesOf, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes}, + api::{ModuleN, VecZnxAddScalarInplace, VecZnxDftBytesOf, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes}, layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, ScalarZnxToRef, Scratch, VecZnx, ZnxZero}, source::Source, }; use crate::{ - SIGMA, + SIGMA, ScratchTakeCore, encryption::glwe_ct::GLWEEncryptSkInternal, layouts::{ GGSW, GGSWInfos, GGSWToMut, GLWE, GLWEInfos, LWEInfos, @@ -44,8 +44,8 @@ pub trait GGSWEncryptSk { impl GGSWEncryptSk for Module where - Module: GLWEEncryptSkInternal + VecZnxAddScalarInplace + VecZnxNormalizeInplace, - Scratch:, + Module: ModuleN + GLWEEncryptSkInternal + VecZnxAddScalarInplace + VecZnxNormalizeInplace, + Scratch: ScratchTakeCore, { fn ggsw_encrypt_sk( &self, @@ -80,7 +80,7 @@ where let dsize: usize = res.dsize().into(); let cols: usize = (rank + 1).into(); - let (mut tmp_pt, scratch_1) = scratch.take_glwe_pt(&res.glwe_layout()); + let (mut tmp_pt, scratch_1) = scratch.take_glwe_pt(self, &res.glwe_layout()); for row_i in 0..res.dnum().into() { tmp_pt.data.zero(); diff --git a/poulpy-core/src/encryption/glwe_ct.rs b/poulpy-core/src/encryption/glwe_ct.rs index 16bbadf..fa314d5 100644 --- a/poulpy-core/src/encryption/glwe_ct.rs +++ b/poulpy-core/src/encryption/glwe_ct.rs @@ -1,9 +1,9 @@ use poulpy_hal::{ api::{ - ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, + ModuleN, ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxBigAddNormal, VecZnxBigAddSmallInplace, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, + VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, ScratchTakeBasic, }, layouts::{Backend, DataMut, Module, ScalarZnx, Scratch, VecZnx, VecZnxBig, VecZnxToMut, ZnxInfos, ZnxZero}, source::Source, @@ -19,19 +19,19 @@ use crate::{ }; impl GLWE> { - pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize + pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GLWEInfos, - Module: VecZnxNormalizeTmpBytes + VecZnxDftBytesOf, + Module: VecZnxNormalizeTmpBytes + VecZnxDftBytesOf, { let size: usize = infos.size(); assert_eq!(module.n() as u32, infos.n()); module.vec_znx_normalize_tmp_bytes() + 2 * VecZnx::bytes_of(module.n(), 1, size) + module.bytes_of_vec_znx_dft(1, size) } - pub fn encrypt_pk_tmp_bytes(module: &Module, infos: &A) -> usize + pub fn encrypt_pk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GLWEInfos, - Module: VecZnxDftBytesOf + SvpPPolBytesOf + VecZnxBigBytesOf + VecZnxNormalizeTmpBytes, + Module: VecZnxDftBytesOf + SvpPPolBytesOf + VecZnxBigBytesOf + VecZnxNormalizeTmpBytes, { let size: usize = infos.size(); assert_eq!(module.n() as u32, infos.n()); @@ -42,68 +42,68 @@ impl GLWE> { } impl GLWE { - pub fn encrypt_sk( + pub fn encrypt_sk( &mut self, - module: &Module, + module: &Module, pt: &P, sk: &S, source_xa: &mut Source, source_xe: &mut Source, - scratch: &mut Scratch, + scratch: &mut Scratch, ) where P: GLWEPlaintextToRef, - S: GLWESecretPreparedToRef, - Module: GLWEEncryptSk, + S: GLWESecretPreparedToRef, + Module: GLWEEncryptSk, { module.glwe_encrypt_sk(self, pt, sk, source_xa, source_xe, scratch); } - pub fn encrypt_zero_sk( + pub fn encrypt_zero_sk( &mut self, - module: &Module, + module: &Module, sk: &S, source_xa: &mut Source, source_xe: &mut Source, - scratch: &mut Scratch, + scratch: &mut Scratch, ) where - S: GLWESecretPreparedToRef, - Module: GLWEEncryptZeroSk, + S: GLWESecretPreparedToRef, + Module: GLWEEncryptZeroSk, { module.glwe_encrypt_zero_sk(self, sk, source_xa, source_xe, scratch); } - pub fn encrypt_pk( + pub fn encrypt_pk( &mut self, - module: &Module, + module: &Module, pt: &P, pk: &K, source_xu: &mut Source, source_xe: &mut Source, - scratch: &mut Scratch, + scratch: &mut Scratch, ) where P: GLWEPlaintextToRef, - K: GLWEPublicKeyPreparedToRef, - Module: GLWEEncryptPk, + K: GLWEPublicKeyPreparedToRef, + Module: GLWEEncryptPk, { module.glwe_encrypt_pk(self, pt, pk, source_xu, source_xe, scratch); } - pub fn encrypt_zero_pk( + pub fn encrypt_zero_pk( &mut self, - module: &Module, + module: &Module, pk: &K, source_xu: &mut Source, source_xe: &mut Source, - scratch: &mut Scratch, + scratch: &mut Scratch, ) where - K: GLWEPublicKeyPreparedToRef, - Module: GLWEEncryptZeroPk, + K: GLWEPublicKeyPreparedToRef, + Module: GLWEEncryptZeroPk, { module.glwe_encrypt_zero_pk(self, pk, source_xu, source_xe, scratch); } } -pub trait GLWEEncryptSk { +pub trait GLWEEncryptSk { fn glwe_encrypt_sk( &self, res: &mut R, @@ -111,17 +111,17 @@ pub trait GLWEEncryptSk { sk: &S, source_xa: &mut Source, source_xe: &mut Source, - scratch: &mut Scratch, + scratch: &mut Scratch, ) where R: GLWEToMut, P: GLWEPlaintextToRef, - S: GLWESecretPreparedToRef; + S: GLWESecretPreparedToRef; } -impl GLWEEncryptSk for Module +impl GLWEEncryptSk for Module where - Module: GLWEEncryptSkInternal + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf, - Scratch: ScratchAvailable, + Module: GLWEEncryptSkInternal + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf, + Scratch: ScratchAvailable, { fn glwe_encrypt_sk( &self, @@ -130,18 +130,18 @@ where sk: &S, source_xa: &mut Source, source_xe: &mut Source, - scratch: &mut Scratch, + scratch: &mut Scratch, ) where R: GLWEToMut, P: GLWEPlaintextToRef, - S: GLWESecretPreparedToRef, + S: GLWESecretPreparedToRef, { let mut res: GLWE<&mut [u8]> = res.to_mut(); let pt: GLWEPlaintext<&[u8]> = pt.to_ref(); #[cfg(debug_assertions)] { - let sk: GLWESecretPrepared<&[u8], B> = sk.to_ref(); + let sk: GLWESecretPrepared<&[u8], BE> = sk.to_ref(); assert_eq!(res.rank(), sk.rank()); assert_eq!(res.n(), self.n() as u32); assert_eq!(sk.n(), self.n() as u32); @@ -171,23 +171,23 @@ where } } -pub trait GLWEEncryptZeroSk { +pub trait GLWEEncryptZeroSk { fn glwe_encrypt_zero_sk( &self, res: &mut R, sk: &S, source_xa: &mut Source, source_xe: &mut Source, - scratch: &mut Scratch, + scratch: &mut Scratch, ) where R: GLWEToMut, - S: GLWESecretPreparedToRef; + S: GLWESecretPreparedToRef; } -impl GLWEEncryptZeroSk for Module +impl GLWEEncryptZeroSk for Module where - Module: GLWEEncryptSkInternal + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf, - Scratch: ScratchAvailable, + Module: GLWEEncryptSkInternal + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf, + Scratch: ScratchAvailable, { fn glwe_encrypt_zero_sk( &self, @@ -195,16 +195,16 @@ where sk: &S, source_xa: &mut Source, source_xe: &mut Source, - scratch: &mut Scratch, + scratch: &mut Scratch, ) where R: GLWEToMut, - S: GLWESecretPreparedToRef, + S: GLWESecretPreparedToRef, { let mut res: GLWE<&mut [u8]> = res.to_mut(); #[cfg(debug_assertions)] { - let sk: GLWESecretPrepared<&[u8], B> = sk.to_ref(); + let sk: GLWESecretPrepared<&[u8], BE> = sk.to_ref(); assert_eq!(res.rank(), sk.rank()); assert_eq!(res.n(), self.n() as u32); assert_eq!(sk.n(), self.n() as u32); @@ -233,7 +233,7 @@ where } } -pub trait GLWEEncryptPk { +pub trait GLWEEncryptPk { fn glwe_encrypt_pk( &self, res: &mut R, @@ -241,16 +241,16 @@ pub trait GLWEEncryptPk { pk: &K, source_xu: &mut Source, source_xe: &mut Source, - scratch: &mut Scratch, + scratch: &mut Scratch, ) where R: GLWEToMut, P: GLWEPlaintextToRef, - K: GLWEPublicKeyPreparedToRef; + K: GLWEPublicKeyPreparedToRef; } -impl GLWEEncryptPk for Module +impl GLWEEncryptPk for Module where - Module: GLWEEncryptPkInternal, + Module: GLWEEncryptPkInternal, { fn glwe_encrypt_pk( &self, @@ -259,32 +259,32 @@ where pk: &K, source_xu: &mut Source, source_xe: &mut Source, - scratch: &mut Scratch, + scratch: &mut Scratch, ) where R: GLWEToMut, P: GLWEPlaintextToRef, - K: GLWEPublicKeyPreparedToRef, + K: GLWEPublicKeyPreparedToRef, { self.glwe_encrypt_pk_internal(res, Some((pt, 0)), pk, source_xu, source_xe, scratch); } } -pub trait GLWEEncryptZeroPk { +pub trait GLWEEncryptZeroPk { fn glwe_encrypt_zero_pk( &self, res: &mut R, pk: &K, source_xu: &mut Source, source_xe: &mut Source, - scratch: &mut Scratch, + scratch: &mut Scratch, ) where R: GLWEToMut, - K: GLWEPublicKeyPreparedToRef; + K: GLWEPublicKeyPreparedToRef; } -impl GLWEEncryptZeroPk for Module +impl GLWEEncryptZeroPk for Module where - Module: GLWEEncryptPkInternal, + Module: GLWEEncryptPkInternal, { fn glwe_encrypt_zero_pk( &self, @@ -292,10 +292,10 @@ where pk: &K, source_xu: &mut Source, source_xe: &mut Source, - scratch: &mut Scratch, + scratch: &mut Scratch, ) where R: GLWEToMut, - K: GLWEPublicKeyPreparedToRef, + K: GLWEPublicKeyPreparedToRef, { self.glwe_encrypt_pk_internal( res, @@ -308,7 +308,7 @@ where } } -pub(crate) trait GLWEEncryptPkInternal { +pub(crate) trait GLWEEncryptPkInternal { fn glwe_encrypt_pk_internal( &self, res: &mut R, @@ -316,22 +316,25 @@ pub(crate) trait GLWEEncryptPkInternal { pk: &K, source_xu: &mut Source, source_xe: &mut Source, - scratch: &mut Scratch, + scratch: &mut Scratch, ) where R: GLWEToMut, P: GLWEPlaintextToRef, - K: GLWEPublicKeyPreparedToRef; + K: GLWEPublicKeyPreparedToRef; } -impl GLWEEncryptPkInternal for Module +impl GLWEEncryptPkInternal for Module where - Module: SvpPrepare - + SvpApplyDftToDft - + VecZnxIdftApplyConsume - + VecZnxBigAddNormal - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize, - Scratch:, + Module: SvpPrepare + + SvpApplyDftToDft + + VecZnxIdftApplyConsume + + VecZnxBigAddNormal + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize + + SvpPPolBytesOf + + ModuleN + + VecZnxDftBytesOf, + Scratch: ScratchTakeBasic, { fn glwe_encrypt_pk_internal( &self, @@ -340,14 +343,14 @@ where pk: &K, source_xu: &mut Source, source_xe: &mut Source, - scratch: &mut Scratch, + scratch: &mut Scratch, ) where R: GLWEToMut, P: GLWEPlaintextToRef, - K: GLWEPublicKeyPreparedToRef, + K: GLWEPublicKeyPreparedToRef, { let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); - let pk: &GLWEPublicKeyPrepared<&[u8], B> = &pk.to_ref(); + let pk: &GLWEPublicKeyPrepared<&[u8], BE> = &pk.to_ref(); #[cfg(debug_assertions)] { @@ -365,10 +368,10 @@ where let cols: usize = (res.rank() + 1).into(); // Generates u according to the underlying secret distribution. - let (mut u_dft, scratch_1) = scratch.take_svp_ppol(res.n().into(), 1); + let (mut u_dft, scratch_1) = scratch.take_svp_ppol(self, 1); { - let (mut u, _) = scratch_1.take_scalar_znx(res.n().into(), 1); + let (mut u, _) = scratch_1.take_scalar_znx(self, 1); match pk.dist { Distribution::NONE => panic!( "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through \ @@ -387,7 +390,7 @@ where // ct[i] = pk[i] * u + ei (+ m if col = i) (0..cols).for_each(|i| { - let (mut ci_dft, scratch_2) = scratch_1.take_vec_znx_dft(res.n().into(), 1, size_pk); + let (mut ci_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, 1, size_pk); // ci_dft = DFT(u) * DFT(pk[i]) self.svp_apply_dft_to_dft(&mut ci_dft, 0, &u_dft, 0, &pk.data, i); @@ -418,7 +421,7 @@ where } } -pub(crate) trait GLWEEncryptSkInternal { +pub(crate) trait GLWEEncryptSkInternal { fn glwe_encrypt_sk_internal( &self, base2k: usize, @@ -431,29 +434,30 @@ pub(crate) trait GLWEEncryptSkInternal { source_xa: &mut Source, source_xe: &mut Source, sigma: f64, - scratch: &mut Scratch, + scratch: &mut Scratch, ) where R: VecZnxToMut, P: GLWEPlaintextToRef, - S: GLWESecretPreparedToRef; + S: GLWESecretPreparedToRef; } -impl GLWEEncryptSkInternal for Module +impl GLWEEncryptSkInternal for Module where - Module: VecZnxDftBytesOf - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume + Module: ModuleN + + VecZnxDftBytesOf + + VecZnxBigNormalize + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform + VecZnxSubInplace + VecZnxAddInplace - + VecZnxNormalizeInplace + + VecZnxNormalizeInplace + VecZnxAddNormal - + VecZnxNormalize + + VecZnxNormalize + VecZnxSub, - Scratch: ScratchAvailable, + Scratch: ScratchAvailable + ScratchTakeBasic, { fn glwe_encrypt_sk_internal( &self, @@ -467,14 +471,14 @@ where source_xa: &mut Source, source_xe: &mut Source, sigma: f64, - scratch: &mut Scratch, + scratch: &mut Scratch, ) where R: VecZnxToMut, P: GLWEPlaintextToRef, - S: GLWESecretPreparedToRef, + S: GLWESecretPreparedToRef, { let ct: &mut VecZnx<&mut [u8]> = &mut res.to_mut(); - let sk: GLWESecretPrepared<&[u8], B> = sk.to_ref(); + let sk: GLWESecretPrepared<&[u8], BE> = sk.to_ref(); #[cfg(debug_assertions)] { @@ -490,11 +494,11 @@ where let size: usize = ct.size(); - let (mut c0, scratch_1) = scratch.take_vec_znx(ct.n(), 1, size); + let (mut c0, scratch_1) = scratch.take_vec_znx(self, 1, size); c0.zero(); { - let (mut ci, scratch_2) = scratch_1.take_vec_znx(ct.n(), 1, size); + let (mut ci, scratch_2) = scratch_1.take_vec_znx(self, 1, size); // ct[i] = uniform // ct[0] -= c[i] * s[i], @@ -504,7 +508,7 @@ where // ct[i] = uniform (+ pt) self.vec_znx_fill_uniform(base2k, ct, col_ct, source_xa); - let (mut ci_dft, scratch_3) = scratch_2.take_vec_znx_dft(ct.n(), 1, size); + let (mut ci_dft, scratch_3) = scratch_2.take_vec_znx_dft(self, 1, size); // ci = ct[i] - pt // i.e. we act as we sample ct[i] already as uniform + pt @@ -522,7 +526,7 @@ where } self.svp_apply_dft_to_dft_inplace(&mut ci_dft, 0, &sk.data, i - 1); - let ci_big: VecZnxBig<&mut [u8], B> = self.vec_znx_idft_apply_consume(ci_dft); + let ci_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(ci_dft); // use c[0] as buffer, which is overwritten later by the normalization step self.vec_znx_big_normalize(base2k, &mut ci, 0, base2k, &ci_big, 0, scratch_3); diff --git a/poulpy-core/src/encryption/glwe_pk.rs b/poulpy-core/src/encryption/glwe_pk.rs index d89f515..9073cfe 100644 --- a/poulpy-core/src/encryption/glwe_pk.rs +++ b/poulpy-core/src/encryption/glwe_pk.rs @@ -47,7 +47,7 @@ where // Its ok to allocate scratch space here since pk is usually generated only once. let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWE::encrypt_sk_tmp_bytes(self, res)); - let mut tmp: GLWE> = GLWE::alloc_from_infos(res); + let mut tmp: GLWE> = GLWE::alloc_from_infos(self, res); tmp.encrypt_zero_sk(self, sk, source_xa, source_xe, scratch.borrow()); res.dist = sk.dist; diff --git a/poulpy-core/src/encryption/glwe_to_lwe_ksk.rs b/poulpy-core/src/encryption/glwe_to_lwe_ksk.rs index 971766e..cd3ffec 100644 --- a/poulpy-core/src/encryption/glwe_to_lwe_ksk.rs +++ b/poulpy-core/src/encryption/glwe_to_lwe_ksk.rs @@ -1,6 +1,6 @@ use poulpy_hal::{ api::{ - ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, + ModuleN, ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, @@ -9,18 +9,22 @@ use poulpy_hal::{ source::Source, }; -use crate::layouts::{ - GGLWEInfos, GLWESecret, GLWESwitchingKey, GLWEToLWESwitchingKey, LWEInfos, LWESecret, Rank, prepared::GLWESecretPrepared, +use crate::{ + ScratchTakeCore, + layouts::{ + GGLWEInfos, GLWESecret, GLWESwitchingKey, GLWEToLWESwitchingKey, LWEInfos, LWESecret, Rank, + prepared::GLWESecretPrepared, + }, }; impl GLWEToLWESwitchingKey> { pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GGLWEInfos, - Module: SvpPPolBytesOf + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes, + Module: ModuleN + SvpPPolBytesOf + SvpPPolAlloc + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes, { GLWESecretPrepared::bytes_of(module, infos.rank_in()) - + (GLWESwitchingKey::encrypt_sk_tmp_bytes(module, infos) | GLWESecret::bytes_of(infos.n(), infos.rank_in())) + + (GLWESwitchingKey::encrypt_sk_tmp_bytes(module, infos) | GLWESecret::bytes_of(module, infos.rank_in())) } } @@ -37,7 +41,8 @@ impl GLWEToLWESwitchingKey { ) where DLwe: DataRef, DGlwe: DataRef, - Module: VecZnxAutomorphismInplace + Module: ModuleN + + VecZnxAutomorphismInplace + VecZnxAddScalarInplace + VecZnxDftBytesOf + VecZnxBigNormalize @@ -54,15 +59,16 @@ impl GLWEToLWESwitchingKey { + VecZnxSub + SvpPrepare + VecZnxSwitchRing - + SvpPPolBytesOf, - Scratch: ScratchAvailable, + + SvpPPolBytesOf + + SvpPPolAlloc, + Scratch: ScratchAvailable + ScratchTakeCore, { #[cfg(debug_assertions)] { assert!(sk_lwe.n().0 <= module.n() as u32); } - let (mut sk_lwe_as_glwe, scratch_1) = scratch.take_glwe_secret(sk_glwe.n(), Rank(1)); + let (mut sk_lwe_as_glwe, scratch_1) = scratch.take_glwe_secret(module, Rank(1)); sk_lwe_as_glwe.data.zero(); sk_lwe_as_glwe.data.at_mut(0, 0)[..sk_lwe.n().into()].copy_from_slice(sk_lwe.data.at(0, 0)); module.vec_znx_automorphism_inplace(-1, &mut sk_lwe_as_glwe.data.as_vec_znx_mut(), 0, scratch_1); diff --git a/poulpy-core/src/encryption/lwe_ct.rs b/poulpy-core/src/encryption/lwe_ct.rs index a01d95f..5a743eb 100644 --- a/poulpy-core/src/encryption/lwe_ct.rs +++ b/poulpy-core/src/encryption/lwe_ct.rs @@ -1,64 +1,79 @@ use poulpy_hal::{ api::{ScratchOwnedAlloc, ScratchOwnedBorrow, ZnAddNormal, ZnFillUniform, ZnNormalizeInplace}, - layouts::{Backend, DataMut, DataRef, Module, ScratchOwned, Zn, ZnxView, ZnxViewMut}, + layouts::{Backend, DataMut, Module, ScratchOwned, Zn, ZnxView, ZnxViewMut}, oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl}, source::Source, }; use crate::{ encryption::{SIGMA, SIGMA_BOUND}, - layouts::{LWE, LWEInfos, LWEPlaintext, LWESecret}, + layouts::{LWE, LWEInfos, LWEPlaintext, LWESecret, LWEToMut, LWEPlaintextToRef, LWESecretToRef}, }; impl LWE { - pub fn encrypt_sk( - &mut self, - module: &Module, - pt: &LWEPlaintext, - sk: &LWESecret, - source_xa: &mut Source, - source_xe: &mut Source, - ) where - DataPt: DataRef, - DataSk: DataRef, - Module: ZnFillUniform + ZnAddNormal + ZnNormalizeInplace, - B: Backend + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, + + pub fn encrypt_sk(&mut self, module: &M, pt: &P, sk: &S, source_xa: &mut Source, source_xe: &mut Source) + where + P: LWEPlaintextToRef, + S: LWESecretToRef, + M: LWEEncryptSk, + BE: Backend + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, { + module.lwe_encrypt_sk(self, pt, sk, source_xa, source_xe); + } +} + + +pub trait LWEEncryptSk +where + Self: Sized + ZnFillUniform + ZnAddNormal + ZnNormalizeInplace, +{ + fn lwe_encrypt_sk(&self, res: &mut R, pt: &P, sk: &S, source_xa: &mut Source, source_xe: &mut Source) + where + R: LWEToMut, + P: LWEPlaintextToRef, + S: LWESecretToRef, + BE: Backend + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, + { + let res: &mut LWE<&mut [u8]> = &mut res.to_mut(); + let pt: &LWEPlaintext<&[u8]> = &pt.to_ref(); + let sk: &LWESecret<&[u8]> = &sk.to_ref(); + #[cfg(debug_assertions)] { - assert_eq!(self.n(), sk.n()) + assert_eq!(res.n(), sk.n()) } - let base2k: usize = self.base2k().into(); - let k: usize = self.k().into(); + let base2k: usize = res.base2k().into(); + let k: usize = res.k().into(); - module.zn_fill_uniform((self.n() + 1).into(), base2k, &mut self.data, 0, source_xa); + self.zn_fill_uniform((res.n() + 1).into(), base2k, &mut res.data, 0, source_xa); - let mut tmp_znx: Zn> = Zn::alloc(1, 1, self.size()); + let mut tmp_znx: Zn> = Zn::alloc(1, 1, res.size()); - let min_size = self.size().min(pt.size()); + let min_size = res.size().min(pt.size()); (0..min_size).for_each(|i| { tmp_znx.at_mut(0, i)[0] = pt.data.at(0, i)[0] - - self.data.at(0, i)[1..] + - res.data.at(0, i)[1..] .iter() .zip(sk.data.at(0, 0)) .map(|(x, y)| x * y) .sum::(); }); - (min_size..self.size()).for_each(|i| { - tmp_znx.at_mut(0, i)[0] -= self.data.at(0, i)[1..] + (min_size..res.size()).for_each(|i| { + tmp_znx.at_mut(0, i)[0] -= res.data.at(0, i)[1..] .iter() .zip(sk.data.at(0, 0)) .map(|(x, y)| x * y) .sum::(); }); - module.zn_add_normal( + self.zn_add_normal( 1, base2k, - &mut self.data, + &mut res.data, 0, k, source_xe, @@ -66,7 +81,7 @@ impl LWE { SIGMA_BOUND, ); - module.zn_normalize_inplace( + self.zn_normalize_inplace( 1, base2k, &mut tmp_znx, @@ -74,8 +89,13 @@ impl LWE { ScratchOwned::alloc(size_of::()).borrow(), ); - (0..self.size()).for_each(|i| { - self.data.at_mut(0, i)[0] = tmp_znx.at(0, i)[0]; - }); + (0..res.size()).for_each(|i| { + res.data.at_mut(0, i)[0] = tmp_znx.at(0, i)[0]; + }); } } + +impl LWEEncryptSk for Module where + Self: Sized + ZnFillUniform + ZnAddNormal + ZnNormalizeInplace, +{ +} \ No newline at end of file diff --git a/poulpy-core/src/encryption/lwe_ksk.rs b/poulpy-core/src/encryption/lwe_ksk.rs index 2fd60ff..ed1c5d6 100644 --- a/poulpy-core/src/encryption/lwe_ksk.rs +++ b/poulpy-core/src/encryption/lwe_ksk.rs @@ -1,6 +1,6 @@ use poulpy_hal::{ api::{ - SvpApplyDftToDftInplace, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, + ModuleN, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, @@ -9,16 +9,19 @@ use poulpy_hal::{ source::Source, }; -use crate::layouts::{ - GGLWEInfos, GLWESecret, GLWESwitchingKey, LWEInfos, LWESecret, LWESwitchingKey, Rank, RingDegree, +use crate::{ + ScratchTakeCore, + layouts::{ + GGLWEInfos, GLWESecret, GLWESwitchingKey, LWEInfos, LWESecret, LWESwitchingKey, Rank, prepared::GLWESecretPrepared, + }, }; impl LWESwitchingKey> { pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GGLWEInfos, - Module: SvpPPolBytesOf + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes, + Module: ModuleN + SvpPPolBytesOf + SvpPPolAlloc + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes, { debug_assert_eq!( infos.dsize().0, @@ -35,7 +38,7 @@ impl LWESwitchingKey> { 1, "rank_out > 1 is not supported for LWESwitchingKey" ); - GLWESecret::bytes_of(RingDegree(module.n() as u32), Rank(1)) + GLWESecret::bytes_of(module, Rank(1)) + GLWESecretPrepared::bytes_of(module, Rank(1)) + GLWESwitchingKey::encrypt_sk_tmp_bytes(module, infos) } @@ -54,7 +57,8 @@ impl LWESwitchingKey { ) where DIn: DataRef, DOut: DataRef, - Module: VecZnxAutomorphismInplace + Module: ModuleN + + VecZnxAutomorphismInplace + VecZnxAddScalarInplace + VecZnxDftBytesOf + VecZnxBigNormalize @@ -71,8 +75,9 @@ impl LWESwitchingKey { + VecZnxSub + SvpPrepare + VecZnxSwitchRing - + SvpPPolBytesOf, - Scratch:, + + SvpPPolBytesOf + + SvpPPolAlloc, + Scratch: ScratchTakeCore, { #[cfg(debug_assertions)] { @@ -81,8 +86,8 @@ impl LWESwitchingKey { assert!(self.n().0 <= module.n() as u32); } - let (mut sk_in_glwe, scratch_1) = scratch.take_glwe_secret(self.n(), Rank(1)); - let (mut sk_out_glwe, scratch_2) = scratch_1.take_glwe_secret(self.n(), Rank(1)); + let (mut sk_in_glwe, scratch_1) = scratch.take_glwe_secret(module, Rank(1)); + let (mut sk_out_glwe, scratch_2) = scratch_1.take_glwe_secret(module, Rank(1)); sk_out_glwe.data.at_mut(0, 0)[..sk_lwe_out.n().into()].copy_from_slice(sk_lwe_out.data.at(0, 0)); sk_out_glwe.data.at_mut(0, 0)[sk_lwe_out.n().into()..].fill(0); diff --git a/poulpy-core/src/encryption/lwe_to_glwe_ksk.rs b/poulpy-core/src/encryption/lwe_to_glwe_ksk.rs index 041c7c4..32c75fc 100644 --- a/poulpy-core/src/encryption/lwe_to_glwe_ksk.rs +++ b/poulpy-core/src/encryption/lwe_to_glwe_ksk.rs @@ -1,6 +1,6 @@ use poulpy_hal::{ api::{ - ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, + ModuleN, ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, @@ -9,13 +9,16 @@ use poulpy_hal::{ source::Source, }; -use crate::layouts::{GGLWEInfos, GLWESecret, GLWESwitchingKey, LWEInfos, LWESecret, LWEToGLWESwitchingKey, Rank, RingDegree}; +use crate::{ + ScratchTakeCore, + layouts::{GGLWEInfos, GLWESecret, GLWESwitchingKey, LWEInfos, LWESecret, LWEToGLWESwitchingKey, Rank}, +}; impl LWEToGLWESwitchingKey> { pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where A: GGLWEInfos, - Module: SvpPPolBytesOf + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes, + Module: ModuleN + SvpPPolBytesOf + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes + SvpPPolAlloc, { debug_assert_eq!( infos.rank_in(), @@ -23,7 +26,7 @@ impl LWEToGLWESwitchingKey> { "rank_in != 1 is not supported for LWEToGLWESwitchingKey" ); GLWESwitchingKey::encrypt_sk_tmp_bytes(module, infos) - + GLWESecret::bytes_of(RingDegree(module.n() as u32), infos.rank_in()) + + GLWESecret::bytes_of(module, infos.rank_in()) } } @@ -40,7 +43,8 @@ impl LWEToGLWESwitchingKey { ) where DLwe: DataRef, DGlwe: DataRef, - Module: VecZnxAutomorphismInplace + Module: ModuleN + + VecZnxAutomorphismInplace + VecZnxAddScalarInplace + VecZnxDftBytesOf + VecZnxBigNormalize @@ -57,8 +61,9 @@ impl LWEToGLWESwitchingKey { + VecZnxSub + SvpPrepare + VecZnxSwitchRing - + SvpPPolBytesOf, - Scratch: ScratchAvailable, + + SvpPPolBytesOf + + SvpPPolAlloc, + Scratch: ScratchAvailable + ScratchTakeCore, { #[cfg(debug_assertions)] { @@ -67,7 +72,7 @@ impl LWEToGLWESwitchingKey { assert!(sk_lwe.n().0 <= module.n() as u32); } - let (mut sk_lwe_as_glwe, scratch_1) = scratch.take_glwe_secret(sk_glwe.n(), Rank(1)); + let (mut sk_lwe_as_glwe, scratch_1) = scratch.take_glwe_secret(module, Rank(1)); sk_lwe_as_glwe.data.at_mut(0, 0)[..sk_lwe.n().into()].copy_from_slice(sk_lwe.data.at(0, 0)); sk_lwe_as_glwe.data.at_mut(0, 0)[sk_lwe.n().into()..].fill(0); module.vec_znx_automorphism_inplace(-1, &mut sk_lwe_as_glwe.data.as_vec_znx_mut(), 0, scratch_1);