diff --git a/Cargo.lock b/Cargo.lock index 0037bc7..76ebe9b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -49,9 +49,9 @@ checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "bytemuck" -version = "1.23.2" +version = "1.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3995eaeebcdf32f91f980d360f78732ddc061097ab4e39991ae7a6ace9194677" +checksum = "1fbdf580320f38b612e485521afda1ee26d10cc9884efaaa750d383e13e3c5f4" [[package]] name = "byteorder" @@ -372,6 +372,7 @@ dependencies = [ name = "poulpy-core" version = "0.3.1" dependencies = [ + "bytemuck", "byteorder", "criterion", "itertools 0.14.0", @@ -534,9 +535,9 @@ checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "rug" -version = "1.27.0" +version = "1.28.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4207e8d668e5b8eb574bda8322088ccd0d7782d3d03c7e8d562e82ed82bdcbc3" +checksum = "58ad2e973fe3c3214251a840a621812a4f40468da814b1a3d6947d433c2af11f" dependencies = [ "az", "gmp-mpfr-sys", diff --git a/Cargo.toml b/Cargo.toml index ed44102..964e654 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,8 +7,8 @@ poulpy-hal = {path = "poulpy-hal"} poulpy-core = {path = "poulpy-core"} poulpy-backend = {path = "poulpy-backend"} poulpy-schemes = {path = "poulpy-schemes"} -rug = "1.27" -rand = "0.9.1" +rug = "1.28.0" +rand = "0.9.2" rand_chacha = "0.9.0" rand_core = "0.9.3" rand_distr = "0.5.1" @@ -16,4 +16,5 @@ itertools = "0.14.0" criterion = "0.7.0" byteorder = "1.5.0" zstd = "0.13.3" -once_cell = "1.21.3" \ No newline at end of file +once_cell = "1.21.3" +bytemuck = "1.24.0" \ No newline at end of file diff --git a/poulpy-core/Cargo.toml b/poulpy-core/Cargo.toml index 42eb830..cf15d04 100644 --- a/poulpy-core/Cargo.toml +++ b/poulpy-core/Cargo.toml @@ -15,6 +15,7 @@ poulpy-hal = {workspace = true} poulpy-backend = {workspace = true} itertools = {workspace = true} byteorder = {workspace = true} +bytemuck = {workspace = true} once_cell = {workspace = true} [[bench]] diff --git a/poulpy-core/src/automorphism/gglwe_atk.rs b/poulpy-core/src/automorphism/gglwe_atk.rs index ffc35cd..87e545e 100644 --- a/poulpy-core/src/automorphism/gglwe_atk.rs +++ b/poulpy-core/src/automorphism/gglwe_atk.rs @@ -1,11 +1,10 @@ use poulpy_hal::{ - api::VecZnxAutomorphism, - layouts::{Backend, DataMut, GaloisElement, Module, Scratch}, + api::{VecZnxAutomorphism, VecZnxAutomorphismInplace}, + layouts::{Backend, CyclotomicOrder, DataMut, GaloisElement, Module, Scratch}, }; use crate::{ - ScratchTakeCore, - automorphism::glwe_ct::GLWEAutomorphism, + GLWEKeyswitch, ScratchTakeCore, layouts::{ GGLWE, GGLWEInfos, GGLWEPreparedToRef, GGLWEToMut, GGLWEToRef, GLWE, GLWEAutomorphismKey, GetGaloisElement, SetGaloisElement, @@ -45,14 +44,10 @@ impl GLWEAutomorphismKey { } } -impl GLWEAutomorphismKeyAutomorphism for Module where - Self: GaloisElement + GLWEAutomorphism + VecZnxAutomorphism -{ -} - -pub trait GLWEAutomorphismKeyAutomorphism +impl GLWEAutomorphismKeyAutomorphism for Module where - Self: GaloisElement + GLWEAutomorphism + VecZnxAutomorphism, + Self: GaloisElement + GLWEKeyswitch + VecZnxAutomorphism + VecZnxAutomorphismInplace + CyclotomicOrder, + Scratch: ScratchTakeCore, { fn glwe_automorphism_key_automorphism_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize where @@ -68,7 +63,6 @@ where R: GGLWEToMut + SetGaloisElement + GGLWEInfos, A: GGLWEToRef + GetGaloisElement + GGLWEInfos, K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, - Scratch: ScratchTakeCore, { assert!( res.dnum().as_u32() <= a.dnum().as_u32(), @@ -163,3 +157,22 @@ where res.set_p((res.p() * key.p()) % self.cyclotomic_order()); } } + +pub trait GLWEAutomorphismKeyAutomorphism { + fn glwe_automorphism_key_automorphism_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize + where + R: GGLWEInfos, + A: GGLWEInfos, + K: GGLWEInfos; + + fn glwe_automorphism_key_automorphism(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) + where + R: GGLWEToMut + SetGaloisElement + GGLWEInfos, + A: GGLWEToRef + GetGaloisElement + GGLWEInfos, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos; + + fn glwe_automorphism_key_automorphism_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) + where + R: GGLWEToMut + SetGaloisElement + GetGaloisElement + GGLWEInfos, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos; +} diff --git a/poulpy-core/src/automorphism/ggsw_ct.rs b/poulpy-core/src/automorphism/ggsw_ct.rs index fb54f6d..8644f98 100644 --- a/poulpy-core/src/automorphism/ggsw_ct.rs +++ b/poulpy-core/src/automorphism/ggsw_ct.rs @@ -7,8 +7,8 @@ use crate::{ GGSWExpandRows, ScratchTakeCore, automorphism::glwe_ct::GLWEAutomorphism, layouts::{ - GGLWEInfos, GGLWEPreparedToRef, GGSW, GGSWInfos, GGSWToMut, GGSWToRef, GetGaloisElement, - prepared::{GLWETensorKeyPrepared, GLWETensorKeyPreparedToRef}, + GGLWEInfos, GGLWEPreparedToRef, GGLWEToGGSWKeyPrepared, GGLWEToGGSWKeyPreparedToRef, GGSW, GGSWInfos, GGSWToMut, + GGSWToRef, GetGaloisElement, }, }; @@ -36,7 +36,7 @@ impl GGSW { where A: GGSWToRef, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, - T: GLWETensorKeyPreparedToRef, + T: GGLWEToGGSWKeyPreparedToRef, Scratch: ScratchTakeCore, M: GGSWAutomorphism, { @@ -46,7 +46,7 @@ impl GGSW { pub fn automorphism_inplace(&mut self, module: &M, key: &K, tsk: &T, scratch: &mut Scratch) where K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, - T: GLWETensorKeyPreparedToRef, + T: GGLWEToGGSWKeyPreparedToRef, Scratch: ScratchTakeCore, M: GGSWAutomorphism, { @@ -67,11 +67,8 @@ where K: GGLWEInfos, T: GGLWEInfos, { - let out_size: usize = res_infos.size(); - let ci_dft: usize = self.bytes_of_vec_znx_dft((key_infos.rank_out() + 1).into(), out_size); - let ks_internal: usize = self.glwe_automorphism_tmp_bytes(res_infos, a_infos, key_infos); - let expand: usize = self.ggsw_expand_rows_tmp_bytes(res_infos, tsk_infos); - ci_dft + (ks_internal.max(expand)) + self.glwe_automorphism_tmp_bytes(res_infos, a_infos, key_infos) + .max(self.ggsw_expand_rows_tmp_bytes(res_infos, tsk_infos)) } fn ggsw_automorphism(&self, res: &mut R, a: &A, key: &K, tsk: &T, scratch: &mut Scratch) @@ -79,12 +76,12 @@ where R: GGSWToMut, A: GGSWToRef, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, - T: GLWETensorKeyPreparedToRef, + T: GGLWEToGGSWKeyPreparedToRef, Scratch: ScratchTakeCore, { let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); let a: &GGSW<&[u8]> = &a.to_ref(); - let tsk: &GLWETensorKeyPrepared<&[u8], BE> = &tsk.to_ref(); + let tsk: &GGLWEToGGSWKeyPrepared<&[u8], BE> = &tsk.to_ref(); assert_eq!(res.dsize(), a.dsize()); assert!(res.dnum() <= a.dnum()); @@ -104,11 +101,11 @@ where where R: GGSWToMut, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, - T: GLWETensorKeyPreparedToRef, + T: GGLWEToGGSWKeyPreparedToRef, Scratch: ScratchTakeCore, { let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); - let tsk: &GLWETensorKeyPrepared<&[u8], BE> = &tsk.to_ref(); + let tsk: &GGLWEToGGSWKeyPrepared<&[u8], BE> = &tsk.to_ref(); // Keyswitch the j-th row of the col 0 for row in 0..res.dnum().as_usize() { diff --git a/poulpy-core/src/automorphism/glwe_ct.rs b/poulpy-core/src/automorphism/glwe_ct.rs index 7161239..b382197 100644 --- a/poulpy-core/src/automorphism/glwe_ct.rs +++ b/poulpy-core/src/automorphism/glwe_ct.rs @@ -1,13 +1,13 @@ use poulpy_hal::{ api::{ - ScratchTakeBasic, VecZnxAutomorphismInplace, VecZnxBigAutomorphismInplace, VecZnxBigSubSmallInplace, - VecZnxBigSubSmallNegateInplace, + ScratchTakeBasic, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace, VecZnxBigNormalize, + VecZnxBigSubSmallInplace, VecZnxBigSubSmallNegateInplace, VecZnxNormalize, }, layouts::{Backend, DataMut, Module, Scratch, VecZnxBig}, }; use crate::{ - GLWEKeyswitch, ScratchTakeCore, keyswitch_internal, + GLWEKeySwitchInternal, GLWEKeyswitch, ScratchTakeCore, layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, GetGaloisElement, LWEInfos}, }; @@ -101,13 +101,71 @@ impl GLWE { } } -pub trait GLWEAutomorphism +pub trait GLWEAutomorphism { + fn glwe_automorphism_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize + where + R: GLWEInfos, + A: GLWEInfos, + K: GGLWEInfos; + + fn glwe_automorphism(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + A: GLWEToRef, + K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos; + + fn glwe_automorphism_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos; + + fn glwe_automorphism_add(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + A: GLWEToRef, + K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos; + + fn glwe_automorphism_add_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos; + + fn glwe_automorphism_sub(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + A: GLWEToRef, + K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos; + + fn glwe_automorphism_sub_negate(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + A: GLWEToRef, + K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos; + + fn glwe_automorphism_sub_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos; + + fn glwe_automorphism_sub_negate_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos; +} + +impl GLWEAutomorphism for Module where - Self: GLWEKeyswitch + Self: Sized + + GLWEKeyswitch + + GLWEKeySwitchInternal + + VecZnxNormalize + VecZnxAutomorphismInplace + VecZnxBigAutomorphismInplace + VecZnxBigSubSmallInplace - + VecZnxBigSubSmallNegateInplace, + + VecZnxBigSubSmallNegateInplace + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize, + Scratch: ScratchTakeCore, { fn glwe_automorphism_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize where @@ -160,7 +218,7 @@ where let a: &GLWE<&[u8]> = &a.to_ref(); let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, BE> = keyswitch_internal(self, res_dft, a, key, scratch_1); + let mut res_big: VecZnxBig<_, BE> = self.glwe_keyswitch_internal(res_dft, a, key, scratch_1); for i in 0..res.rank().as_usize() + 1 { self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); @@ -186,7 +244,7 @@ where let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, BE> = keyswitch_internal(self, res_dft, res, key, scratch_1); + let mut res_big: VecZnxBig<_, BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1); for i in 0..res.rank().as_usize() + 1 { self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); @@ -214,7 +272,7 @@ where let a: &GLWE<&[u8]> = &a.to_ref(); let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, BE> = keyswitch_internal(self, res_dft, a, key, scratch_1); + let mut res_big: VecZnxBig<_, BE> = self.glwe_keyswitch_internal(res_dft, a, key, scratch_1); for i in 0..res.rank().as_usize() + 1 { self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); @@ -242,7 +300,7 @@ where let a: &GLWE<&[u8]> = &a.to_ref(); let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, BE> = keyswitch_internal(self, res_dft, a, key, scratch_1); + let mut res_big: VecZnxBig<_, BE> = self.glwe_keyswitch_internal(res_dft, a, key, scratch_1); for i in 0..res.rank().as_usize() + 1 { self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); @@ -268,7 +326,7 @@ where let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, BE> = keyswitch_internal(self, res_dft, res, key, scratch_1); + let mut res_big: VecZnxBig<_, BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1); for i in 0..res.rank().as_usize() + 1 { self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); @@ -294,7 +352,7 @@ where let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, BE> = keyswitch_internal(self, res_dft, res, key, scratch_1); + let mut res_big: VecZnxBig<_, BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1); for i in 0..res.rank().as_usize() + 1 { self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); @@ -311,12 +369,3 @@ where } } } - -impl GLWEAutomorphism for Module where - Self: GLWEKeyswitch - + VecZnxAutomorphismInplace - + VecZnxBigAutomorphismInplace - + VecZnxBigSubSmallInplace - + VecZnxBigSubSmallNegateInplace -{ -} diff --git a/poulpy-core/src/conversion/gglwe_to_ggsw.rs b/poulpy-core/src/conversion/gglwe_to_ggsw.rs index b33759e..8554e50 100644 --- a/poulpy-core/src/conversion/gglwe_to_ggsw.rs +++ b/poulpy-core/src/conversion/gglwe_to_ggsw.rs @@ -1,17 +1,16 @@ use poulpy_hal::{ api::{ - ModuleN, ScratchAvailable, ScratchTakeBasic, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxDftAddInplace, VecZnxDftApply, - VecZnxDftBytesOf, VecZnxDftCopy, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, - VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + ScratchAvailable, ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigBytesOf, VecZnxBigNormalize, + VecZnxBigNormalizeTmpBytes, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, }, - layouts::{Backend, DataMut, Module, Scratch, VmpPMat, ZnxInfos}, + layouts::{Backend, DataMut, Module, Scratch, VecZnxBig}, }; use crate::{ - GLWECopy, ScratchTakeCore, + GGLWEProduct, GLWECopy, ScratchTakeCore, layouts::{ - GGLWE, GGLWEInfos, GGLWEToRef, GGSW, GGSWInfos, GGSWToMut, GLWEInfos, LWEInfos, - prepared::{GLWETensorKeyPrepared, GLWETensorKeyPreparedToRef}, + GGLWE, GGLWEInfos, GGLWEToGGSWKeyPrepared, GGLWEToGGSWKeyPreparedToRef, GGLWEToRef, GGSW, GGSWInfos, GGSWToMut, GLWE, + GLWEInfos, LWEInfos, }, }; @@ -31,7 +30,7 @@ impl GGSW { where M: GGSWFromGGLWE, G: GGLWEToRef, - T: GLWETensorKeyPreparedToRef, + T: GGLWEToGGSWKeyPreparedToRef, Scratch: ScratchTakeCore, { module.ggsw_from_gglwe(self, gglwe, tsk, scratch); @@ -54,12 +53,12 @@ where where R: GGSWToMut, A: GGLWEToRef, - T: GLWETensorKeyPreparedToRef, + T: GGLWEToGGSWKeyPreparedToRef, Scratch: ScratchTakeCore, { let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); let a: &GGLWE<&[u8]> = &a.to_ref(); - let tsk: &GLWETensorKeyPrepared<&[u8], BE> = &tsk.to_ref(); + let tsk: &GGLWEToGGSWKeyPrepared<&[u8], BE> = &tsk.to_ref(); assert_eq!(res.rank(), a.rank_out()); assert_eq!(res.dnum(), a.dnum()); @@ -85,177 +84,140 @@ pub trait GGSWFromGGLWE { where R: GGSWToMut, A: GGLWEToRef, - T: GLWETensorKeyPreparedToRef, + T: GGLWEToGGSWKeyPreparedToRef, Scratch: ScratchTakeCore; } -impl GGSWExpandRows for Module where - Self: Sized - + ModuleN - + VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigBytesOf - + VecZnxNormalizeTmpBytes - + VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigBytesOf - + VecZnxNormalizeTmpBytes - + VecZnxDftApply - + VecZnxDftCopy - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftAddInplace - + VecZnxBigNormalize - + VecZnxIdftApplyTmpA - + VecZnxNormalize -{ +pub trait GGSWExpandRows { + fn ggsw_expand_rows_tmp_bytes(&self, res_infos: &R, tsk_infos: &A) -> usize + where + R: GGSWInfos, + A: GGLWEInfos; + + fn ggsw_expand_row(&self, res: &mut R, tsk: &T, scratch: &mut Scratch) + where + R: GGSWToMut, + T: GGLWEToGGSWKeyPreparedToRef, + Scratch: ScratchTakeCore; } -pub trait GGSWExpandRows +impl GGSWExpandRows for Module where - Self: Sized - + ModuleN - + VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigBytesOf - + VecZnxNormalizeTmpBytes - + VecZnxDftApply - + VecZnxDftCopy - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftAddInplace + Self: GGLWEProduct + VecZnxBigNormalize - + VecZnxIdftApplyTmpA - + VecZnxNormalize, + + VecZnxBigNormalizeTmpBytes + + VecZnxBigBytesOf + + VecZnxDftBytesOf + + VecZnxDftApply + + VecZnxNormalize + + VecZnxBigAddSmallInplace + + VecZnxIdftApplyConsume, { fn ggsw_expand_rows_tmp_bytes(&self, res_infos: &R, tsk_infos: &A) -> usize where R: GGSWInfos, A: GGLWEInfos, { - let tsk_size: usize = tsk_infos.k().div_ceil(tsk_infos.base2k()) as usize; - let size_in: usize = res_infos - .k() - .div_ceil(tsk_infos.base2k()) - .div_ceil(tsk_infos.dsize().into()) as usize; + let base2k_in: usize = res_infos.base2k().into(); + let base2k_tsk: usize = tsk_infos.base2k().into(); - let tmp_dft_i: usize = self.bytes_of_vec_znx_dft((tsk_infos.rank_out() + 1).into(), tsk_size); - let tmp_a: usize = self.bytes_of_vec_znx_dft(1, size_in); - let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes( - tsk_size, - size_in, - size_in, - (tsk_infos.rank_in()).into(), // Verify if rank+1 - (tsk_infos.rank_out()).into(), // Verify if rank+1 - tsk_size, - ); - let tmp_idft: usize = self.bytes_of_vec_znx_big(1, tsk_size); - let norm: usize = self.vec_znx_normalize_tmp_bytes(); + let rank: usize = res_infos.rank().into(); + let cols: usize = rank + 1; - tmp_dft_i + ((tmp_a + vmp) | (tmp_idft + norm)) + let res_size = res_infos.size(); + let a_size: usize = (res_infos.size() * base2k_in).div_ceil(base2k_tsk); + + let a_dft = self.bytes_of_vec_znx_dft(cols - 1, a_size); + let res_dft = self.bytes_of_vec_znx_dft(cols, a_size); + let gglwe_prod: usize = self.gglwe_product_dft_tmp_bytes(res_size, a_size, tsk_infos); + let normalize = self.vec_znx_big_normalize_tmp_bytes(); + + (a_dft + res_dft + gglwe_prod).max(normalize) } fn ggsw_expand_row(&self, res: &mut R, tsk: &T, scratch: &mut Scratch) where R: GGSWToMut, - T: GLWETensorKeyPreparedToRef, + T: GGLWEToGGSWKeyPreparedToRef, Scratch: ScratchTakeCore, { let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); - let tsk: &GLWETensorKeyPrepared<&[u8], BE> = &tsk.to_ref(); + let tsk: &GGLWEToGGSWKeyPrepared<&[u8], BE> = &tsk.to_ref(); - let basek_in: usize = res.base2k().into(); - let basek_tsk: usize = tsk.base2k().into(); + let base2k_in: usize = res.base2k().into(); + let base2k_tsk: usize = tsk.base2k().into(); assert!(scratch.available() >= self.ggsw_expand_rows_tmp_bytes(res, tsk)); let rank: usize = res.rank().into(); let cols: usize = rank + 1; - let a_size: usize = (res.size() * basek_in).div_ceil(basek_tsk); + let a_size: usize = (res.size() * base2k_in).div_ceil(base2k_tsk); // Keyswitch the j-th row of the col 0 - for row_i in 0..res.dnum().into() { - let a = &res.at(row_i, 0).data; + for row in 0..res.dnum().as_usize() { + let (mut a_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols - 1, a_size); - // Pre-compute DFT of (a0, a1, a2) - let (mut ci_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, a_size); + { + let glwe_mi_1: &GLWE<&[u8]> = &res.at(row, 0); - if basek_in == basek_tsk { - for i in 0..cols { - self.vec_znx_dft_apply(1, 0, &mut ci_dft, i, a, i); - } - } else { - let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(self.n(), 1, a_size); - for i in 0..cols { - self.vec_znx_normalize(basek_tsk, &mut a_conv, 0, basek_in, a, i, scratch_2); - self.vec_znx_dft_apply(1, 0, &mut ci_dft, i, &a_conv, 0); + if base2k_in == base2k_tsk { + for col_i in 0..cols - 1 { + self.vec_znx_dft_apply(1, 0, &mut a_dft, col_i, glwe_mi_1.data(), col_i + 1); + } + } else { + let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(self.n(), 1, a_size); + for i in 0..cols - 1 { + self.vec_znx_normalize( + base2k_tsk, + &mut a_conv, + 0, + base2k_in, + glwe_mi_1.data(), + i + 1, + scratch_2, + ); + self.vec_znx_dft_apply(1, 0, &mut a_dft, i, &a_conv, 0); + } } } - for col_j in 1..cols { - // Example for rank 3: + // Example for rank 3: + // + // Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is + // actually composed of that many dnum and we focus on a specific row here + // implicitely given ci_dft. + // + // # Input + // + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) + // col 1: (0, 0, 0, 0) + // col 2: (0, 0, 0, 0) + // col 3: (0, 0, 0, 0) + // + // # Output + // + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) + // col 1: (-(b0s0 + b1s1 + b2s2) , b0 + M[i], b1 , b2 ) + // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M[i], c2 ) + // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M[i]) + for col in 1..cols { + let (mut res_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, cols, tsk.size()); // Todo optimise + + // Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2 // - // Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is - // actually composed of that many dnum and we focus on a specific row here - // implicitely given ci_dft. + // # Example for col=1 // - // # Input - // - // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) - // col 1: (0, 0, 0, 0) - // col 2: (0, 0, 0, 0) - // col 3: (0, 0, 0, 0) - // - // # Output - // - // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) - // col 1: (-(b0s0 + b1s1 + b2s2) , b0 + M[i], b1 , b2 ) - // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M[i], c2 ) - // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M[i]) + // a0 * (-(f0s0 + f1s1 + f1s2) + s0^2, f0, f1, f2) = (-(a0f0s0 + a0f1s1 + a0f1s2) + a0s0^2, a0f0, a0f1, a0f2) + // + + // a1 * (-(g0s0 + g1s1 + g1s2) + s0s1, g0, g1, g2) = (-(a1g0s0 + a1g1s1 + a1g1s2) + a1s0s1, a1g0, a1g1, a1g2) + // + + // a2 * (-(h0s0 + h1s1 + h1s2) + s0s2, h0, h1, h2) = (-(a2h0s0 + a2h1s1 + a2h1s2) + a2s0s2, a2h0, a2h1, a2h2) + // = + // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2) + self.gglwe_product_dft(&mut res_dft, &a_dft, tsk.at(col - 1), scratch_2); - let dsize: usize = tsk.dsize().into(); - - let (mut tmp_dft_i, scratch_2) = scratch_1.take_vec_znx_dft(self, cols, tsk.size()); - let (mut tmp_a, scratch_3) = scratch_2.take_vec_znx_dft(self, 1, ci_dft.size().div_ceil(dsize)); - - { - // Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2 - // - // # Example for col=1 - // - // a0 * (-(f0s0 + f1s1 + f1s2) + s0^2, f0, f1, f2) = (-(a0f0s0 + a0f1s1 + a0f1s2) + a0s0^2, a0f0, a0f1, a0f2) - // + - // a1 * (-(g0s0 + g1s1 + g1s2) + s0s1, g0, g1, g2) = (-(a1g0s0 + a1g1s1 + a1g1s2) + a1s0s1, a1g0, a1g1, a1g2) - // + - // a2 * (-(h0s0 + h1s1 + h1s2) + s0s2, h0, h1, h2) = (-(a2h0s0 + a2h1s1 + a2h1s2) + a2s0s2, a2h0, a2h1, a2h2) - // = - // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2) - for col_i in 1..cols { - let pmat: &VmpPMat<&[u8], BE> = &tsk.at(col_i - 1, col_j - 1).data; // Selects Enc(s[i]s[j]) - - // Extracts a[i] and multipies with Enc(s[i]s[j]) - for di in 0..dsize { - tmp_a.set_size((ci_dft.size() + di) / dsize); - - // Small optimization for dsize > 2 - // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then - // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}. - // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. - // It is possible to further ignore the last dsize-1 limbs, but this introduce - // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same - // noise is kept with respect to the ideal functionality. - tmp_dft_i.set_size(tsk.size() - ((dsize - di) as isize - 2).max(0) as usize); - - self.vec_znx_dft_copy(dsize, dsize - 1 - di, &mut tmp_a, 0, &ci_dft, col_i); - if di == 0 && col_i == 1 { - self.vmp_apply_dft_to_dft(&mut tmp_dft_i, &tmp_a, pmat, scratch_3); - } else { - self.vmp_apply_dft_to_dft_add(&mut tmp_dft_i, &tmp_a, pmat, di, scratch_3); - } - } - } - } + let mut res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft); // Adds -(sum a[i] * s[i]) + m) on the i-th column of tmp_idft_i // @@ -266,18 +228,17 @@ where // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0 -(a0s0 + a1s1 + a2s2) + M[i], x1, x2) // = // (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2) - self.vec_znx_dft_add_inplace(&mut tmp_dft_i, col_j, &ci_dft, 0); - let (mut tmp_idft, scratch_3) = scratch_2.take_vec_znx_big(self, 1, tsk.size()); - for i in 0..cols { - self.vec_znx_idft_apply_tmpa(&mut tmp_idft, 0, &mut tmp_dft_i, i); + self.vec_znx_big_add_small_inplace(&mut res_big, col, res.at(row, 0).data(), 0); + + for j in 0..cols { self.vec_znx_big_normalize( - basek_in, - &mut res.at_mut(row_i, col_j).data, - i, - basek_tsk, - &tmp_idft, - 0, - scratch_3, + res.base2k().as_usize(), + res.at_mut(row, col).data_mut(), + j, + tsk.base2k().as_usize(), + &res_big, + j, + scratch_2, ); } } diff --git a/poulpy-core/src/conversion/lwe_to_glwe.rs b/poulpy-core/src/conversion/lwe_to_glwe.rs index c759ee5..c9b40c3 100644 --- a/poulpy-core/src/conversion/lwe_to_glwe.rs +++ b/poulpy-core/src/conversion/lwe_to_glwe.rs @@ -1,5 +1,5 @@ use poulpy_hal::{ - api::ScratchTakeBasic, + api::{ScratchTakeBasic, VecZnxNormalize, VecZnxNormalizeTmpBytes}, layouts::{Backend, DataMut, Module, Scratch, VecZnx, ZnxView, ZnxViewMut, ZnxZero}, }; @@ -8,11 +8,10 @@ use crate::{ layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWELayout, GLWEToMut, LWE, LWEInfos, LWEToRef}, }; -impl GLWEFromLWE for Module where Self: GLWEKeyswitch {} - -pub trait GLWEFromLWE +impl GLWEFromLWE for Module where - Self: GLWEKeyswitch, + Self: GLWEKeyswitch + VecZnxNormalizeTmpBytes + VecZnxNormalize, + Scratch: ScratchTakeCore, { fn glwe_from_lwe_tmp_bytes(&self, glwe_infos: &R, lwe_infos: &A, key_infos: &K) -> usize where @@ -41,7 +40,6 @@ where R: GLWEToMut, A: LWEToRef, K: GGLWEPreparedToRef + GGLWEInfos, - Scratch: ScratchTakeCore, { let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let lwe: &LWE<&[u8]> = &lwe.to_ref(); @@ -105,6 +103,23 @@ where } } +pub trait GLWEFromLWE +where + Self: GLWEKeyswitch, +{ + fn glwe_from_lwe_tmp_bytes(&self, glwe_infos: &R, lwe_infos: &A, key_infos: &K) -> usize + where + R: GLWEInfos, + A: LWEInfos, + K: GGLWEInfos; + + fn glwe_from_lwe(&self, res: &mut R, lwe: &A, ksk: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + A: LWEToRef, + K: GGLWEPreparedToRef + GGLWEInfos; +} + impl GLWE> { pub fn from_lwe_tmp_bytes(module: &M, glwe_infos: &R, lwe_infos: &A, key_infos: &K) -> usize where diff --git a/poulpy-core/src/encryption/compressed/gglwe_to_ggsw_key.rs b/poulpy-core/src/encryption/compressed/gglwe_to_ggsw_key.rs new file mode 100644 index 0000000..92f382c --- /dev/null +++ b/poulpy-core/src/encryption/compressed/gglwe_to_ggsw_key.rs @@ -0,0 +1,124 @@ +use poulpy_hal::{ + api::{ModuleN, ScratchTakeBasic, VecZnxCopy}, + layouts::{Backend, DataMut, Module, Scratch}, + source::Source, +}; + +use crate::{ + GGLWECompressedEncryptSk, GetDistribution, ScratchTakeCore, + layouts::{ + GGLWEInfos, GGLWEToGGSWKeyCompressed, GGLWEToGGSWKeyCompressedToMut, GLWEInfos, GLWESecret, GLWESecretTensor, + GLWESecretTensorFactory, GLWESecretToRef, + prepared::{GLWESecretPrepared, GLWESecretPreparedFactory}, + }, +}; + +impl GGLWEToGGSWKeyCompressed> { + pub fn encrypt_sk_tmp_bytes(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: GGLWEToGGSWKeyCompressedEncryptSk, + { + module.gglwe_to_ggsw_key_encrypt_sk_tmp_bytes(infos) + } +} + +impl GGLWEToGGSWKeyCompressed { + pub fn encrypt_sk( + &mut self, + module: &M, + sk: &S, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + M: GGLWEToGGSWKeyCompressedEncryptSk, + S: GLWESecretToRef + GetDistribution + GLWEInfos, + Scratch: ScratchTakeCore, + { + module.gglwe_to_ggsw_key_encrypt_sk(self, sk, seed_xa, source_xe, scratch); + } +} + +pub trait GGLWEToGGSWKeyCompressedEncryptSk { + fn gglwe_to_ggsw_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos; + + fn gglwe_to_ggsw_key_encrypt_sk( + &self, + res: &mut R, + sk: &S, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGLWEToGGSWKeyCompressedToMut + GGLWEInfos, + S: GLWESecretToRef + GetDistribution + GLWEInfos; +} + +impl GGLWEToGGSWKeyCompressedEncryptSk for Module +where + Self: ModuleN + GGLWECompressedEncryptSk + GLWESecretTensorFactory + GLWESecretPreparedFactory + VecZnxCopy, + Scratch: ScratchTakeCore, +{ + fn gglwe_to_ggsw_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + let sk_prepared: usize = GLWESecretPrepared::bytes_of(self, infos.rank()); + let sk_tensor: usize = GLWESecretTensor::bytes_of_from_infos(infos); + let gglwe_encrypt: usize = self.gglwe_compressed_encrypt_sk_tmp_bytes(infos); + let sk_ij = GLWESecret::bytes_of(self.n().into(), infos.rank()); + (sk_prepared + sk_tensor + sk_ij) + gglwe_encrypt.max(self.glwe_secret_tensor_prepare_tmp_bytes(infos.rank())) + } + + fn gglwe_to_ggsw_key_encrypt_sk( + &self, + res: &mut R, + sk: &S, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGLWEToGGSWKeyCompressedToMut + GGLWEInfos, + S: GLWESecretToRef + GetDistribution + GLWEInfos, + { + assert_eq!(res.rank(), sk.rank()); + assert_eq!(res.n(), sk.n()); + + let res: &mut GGLWEToGGSWKeyCompressed<&mut [u8]> = &mut res.to_mut(); + let rank: usize = res.rank_out().as_usize(); + + let (mut sk_prepared, scratch_1) = scratch.take_glwe_secret_prepared(self, res.rank()); + let (mut sk_tensor, scratch_2) = scratch_1.take_glwe_secret_tensor(self.n().into(), res.rank()); + sk_prepared.prepare(self, sk); + sk_tensor.prepare(self, sk, scratch_2); + + let (mut sk_ij, scratch_3) = scratch_2.take_scalar_znx(self.n(), rank); + + let mut source_xa = Source::new(seed_xa); + + for i in 0..rank { + for j in 0..rank { + self.vec_znx_copy( + &mut sk_ij.as_vec_znx_mut(), + j, + &sk_tensor.at(i, j).as_vec_znx(), + 0, + ); + } + + let (seed_xa_tmp, _) = source_xa.branch(); + + res.at_mut(i).encrypt_sk( + self, + &sk_ij, + &sk_prepared, + seed_xa_tmp, + source_xe, + scratch_3, + ); + } + } +} diff --git a/poulpy-core/src/encryption/compressed/glwe_tensor_key.rs b/poulpy-core/src/encryption/compressed/glwe_tensor_key.rs index 14c9217..12af7ee 100644 --- a/poulpy-core/src/encryption/compressed/glwe_tensor_key.rs +++ b/poulpy-core/src/encryption/compressed/glwe_tensor_key.rs @@ -1,17 +1,15 @@ use poulpy_hal::{ - api::{ - ModuleN, ScratchTakeBasic, SvpApplyDftToDft, SvpPPolBytesOf, SvpPrepare, VecZnxBigBytesOf, VecZnxBigNormalize, - VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyTmpA, - }, + api::ScratchTakeBasic, layouts::{Backend, DataMut, Module, Scratch}, source::Source, }; use crate::{ - GGLWECompressedEncryptSk, GLWETensorKeyEncryptSk, GetDistribution, ScratchTakeCore, + GGLWECompressedEncryptSk, GetDistribution, ScratchTakeCore, layouts::{ - GGLWEInfos, GLWEInfos, GLWESecret, GLWESecretPrepared, GLWESecretPreparedFactory, GLWESecretToRef, - GLWETensorKeyCompressedAtMut, LWEInfos, Rank, compressed::GLWETensorKeyCompressed, + GGLWECompressedSeedMut, GGLWECompressedToMut, GGLWEInfos, GGLWELayout, GLWEInfos, GLWESecretPrepared, + GLWESecretPreparedFactory, GLWESecretTensor, GLWESecretTensorFactory, GLWESecretToRef, + compressed::GLWETensorKeyCompressed, }, }; @@ -34,7 +32,7 @@ impl GLWETensorKeyCompressed { source_xe: &mut Source, scratch: &mut Scratch, ) where - S: GLWESecretToRef + GetDistribution, + S: GLWESecretToRef + GetDistribution + GLWEInfos, M: GLWETensorKeyCompressedEncryptSk, { module.glwe_tensor_key_compressed_encrypt_sk(self, sk, seed_xa, source_xe, scratch); @@ -46,7 +44,7 @@ pub trait GLWETensorKeyCompressedEncryptSk { where A: GGLWEInfos; - fn glwe_tensor_key_compressed_encrypt_sk( + fn glwe_tensor_key_compressed_encrypt_sk( &self, res: &mut R, sk: &S, @@ -54,40 +52,38 @@ pub trait GLWETensorKeyCompressedEncryptSk { source_xe: &mut Source, scratch: &mut Scratch, ) where - D: DataMut, - R: GLWETensorKeyCompressedAtMut + GGLWEInfos, - S: GLWESecretToRef + GetDistribution; + R: GGLWECompressedToMut + GGLWEInfos + GGLWECompressedSeedMut, + S: GLWESecretToRef + GetDistribution + GLWEInfos; } impl GLWETensorKeyCompressedEncryptSk for Module where - Self: ModuleN - + GGLWECompressedEncryptSk - + GLWETensorKeyEncryptSk - + VecZnxDftApply - + SvpApplyDftToDft - + VecZnxIdftApplyTmpA - + VecZnxBigNormalize - + SvpPrepare - + SvpPPolBytesOf - + VecZnxDftBytesOf - + VecZnxBigBytesOf - + GLWESecretPreparedFactory, + Self: GGLWECompressedEncryptSk + GLWESecretPreparedFactory + GLWESecretTensorFactory, Scratch: ScratchTakeBasic + ScratchTakeCore, { fn glwe_tensor_key_compressed_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize where A: GGLWEInfos, { - GLWESecretPrepared::bytes_of(self, infos.rank_out()) - + self.bytes_of_vec_znx_dft(infos.rank_out().into(), 1) - + self.bytes_of_vec_znx_big(1, 1) - + self.bytes_of_vec_znx_dft(1, 1) - + GLWESecret::bytes_of(self.n().into(), Rank(1)) - + self.gglwe_compressed_encrypt_sk_tmp_bytes(infos) + let sk_prepared: usize = GLWESecretPrepared::bytes_of(self, infos.rank_out()); + let sk_tensor: usize = GLWESecretTensor::bytes_of_from_infos(infos); + + let tensor_infos: GGLWELayout = GGLWELayout { + n: infos.n(), + base2k: infos.base2k(), + k: infos.k(), + rank_in: GLWESecretTensor::pairs(infos.rank().into()).into(), + rank_out: infos.rank_out(), + dnum: infos.dnum(), + dsize: infos.dsize(), + }; + + let gglwe_encrypt: usize = self.gglwe_compressed_encrypt_sk_tmp_bytes(&tensor_infos); + + (sk_prepared + sk_tensor) + gglwe_encrypt.max(self.glwe_secret_tensor_prepare_tmp_bytes(infos.rank())) } - fn glwe_tensor_key_compressed_encrypt_sk( + fn glwe_tensor_key_compressed_encrypt_sk( &self, res: &mut R, sk: &S, @@ -95,62 +91,24 @@ where source_xe: &mut Source, scratch: &mut Scratch, ) where - D: DataMut, - R: GGLWEInfos + GLWETensorKeyCompressedAtMut, - S: GLWESecretToRef + GetDistribution, + R: GGLWEInfos + GGLWECompressedToMut + GGLWECompressedSeedMut, + S: GLWESecretToRef + GetDistribution + GLWEInfos, { - let (mut sk_dft_prep, scratch_1) = scratch.take_glwe_secret_prepared(self, res.rank()); - sk_dft_prep.prepare(self, sk); + assert_eq!(res.rank_out(), sk.rank()); + assert_eq!(res.n(), sk.n()); - let sk: &GLWESecret<&[u8]> = &sk.to_ref(); + let (mut sk_prepared, scratch_1) = scratch.take_glwe_secret_prepared(self, res.rank()); + let (mut sk_tensor, scratch_2) = scratch_1.take_glwe_secret_tensor(self.n().into(), res.rank()); + sk_prepared.prepare(self, sk); + sk_tensor.prepare(self, sk, scratch_2); - #[cfg(debug_assertions)] - { - assert_eq!(res.rank_out(), sk.rank()); - assert_eq!(res.n(), sk.n()); - } - - // let n: usize = sk.n().into(); - let rank: usize = res.rank_out().into(); - - let (mut sk_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, rank, 1); - - for i in 0..rank { - self.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); - } - - let (mut sk_ij_big, scratch_3) = scratch_2.take_vec_znx_big(self, 1, 1); - let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(self.n().into(), Rank(1)); - let (mut sk_ij_dft, scratch_5) = scratch_4.take_vec_znx_dft(self, 1, 1); - - let mut source_xa: Source = Source::new(seed_xa); - - for i in 0..rank { - for j in i..rank { - self.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_dft_prep.data, j, &sk_dft, i); - - self.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0); - self.vec_znx_big_normalize( - res.base2k().into(), - &mut sk_ij.data.as_vec_znx_mut(), - 0, - res.base2k().into(), - &sk_ij_big, - 0, - scratch_5, - ); - - let (seed_xa_tmp, _) = source_xa.branch(); - - self.gglwe_compressed_encrypt_sk( - res.at_mut(i, j), - &sk_ij.data, - &sk_dft_prep, - seed_xa_tmp, - source_xe, - scratch_5, - ); - } - } + self.gglwe_compressed_encrypt_sk( + res, + &sk_tensor.data, + &sk_prepared, + seed_xa, + source_xe, + scratch_2, + ); } } diff --git a/poulpy-core/src/encryption/compressed/mod.rs b/poulpy-core/src/encryption/compressed/mod.rs index e96eeb5..1b21e1f 100644 --- a/poulpy-core/src/encryption/compressed/mod.rs +++ b/poulpy-core/src/encryption/compressed/mod.rs @@ -1,4 +1,5 @@ mod gglwe; +mod gglwe_to_ggsw_key; mod ggsw; mod glwe_automorphism_key; mod glwe_ct; @@ -6,6 +7,7 @@ mod glwe_switching_key; mod glwe_tensor_key; pub use gglwe::*; +pub use gglwe_to_ggsw_key::*; pub use ggsw::*; pub use glwe_automorphism_key::*; pub use glwe_ct::*; diff --git a/poulpy-core/src/encryption/gglwe.rs b/poulpy-core/src/encryption/gglwe.rs index ba78cde..a50b565 100644 --- a/poulpy-core/src/encryption/gglwe.rs +++ b/poulpy-core/src/encryption/gglwe.rs @@ -148,7 +148,7 @@ where // Example for ksk rank 2 to rank 3: // // (-(a0*s0 + a1*s1 + a2*s2) + s0', a0, a1, a2) - // (-(b0*s0 + b1*s1 + b2*s2) + s0', b0, b1, b2) + // (-(b0*s0 + b1*s1 + b2*s2) + s1', b0, b1, b2) // // Example ksk rank 2 to rank 1 // diff --git a/poulpy-core/src/encryption/gglwe_to_ggsw_key.rs b/poulpy-core/src/encryption/gglwe_to_ggsw_key.rs new file mode 100644 index 0000000..017455f --- /dev/null +++ b/poulpy-core/src/encryption/gglwe_to_ggsw_key.rs @@ -0,0 +1,112 @@ +use poulpy_hal::{ + api::{ModuleN, ScratchTakeBasic, VecZnxCopy}, + layouts::{Backend, DataMut, Module, Scratch}, + source::Source, +}; + +use crate::{ + GGLWEEncryptSk, GetDistribution, ScratchTakeCore, + layouts::{ + GGLWEInfos, GGLWEToGGSWKey, GGLWEToGGSWKeyToMut, GLWEInfos, GLWESecret, GLWESecretTensor, GLWESecretTensorFactory, + GLWESecretToRef, + prepared::{GLWESecretPrepared, GLWESecretPreparedFactory}, + }, +}; + +impl GGLWEToGGSWKey> { + pub fn encrypt_sk_tmp_bytes(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: GGLWEToGGSWKeyEncryptSk, + { + module.gglwe_to_ggsw_key_encrypt_sk_tmp_bytes(infos) + } +} + +impl GGLWEToGGSWKey { + pub fn encrypt_sk( + &mut self, + module: &M, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + M: GGLWEToGGSWKeyEncryptSk, + S: GLWESecretToRef + GetDistribution + GLWEInfos, + Scratch: ScratchTakeCore, + { + module.gglwe_to_ggsw_key_encrypt_sk(self, sk, source_xa, source_xe, scratch); + } +} + +pub trait GGLWEToGGSWKeyEncryptSk { + fn gglwe_to_ggsw_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos; + + fn gglwe_to_ggsw_key_encrypt_sk( + &self, + res: &mut R, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGLWEToGGSWKeyToMut, + S: GLWESecretToRef + GetDistribution + GLWEInfos; +} + +impl GGLWEToGGSWKeyEncryptSk for Module +where + Self: ModuleN + GGLWEEncryptSk + GLWESecretTensorFactory + GLWESecretPreparedFactory + VecZnxCopy, + Scratch: ScratchTakeCore, +{ + fn gglwe_to_ggsw_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + let sk_prepared: usize = GLWESecretPrepared::bytes_of(self, infos.rank()); + let sk_tensor: usize = GLWESecretTensor::bytes_of_from_infos(infos); + let gglwe_encrypt: usize = self.gglwe_encrypt_sk_tmp_bytes(infos); + let sk_ij = GLWESecret::bytes_of(self.n().into(), infos.rank()); + (sk_prepared + sk_tensor + sk_ij) + gglwe_encrypt.max(self.glwe_secret_tensor_prepare_tmp_bytes(infos.rank())) + } + + fn gglwe_to_ggsw_key_encrypt_sk( + &self, + res: &mut R, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGLWEToGGSWKeyToMut, + S: GLWESecretToRef + GetDistribution + GLWEInfos, + { + let res: &mut GGLWEToGGSWKey<&mut [u8]> = &mut res.to_mut(); + + let rank: usize = res.rank_out().as_usize(); + + let (mut sk_prepared, scratch_1) = scratch.take_glwe_secret_prepared(self, res.rank()); + let (mut sk_tensor, scratch_2) = scratch_1.take_glwe_secret_tensor(self.n().into(), res.rank()); + sk_prepared.prepare(self, sk); + sk_tensor.prepare(self, sk, scratch_2); + + let (mut sk_ij, scratch_3) = scratch_2.take_scalar_znx(self.n(), rank); + + for i in 0..rank { + for j in 0..rank { + self.vec_znx_copy( + &mut sk_ij.as_vec_znx_mut(), + j, + &sk_tensor.at(i, j).as_vec_znx(), + 0, + ); + } + + res.at_mut(i) + .encrypt_sk(self, &sk_ij, &sk_prepared, source_xa, source_xe, scratch_3); + } + } +} diff --git a/poulpy-core/src/encryption/glwe_tensor_key.rs b/poulpy-core/src/encryption/glwe_tensor_key.rs index b7afae5..08df09b 100644 --- a/poulpy-core/src/encryption/glwe_tensor_key.rs +++ b/poulpy-core/src/encryption/glwe_tensor_key.rs @@ -1,8 +1,5 @@ use poulpy_hal::{ - api::{ - ModuleN, ScratchTakeBasic, SvpApplyDftToDft, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, - VecZnxIdftApplyTmpA, - }, + api::ModuleN, layouts::{Backend, DataMut, Module, Scratch}, source::Source, }; @@ -10,7 +7,8 @@ use poulpy_hal::{ use crate::{ GGLWEEncryptSk, GetDistribution, ScratchTakeCore, layouts::{ - GGLWE, GGLWEInfos, GLWEInfos, GLWESecret, GLWESecretToRef, GLWETensorKey, GLWETensorKeyToMut, LWEInfos, Rank, + GGLWEInfos, GGLWELayout, GGLWEToMut, GLWEInfos, GLWESecretTensor, GLWESecretTensorFactory, GLWESecretToRef, + GLWETensorKey, prepared::{GLWESecretPrepared, GLWESecretPreparedFactory}, }, }; @@ -55,33 +53,35 @@ pub trait GLWETensorKeyEncryptSk { source_xe: &mut Source, scratch: &mut Scratch, ) where - R: GLWETensorKeyToMut, + R: GGLWEToMut + GGLWEInfos, S: GLWESecretToRef + GetDistribution + GLWEInfos; } impl GLWETensorKeyEncryptSk for Module where - Self: ModuleN - + GGLWEEncryptSk - + VecZnxDftBytesOf - + VecZnxBigBytesOf - + GLWESecretPreparedFactory - + VecZnxDftApply - + SvpApplyDftToDft - + VecZnxIdftApplyTmpA - + VecZnxBigNormalize, + Self: ModuleN + GGLWEEncryptSk + GLWESecretPreparedFactory + GLWESecretTensorFactory, Scratch: ScratchTakeCore, { fn glwe_tensor_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize where A: GGLWEInfos, { - GLWESecretPrepared::bytes_of(self, infos.rank_out()) - + self.bytes_of_vec_znx_dft(infos.rank_out().into(), 1) - + self.bytes_of_vec_znx_big(1, 1) - + self.bytes_of_vec_znx_dft(1, 1) - + GLWESecret::bytes_of(self.n().into(), Rank(1)) - + GGLWE::encrypt_sk_tmp_bytes(self, infos) + let sk_prepared: usize = GLWESecretPrepared::bytes_of(self, infos.rank_out()); + let sk_tensor: usize = GLWESecretTensor::bytes_of_from_infos(infos); + + let tensor_infos: GGLWELayout = GGLWELayout { + n: infos.n(), + base2k: infos.base2k(), + k: infos.k(), + rank_in: GLWESecretTensor::pairs(infos.rank().into()).into(), + rank_out: infos.rank_out(), + dnum: infos.dnum(), + dsize: infos.dsize(), + }; + + let gglwe_encrypt: usize = self.gglwe_encrypt_sk_tmp_bytes(&tensor_infos); + + (sk_prepared + sk_tensor) + gglwe_encrypt.max(self.glwe_secret_tensor_prepare_tmp_bytes(infos.rank())) } fn glwe_tensor_key_encrypt_sk( @@ -92,56 +92,24 @@ where source_xe: &mut Source, scratch: &mut Scratch, ) where - R: GLWETensorKeyToMut, + R: GGLWEToMut + GGLWEInfos, S: GLWESecretToRef + GetDistribution + GLWEInfos, { - let res: &mut GLWETensorKey<&mut [u8]> = &mut res.to_mut(); - - // let n: RingDegree = sk.n(); - let rank: Rank = res.rank_out(); - - let (mut sk_prepared, scratch_1) = scratch.take_glwe_secret_prepared(self, sk.rank()); - sk_prepared.prepare(self, sk); - - let sk: &GLWESecret<&[u8]> = &sk.to_ref(); - assert_eq!(res.rank_out(), sk.rank()); assert_eq!(res.n(), sk.n()); - let (mut sk_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, rank.into(), 1); + let (mut sk_prepared, scratch_1) = scratch.take_glwe_secret_prepared(self, res.rank()); + let (mut sk_tensor, scratch_2) = scratch_1.take_glwe_secret_tensor(self.n().into(), res.rank()); + sk_prepared.prepare(self, sk); + sk_tensor.prepare(self, sk, scratch_2); - (0..rank.into()).for_each(|i| { - self.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); - }); - - let (mut sk_ij_big, scratch_3) = scratch_2.take_vec_znx_big(self, 1, 1); - let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(self.n().into(), Rank(1)); - let (mut sk_ij_dft, scratch_5) = scratch_4.take_vec_znx_dft(self, 1, 1); - - (0..rank.into()).for_each(|i| { - (i..rank.into()).for_each(|j| { - self.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_prepared.data, j, &sk_dft, i); - - self.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0); - self.vec_znx_big_normalize( - res.base2k().into(), - &mut sk_ij.data.as_vec_znx_mut(), - 0, - res.base2k().into(), - &sk_ij_big, - 0, - scratch_5, - ); - - res.at_mut(i, j).encrypt_sk( - self, - &sk_ij.data, - &sk_prepared, - source_xa, - source_xe, - scratch_5, - ); - }); - }) + self.gglwe_encrypt_sk( + res, + &sk_tensor.data, + &sk_prepared, + source_xa, + source_xe, + scratch_2, + ); } } diff --git a/poulpy-core/src/encryption/glwe_to_lwe_switching_key.rs b/poulpy-core/src/encryption/glwe_to_lwe_key.rs similarity index 83% rename from poulpy-core/src/encryption/glwe_to_lwe_switching_key.rs rename to poulpy-core/src/encryption/glwe_to_lwe_key.rs index 71877a4..0609fb6 100644 --- a/poulpy-core/src/encryption/glwe_to_lwe_switching_key.rs +++ b/poulpy-core/src/encryption/glwe_to_lwe_key.rs @@ -7,23 +7,22 @@ use poulpy_hal::{ use crate::{ GGLWEEncryptSk, ScratchTakeCore, layouts::{ - GGLWE, GGLWEInfos, GGLWEToMut, GLWESecret, GLWESecretToRef, GLWEToLWESwitchingKey, LWEInfos, LWESecret, LWESecretToRef, - Rank, + GGLWE, GGLWEInfos, GGLWEToMut, GLWESecret, GLWESecretToRef, GLWEToLWEKey, LWEInfos, LWESecret, LWESecretToRef, Rank, prepared::{GLWESecretPrepared, GLWESecretPreparedFactory}, }, }; -impl GLWEToLWESwitchingKey> { +impl GLWEToLWEKey> { pub fn encrypt_sk_tmp_bytes(module: &M, infos: &A) -> usize where A: GGLWEInfos, M: GLWEToLWESwitchingKeyEncryptSk, { - module.glwe_to_lwe_switching_key_encrypt_sk_tmp_bytes(infos) + module.glwe_to_lwe_key_encrypt_sk_tmp_bytes(infos) } } -impl GLWEToLWESwitchingKey { +impl GLWEToLWEKey { pub fn encrypt_sk( &mut self, module: &M, @@ -38,16 +37,16 @@ impl GLWEToLWESwitchingKey { S2: GLWESecretToRef, Scratch: ScratchTakeCore, { - module.glwe_to_lwe_switching_key_encrypt_sk(self, sk_lwe, sk_glwe, source_xa, source_xe, scratch); + module.glwe_to_lwe_key_encrypt_sk(self, sk_lwe, sk_glwe, source_xa, source_xe, scratch); } } pub trait GLWEToLWESwitchingKeyEncryptSk { - fn glwe_to_lwe_switching_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + fn glwe_to_lwe_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize where A: GGLWEInfos; - fn glwe_to_lwe_switching_key_encrypt_sk( + fn glwe_to_lwe_key_encrypt_sk( &self, res: &mut R, sk_lwe: &S1, @@ -70,7 +69,7 @@ where + VecZnxAutomorphismInplaceTmpBytes, Scratch: ScratchTakeCore, { - fn glwe_to_lwe_switching_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + fn glwe_to_lwe_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize where A: GGLWEInfos, { @@ -79,7 +78,7 @@ where .max(GLWESecret::bytes_of(self.n().into(), infos.rank_in()) + self.vec_znx_automorphism_inplace_tmp_bytes()) } - fn glwe_to_lwe_switching_key_encrypt_sk( + fn glwe_to_lwe_key_encrypt_sk( &self, res: &mut R, sk_lwe: &S1, diff --git a/poulpy-core/src/encryption/lwe_to_glwe_switching_key.rs b/poulpy-core/src/encryption/lwe_to_glwe_key.rs similarity index 81% rename from poulpy-core/src/encryption/lwe_to_glwe_switching_key.rs rename to poulpy-core/src/encryption/lwe_to_glwe_key.rs index af31420..c5fcd15 100644 --- a/poulpy-core/src/encryption/lwe_to_glwe_switching_key.rs +++ b/poulpy-core/src/encryption/lwe_to_glwe_key.rs @@ -8,21 +8,21 @@ use crate::{ GGLWEEncryptSk, ScratchTakeCore, layouts::{ GGLWE, GGLWEInfos, GGLWEToMut, GLWESecret, GLWESecretPreparedFactory, GLWESecretPreparedToRef, LWEInfos, LWESecret, - LWESecretToRef, LWEToGLWESwitchingKey, Rank, + LWESecretToRef, LWEToGLWEKey, Rank, }, }; -impl LWEToGLWESwitchingKey> { +impl LWEToGLWEKey> { pub fn encrypt_sk_tmp_bytes(module: &M, infos: &A) -> usize where A: GGLWEInfos, M: LWEToGLWESwitchingKeyEncryptSk, { - module.lwe_to_glwe_switching_key_encrypt_sk_tmp_bytes(infos) + module.lwe_to_glwe_key_encrypt_sk_tmp_bytes(infos) } } -impl LWEToGLWESwitchingKey { +impl LWEToGLWEKey { pub fn encrypt_sk( &mut self, module: &M, @@ -37,16 +37,16 @@ impl LWEToGLWESwitchingKey { M: LWEToGLWESwitchingKeyEncryptSk, Scratch: ScratchTakeCore, { - module.lwe_to_glwe_switching_key_encrypt_sk(self, sk_lwe, sk_glwe, source_xa, source_xe, scratch); + module.lwe_to_glwe_key_encrypt_sk(self, sk_lwe, sk_glwe, source_xa, source_xe, scratch); } } pub trait LWEToGLWESwitchingKeyEncryptSk { - fn lwe_to_glwe_switching_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + fn lwe_to_glwe_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize where A: GGLWEInfos; - fn lwe_to_glwe_switching_key_encrypt_sk( + fn lwe_to_glwe_key_encrypt_sk( &self, res: &mut R, sk_lwe: &S1, @@ -69,20 +69,20 @@ where + VecZnxAutomorphismInplaceTmpBytes, Scratch: ScratchTakeCore, { - fn lwe_to_glwe_switching_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + fn lwe_to_glwe_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize where A: GGLWEInfos, { debug_assert_eq!( infos.rank_in(), Rank(1), - "rank_in != 1 is not supported for LWEToGLWESwitchingKey" + "rank_in != 1 is not supported for LWEToGLWEKeyPrepared" ); GLWESecret::bytes_of(self.n().into(), infos.rank_in()) + GGLWE::encrypt_sk_tmp_bytes(self, infos).max(self.vec_znx_automorphism_inplace_tmp_bytes()) } - fn lwe_to_glwe_switching_key_encrypt_sk( + fn lwe_to_glwe_key_encrypt_sk( &self, res: &mut R, sk_lwe: &S1, diff --git a/poulpy-core/src/encryption/mod.rs b/poulpy-core/src/encryption/mod.rs index 7a391a6..d64757f 100644 --- a/poulpy-core/src/encryption/mod.rs +++ b/poulpy-core/src/encryption/mod.rs @@ -1,28 +1,30 @@ mod compressed; mod gglwe; +mod gglwe_to_ggsw_key; mod ggsw; mod glwe; mod glwe_automorphism_key; mod glwe_public_key; mod glwe_switching_key; mod glwe_tensor_key; -mod glwe_to_lwe_switching_key; +mod glwe_to_lwe_key; mod lwe; mod lwe_switching_key; -mod lwe_to_glwe_switching_key; +mod lwe_to_glwe_key; pub use compressed::*; pub use gglwe::*; +pub use gglwe_to_ggsw_key::*; pub use ggsw::*; pub use glwe::*; pub use glwe_automorphism_key::*; pub use glwe_public_key::*; pub use glwe_switching_key::*; pub use glwe_tensor_key::*; -pub use glwe_to_lwe_switching_key::*; +pub use glwe_to_lwe_key::*; pub use lwe::*; pub use lwe_switching_key::*; -pub use lwe_to_glwe_switching_key::*; +pub use lwe_to_glwe_key::*; pub const SIGMA: f64 = 3.2; pub(crate) const SIGMA_BOUND: f64 = 6.0 * SIGMA; diff --git a/poulpy-core/src/glwe_packer.rs b/poulpy-core/src/glwe_packer.rs new file mode 100644 index 0000000..da8c93e --- /dev/null +++ b/poulpy-core/src/glwe_packer.rs @@ -0,0 +1,388 @@ +use std::collections::HashMap; + +use poulpy_hal::{ + api::ModuleLogN, + layouts::{Backend, GaloisElement, Module, Scratch}, +}; + +use crate::{ + GLWEAdd, GLWEAutomorphism, GLWECopy, GLWENormalize, GLWERotate, GLWEShift, GLWESub, ScratchTakeCore, + glwe_trace::GLWETrace, + layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, GetGaloisElement, LWEInfos}, +}; + +/// [GLWEPacker] enables only the fly GLWE packing +/// with constant memory of Log(N) ciphertexts. +/// Main difference with usual GLWE packing is that +/// the output is bit-reversed. +pub struct GLWEPacker { + accumulators: Vec, + log_batch: usize, + counter: usize, +} + +/// [Accumulator] stores intermediate packing result. +/// There are Log(N) such accumulators in a [GLWEPacker]. +struct Accumulator { + data: GLWE>, + value: bool, // Implicit flag for zero ciphertext + control: bool, // Can be combined with incoming value +} + +impl Accumulator { + /// Allocates a new [Accumulator]. + /// + /// #Arguments + /// + /// * `module`: static backend FFT tables. + /// * `base2k`: base 2 logarithm of the GLWE ciphertext in memory digit representation. + /// * `k`: base 2 precision of the GLWE ciphertext precision over the Torus. + /// * `rank`: rank of the GLWE ciphertext. + pub fn alloc(infos: &A) -> Self + where + A: GLWEInfos, + { + Self { + data: GLWE::alloc_from_infos(infos), + value: false, + control: false, + } + } +} + +impl GLWEPacker { + /// Instantiates a new [GLWEPacker]. + /// + /// # Arguments + /// + /// * `log_batch`: packs coefficients which are multiples of X^{N/2^log_batch}. + /// i.e. with `log_batch=0` only the constant coefficient is packed + /// and N GLWE ciphertext can be packed. With `log_batch=2` all coefficients + /// which are multiples of X^{N/4} are packed. Meaning that N/4 ciphertexts + /// can be packed. + pub fn alloc(infos: &A, log_batch: usize) -> Self + where + A: GLWEInfos, + { + let mut accumulators: Vec = Vec::::new(); + let log_n: usize = infos.n().log2(); + (0..log_n - log_batch).for_each(|_| accumulators.push(Accumulator::alloc(infos))); + GLWEPacker { + accumulators, + log_batch, + counter: 0, + } + } + + /// Implicit reset of the internal state (to be called before a new packing procedure). + fn reset(&mut self) { + for i in 0..self.accumulators.len() { + self.accumulators[i].value = false; + self.accumulators[i].control = false; + } + self.counter = 0; + } + + /// Number of scratch space bytes required to call [Self::add]. + pub fn tmp_bytes(module: &M, res_infos: &R, key_infos: &K) -> usize + where + R: GLWEInfos, + K: GGLWEInfos, + M: GLWEPackerOps, + { + GLWE::bytes_of_from_infos(res_infos) + + module + .glwe_rsh_tmp_byte() + .max(module.glwe_automorphism_tmp_bytes(res_infos, res_infos, key_infos)) + } + + pub fn galois_elements(module: &M) -> Vec + where + M: GLWETrace, + { + module.glwe_trace_galois_elements() + } + + /// Adds a GLWE ciphertext to the [GLWEPacker]. + /// #Arguments + /// + /// * `module`: static backend FFT tables. + /// * `res`: space to append fully packed ciphertext. Only when the number + /// of packed ciphertexts reaches N/2^log_batch is a result written. + /// * `a`: ciphertext to pack. Can optionally give None to pack a 0 ciphertext. + /// * `auto_keys`: a [HashMap] containing the [AutomorphismKeyExec]s. + /// * `scratch`: scratch space of size at least [Self::tmp_bytes]. + pub fn add(&mut self, module: &M, a: Option<&A>, auto_keys: &HashMap, scratch: &mut Scratch) + where + A: GLWEToRef + GLWEInfos, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, + M: GLWEPackerOps, + Scratch: ScratchTakeCore, + { + assert!( + (self.counter as u32) < self.accumulators[0].data.n(), + "Packing limit of {} reached", + self.accumulators[0].data.n().0 as usize >> self.log_batch + ); + + module.packer_add(self, a, self.log_batch, auto_keys, scratch); + self.counter += 1 << self.log_batch; + } + + /// Flush result to`res`. + pub fn flush(&mut self, module: &M, res: &mut R) + where + R: GLWEToMut, + M: GLWEPackerOps, + { + assert!(self.counter as u32 == self.accumulators[0].data.n()); + // Copy result GLWE into res GLWE + module.glwe_copy( + res, + &self.accumulators[module.log_n() - self.log_batch - 1].data, + ); + + self.reset(); + } +} + +impl GLWEPackerOps for Module where + Self: Sized + + ModuleLogN + + GLWEAutomorphism + + GaloisElement + + GLWERotate + + GLWESub + + GLWEShift + + GLWEAdd + + GLWENormalize + + GLWECopy + + GLWEAutomorphism + + GaloisElement + + GLWERotate + + GLWESub + + GLWEShift + + GLWEAdd + + GLWENormalize +{ +} + +pub trait GLWEPackerOps +where + Self: Sized + + ModuleLogN + + GLWEAutomorphism + + GaloisElement + + GLWERotate + + GLWESub + + GLWEShift + + GLWEAdd + + GLWENormalize + + GLWECopy + + GLWEAutomorphism + + GaloisElement + + GLWERotate + + GLWESub + + GLWEShift + + GLWEAdd + + GLWENormalize, +{ + fn packer_add( + &self, + packer: &mut GLWEPacker, + a: Option<&A>, + i: usize, + auto_keys: &HashMap, + scratch: &mut Scratch, + ) where + A: GLWEToRef + GLWEInfos, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, + Scratch: ScratchTakeCore, + { + pack_core(self, a, &mut packer.accumulators, i, auto_keys, scratch) + } +} + +fn pack_core( + module: &M, + a: Option<&A>, + accumulators: &mut [Accumulator], + i: usize, + auto_keys: &HashMap, + scratch: &mut Scratch, +) where + A: GLWEToRef + GLWEInfos, + M: ModuleLogN + + GLWEAutomorphism + + GaloisElement + + GLWERotate + + GLWESub + + GLWEShift + + GLWEAdd + + GLWENormalize + + GLWECopy + + GLWEAutomorphism + + GaloisElement + + GLWERotate + + GLWESub + + GLWEShift + + GLWEAdd + + GLWENormalize, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, + Scratch: ScratchTakeCore, +{ + let log_n: usize = module.log_n(); + + if i == log_n { + return; + } + + // Isolate the first accumulator + let (acc_prev, acc_next) = accumulators.split_at_mut(1); + + // Control = true accumlator is free to overide + if !acc_prev[0].control { + let acc_mut_ref: &mut Accumulator = &mut acc_prev[0]; // from split_at_mut + + // No previous value -> copies and sets flags accordingly + if let Some(a_ref) = a { + module.glwe_copy(&mut acc_mut_ref.data, a_ref); + acc_mut_ref.value = true + } else { + acc_mut_ref.value = false + } + acc_mut_ref.control = true; // Able to be combined on next call + } else { + // Compresses acc_prev <- combine(acc_prev, a). + combine(module, &mut acc_prev[0], a, i, auto_keys, scratch); + acc_prev[0].control = false; + + // Propagates to next accumulator + if acc_prev[0].value { + pack_core( + module, + Some(&acc_prev[0].data), + acc_next, + i + 1, + auto_keys, + scratch, + ); + } else { + pack_core( + module, + None::<&GLWE>>, + acc_next, + i + 1, + auto_keys, + scratch, + ); + } + } +} + +fn combine( + module: &M, + acc: &mut Accumulator, + b: Option<&B>, + i: usize, + auto_keys: &HashMap, + scratch: &mut Scratch, +) where + B: GLWEToRef + GLWEInfos, + B: GLWEToRef + GLWEInfos, + M: ModuleLogN + + GLWEAutomorphism + + GaloisElement + + GLWERotate + + GLWESub + + GLWEShift + + GLWEAdd + + GLWENormalize + + GLWECopy + + GLWEAutomorphism + + GaloisElement + + GLWERotate + + GLWESub + + GLWEShift + + GLWEAdd + + GLWENormalize, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, + Scratch: ScratchTakeCore, +{ + let log_n: usize = acc.data.n().log2(); + let a: &mut GLWE> = &mut acc.data; + + let gal_el: i64 = if i == 0 { + -1 + } else { + module.galois_element(1 << (i - 1)) + }; + + let t: i64 = 1 << (log_n - i - 1); + + // Goal is to evaluate: a = a + b*X^t + phi(a - b*X^t)) + // We also use the identity: AUTO(a * X^t, g) = -X^t * AUTO(a, g) + // where t = 2^(log_n - i - 1) and g = 5^{2^(i - 1)} + // Different cases for wether a and/or b are zero. + // + // Implicite RSH without modulus switch, introduces extra I(X) * Q/2 on decryption. + // Necessary so that the scaling of the plaintext remains constant. + // It however is ok to do so here because coefficients are eventually + // either mapped to garbage or twice their value which vanishes I(X) + // since 2*(I(X) * Q/2) = I(X) * Q = 0 mod Q. + if acc.value { + if let Some(b) = b { + let (mut tmp_b, scratch_1) = scratch.take_glwe(a); + + // a = a * X^-t + module.glwe_rotate_inplace(-t, a, scratch_1); + + // tmp_b = a * X^-t - b + module.glwe_sub(&mut tmp_b, a, b); + module.glwe_rsh(1, &mut tmp_b, scratch_1); + + // a = a * X^-t + b + module.glwe_add_inplace(a, b); + module.glwe_rsh(1, a, scratch_1); + + module.glwe_normalize_inplace(&mut tmp_b, scratch_1); + + // tmp_b = phi(a * X^-t - b) + if let Some(auto_key) = auto_keys.get(&gal_el) { + module.glwe_automorphism_inplace(&mut tmp_b, auto_key, scratch_1); + } else { + panic!("auto_key[{gal_el}] not found"); + } + + // a = a * X^-t + b - phi(a * X^-t - b) + module.glwe_sub_inplace(a, &tmp_b); + module.glwe_normalize_inplace(a, scratch_1); + + // a = a + b * X^t - phi(a * X^-t - b) * X^t + // = a + b * X^t - phi(a * X^-t - b) * - phi(X^t) + // = a + b * X^t + phi(a - b * X^t) + module.glwe_rotate_inplace(t, a, scratch_1); + } else { + module.glwe_rsh(1, a, scratch); + // a = a + phi(a) + if let Some(auto_key) = auto_keys.get(&gal_el) { + module.glwe_automorphism_add_inplace(a, auto_key, scratch); + } else { + panic!("auto_key[{gal_el}] not found"); + } + } + } else if let Some(b) = b { + let (mut tmp_b, scratch_1) = scratch.take_glwe(a); + module.glwe_rotate(t, &mut tmp_b, b); + module.glwe_rsh(1, &mut tmp_b, scratch_1); + + // a = (b* X^t - phi(b* X^t)) + if let Some(auto_key) = auto_keys.get(&gal_el) { + module.glwe_automorphism_sub_negate(a, &tmp_b, auto_key, scratch_1); + } else { + panic!("auto_key[{gal_el}] not found"); + } + + acc.value = true; + } +} diff --git a/poulpy-core/src/glwe_packing.rs b/poulpy-core/src/glwe_packing.rs index 09540b2..6debd0d 100644 --- a/poulpy-core/src/glwe_packing.rs +++ b/poulpy-core/src/glwe_packing.rs @@ -7,166 +7,23 @@ use poulpy_hal::{ use crate::{ GLWEAdd, GLWEAutomorphism, GLWECopy, GLWENormalize, GLWERotate, GLWEShift, GLWESub, ScratchTakeCore, - glwe_trace::GLWETrace, - layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, GetGaloisElement, LWEInfos}, + layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWEInfos, GLWEToMut, GLWEToRef, GetGaloisElement}, }; - -/// [GLWEPacker] enables only the fly GLWE packing -/// with constant memory of Log(N) ciphertexts. -/// Main difference with usual GLWE packing is that -/// the output is bit-reversed. -pub struct GLWEPacker { - accumulators: Vec, - log_batch: usize, - counter: usize, +pub trait GLWEPacking { + /// Packs [x_0: GLWE(m_0), x_1: GLWE(m_1), ..., x_i: GLWE(m_i)] + /// to [0: GLWE(m_0 * X^x_0 + m_1 * X^x_1 + ... + m_i * X^x_i)] + fn glwe_pack( + &self, + cts: &mut HashMap, + log_gap_out: usize, + keys: &HashMap, + scratch: &mut Scratch, + ) where + R: GLWEToMut + GLWEToRef + GLWEInfos, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos; } -/// [Accumulator] stores intermediate packing result. -/// There are Log(N) such accumulators in a [GLWEPacker]. -struct Accumulator { - data: GLWE>, - value: bool, // Implicit flag for zero ciphertext - control: bool, // Can be combined with incoming value -} - -impl Accumulator { - /// Allocates a new [Accumulator]. - /// - /// #Arguments - /// - /// * `module`: static backend FFT tables. - /// * `base2k`: base 2 logarithm of the GLWE ciphertext in memory digit representation. - /// * `k`: base 2 precision of the GLWE ciphertext precision over the Torus. - /// * `rank`: rank of the GLWE ciphertext. - pub fn alloc(infos: &A) -> Self - where - A: GLWEInfos, - { - Self { - data: GLWE::alloc_from_infos(infos), - value: false, - control: false, - } - } -} - -impl GLWEPacker { - /// Instantiates a new [GLWEPacker]. - /// - /// # Arguments - /// - /// * `log_batch`: packs coefficients which are multiples of X^{N/2^log_batch}. - /// i.e. with `log_batch=0` only the constant coefficient is packed - /// and N GLWE ciphertext can be packed. With `log_batch=2` all coefficients - /// which are multiples of X^{N/4} are packed. Meaning that N/4 ciphertexts - /// can be packed. - pub fn alloc(infos: &A, log_batch: usize) -> Self - where - A: GLWEInfos, - { - let mut accumulators: Vec = Vec::::new(); - let log_n: usize = infos.n().log2(); - (0..log_n - log_batch).for_each(|_| accumulators.push(Accumulator::alloc(infos))); - GLWEPacker { - accumulators, - log_batch, - counter: 0, - } - } - - /// Implicit reset of the internal state (to be called before a new packing procedure). - fn reset(&mut self) { - for i in 0..self.accumulators.len() { - self.accumulators[i].value = false; - self.accumulators[i].control = false; - } - self.counter = 0; - } - - /// Number of scratch space bytes required to call [Self::add]. - pub fn tmp_bytes(module: &M, res_infos: &R, key_infos: &K) -> usize - where - R: GLWEInfos, - K: GGLWEInfos, - M: GLWEPacking, - { - GLWE::bytes_of_from_infos(res_infos) - + module - .glwe_rsh_tmp_byte() - .max(module.glwe_automorphism_tmp_bytes(res_infos, res_infos, key_infos)) - } - - pub fn galois_elements(module: &M) -> Vec - where - M: GLWETrace, - { - module.glwe_trace_galois_elements() - } - - /// Adds a GLWE ciphertext to the [GLWEPacker]. - /// #Arguments - /// - /// * `module`: static backend FFT tables. - /// * `res`: space to append fully packed ciphertext. Only when the number - /// of packed ciphertexts reaches N/2^log_batch is a result written. - /// * `a`: ciphertext to pack. Can optionally give None to pack a 0 ciphertext. - /// * `auto_keys`: a [HashMap] containing the [AutomorphismKeyExec]s. - /// * `scratch`: scratch space of size at least [Self::tmp_bytes]. - pub fn add(&mut self, module: &M, a: Option<&A>, auto_keys: &HashMap, scratch: &mut Scratch) - where - A: GLWEToRef + GLWEInfos, - K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, - M: GLWEPacking, - Scratch: ScratchTakeCore, - { - assert!( - (self.counter as u32) < self.accumulators[0].data.n(), - "Packing limit of {} reached", - self.accumulators[0].data.n().0 as usize >> self.log_batch - ); - - pack_core( - module, - a, - &mut self.accumulators, - self.log_batch, - auto_keys, - scratch, - ); - self.counter += 1 << self.log_batch; - } - - /// Flush result to`res`. - pub fn flush(&mut self, module: &M, res: &mut R) - where - R: GLWEToMut, - M: GLWEPacking, - { - assert!(self.counter as u32 == self.accumulators[0].data.n()); - // Copy result GLWE into res GLWE - module.glwe_copy( - res, - &self.accumulators[module.log_n() - self.log_batch - 1].data, - ); - - self.reset(); - } -} - -impl GLWEPacking for Module where - Self: GLWEAutomorphism - + GaloisElement - + ModuleLogN - + GLWERotate - + GLWESub - + GLWEShift - + GLWEAdd - + GLWENormalize - + GLWECopy -{ -} - -pub trait GLWEPacking +impl GLWEPacking for Module where Self: GLWEAutomorphism + GaloisElement @@ -177,6 +34,7 @@ where + GLWEAdd + GLWENormalize + GLWECopy, + Scratch: ScratchTakeCore, { /// Packs [x_0: GLWE(m_0), x_1: GLWE(m_1), ..., x_i: GLWE(m_i)] /// to [0: GLWE(m_0 * X^x_0 + m_1 * X^x_1 + ... + m_i * X^x_i)] @@ -189,7 +47,6 @@ where ) where R: GLWEToMut + GLWEToRef + GLWEInfos, K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, - Scratch: ScratchTakeCore, { #[cfg(debug_assertions)] { @@ -223,169 +80,6 @@ where } } -fn pack_core( - module: &M, - a: Option<&A>, - accumulators: &mut [Accumulator], - i: usize, - auto_keys: &HashMap, - scratch: &mut Scratch, -) where - A: GLWEToRef + GLWEInfos, - K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, - M: ModuleLogN - + GLWEAutomorphism - + GaloisElement - + GLWERotate - + GLWESub - + GLWEShift - + GLWEAdd - + GLWENormalize - + GLWECopy, - Scratch: ScratchTakeCore, -{ - let log_n: usize = module.log_n(); - - if i == log_n { - return; - } - - // Isolate the first accumulator - let (acc_prev, acc_next) = accumulators.split_at_mut(1); - - // Control = true accumlator is free to overide - if !acc_prev[0].control { - let acc_mut_ref: &mut Accumulator = &mut acc_prev[0]; // from split_at_mut - - // No previous value -> copies and sets flags accordingly - if let Some(a_ref) = a { - module.glwe_copy(&mut acc_mut_ref.data, a_ref); - acc_mut_ref.value = true - } else { - acc_mut_ref.value = false - } - acc_mut_ref.control = true; // Able to be combined on next call - } else { - // Compresses acc_prev <- combine(acc_prev, a). - combine(module, &mut acc_prev[0], a, i, auto_keys, scratch); - acc_prev[0].control = false; - - // Propagates to next accumulator - if acc_prev[0].value { - pack_core( - module, - Some(&acc_prev[0].data), - acc_next, - i + 1, - auto_keys, - scratch, - ); - } else { - pack_core( - module, - None::<&GLWE>>, - acc_next, - i + 1, - auto_keys, - scratch, - ); - } - } -} - -/// [combine] merges two ciphertexts together. -fn combine( - module: &M, - acc: &mut Accumulator, - b: Option<&B>, - i: usize, - auto_keys: &HashMap, - scratch: &mut Scratch, -) where - B: GLWEToRef + GLWEInfos, - M: GLWEAutomorphism + GaloisElement + GLWERotate + GLWESub + GLWEShift + GLWEAdd + GLWENormalize, - B: GLWEToRef + GLWEInfos, - K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, - Scratch: ScratchTakeCore, -{ - let log_n: usize = acc.data.n().log2(); - let a: &mut GLWE> = &mut acc.data; - - let gal_el: i64 = if i == 0 { - -1 - } else { - module.galois_element(1 << (i - 1)) - }; - - let t: i64 = 1 << (log_n - i - 1); - - // Goal is to evaluate: a = a + b*X^t + phi(a - b*X^t)) - // We also use the identity: AUTO(a * X^t, g) = -X^t * AUTO(a, g) - // where t = 2^(log_n - i - 1) and g = 5^{2^(i - 1)} - // Different cases for wether a and/or b are zero. - // - // Implicite RSH without modulus switch, introduces extra I(X) * Q/2 on decryption. - // Necessary so that the scaling of the plaintext remains constant. - // It however is ok to do so here because coefficients are eventually - // either mapped to garbage or twice their value which vanishes I(X) - // since 2*(I(X) * Q/2) = I(X) * Q = 0 mod Q. - if acc.value { - if let Some(b) = b { - let (mut tmp_b, scratch_1) = scratch.take_glwe(a); - - // a = a * X^-t - module.glwe_rotate_inplace(-t, a, scratch_1); - - // tmp_b = a * X^-t - b - module.glwe_sub(&mut tmp_b, a, b); - module.glwe_rsh(1, &mut tmp_b, scratch_1); - - // a = a * X^-t + b - module.glwe_add_inplace(a, b); - module.glwe_rsh(1, a, scratch_1); - - module.glwe_normalize_inplace(&mut tmp_b, scratch_1); - - // tmp_b = phi(a * X^-t - b) - if let Some(auto_key) = auto_keys.get(&gal_el) { - module.glwe_automorphism_inplace(&mut tmp_b, auto_key, scratch_1); - } else { - panic!("auto_key[{gal_el}] not found"); - } - - // a = a * X^-t + b - phi(a * X^-t - b) - module.glwe_sub_inplace(a, &tmp_b); - module.glwe_normalize_inplace(a, scratch_1); - - // a = a + b * X^t - phi(a * X^-t - b) * X^t - // = a + b * X^t - phi(a * X^-t - b) * - phi(X^t) - // = a + b * X^t + phi(a - b * X^t) - module.glwe_rotate_inplace(t, a, scratch_1); - } else { - module.glwe_rsh(1, a, scratch); - // a = a + phi(a) - if let Some(auto_key) = auto_keys.get(&gal_el) { - module.glwe_automorphism_add_inplace(a, auto_key, scratch); - } else { - panic!("auto_key[{gal_el}] not found"); - } - } - } else if let Some(b) = b { - let (mut tmp_b, scratch_1) = scratch.take_glwe(a); - module.glwe_rotate(t, &mut tmp_b, b); - module.glwe_rsh(1, &mut tmp_b, scratch_1); - - // a = (b* X^t - phi(b* X^t)) - if let Some(auto_key) = auto_keys.get(&gal_el) { - module.glwe_automorphism_sub_negate(a, &tmp_b, auto_key, scratch_1); - } else { - panic!("auto_key[{gal_el}] not found"); - } - - acc.value = true; - } -} - #[allow(clippy::too_many_arguments)] fn pack_internal( module: &M, diff --git a/poulpy-core/src/glwe_trace.rs b/poulpy-core/src/glwe_trace.rs index c2ba15c..0ba7b81 100644 --- a/poulpy-core/src/glwe_trace.rs +++ b/poulpy-core/src/glwe_trace.rs @@ -1,8 +1,8 @@ use std::collections::HashMap; use poulpy_hal::{ - api::ModuleLogN, - layouts::{Backend, DataMut, GaloisElement, Module, Scratch, VecZnx, galois_element}, + api::{ModuleLogN, VecZnxNormalize, VecZnxNormalizeTmpBytes}, + layouts::{Backend, CyclotomicOrder, DataMut, GaloisElement, Module, Scratch, VecZnx, galois_element}, }; use crate::{ @@ -27,7 +27,7 @@ impl GLWE> { K: GGLWEInfos, M: GLWETrace, { - module.glwe_automorphism_tmp_bytes(res_infos, a_infos, key_infos) + module.glwe_trace_tmp_bytes(res_infos, a_infos, key_infos) } } @@ -65,11 +65,6 @@ impl GLWE { } } -impl GLWETrace for Module where - Self: ModuleLogN + GaloisElement + GLWEAutomorphism + GLWEShift + GLWECopy -{ -} - #[inline(always)] pub fn trace_galois_elements(log_n: usize, cyclotomic_order: i64) -> Vec { (0..log_n) @@ -83,9 +78,17 @@ pub fn trace_galois_elements(log_n: usize, cyclotomic_order: i64) -> Vec { .collect() } -pub trait GLWETrace +impl GLWETrace for Module where - Self: ModuleLogN + GaloisElement + GLWEAutomorphism + GLWEShift + GLWECopy, + Self: ModuleLogN + + GaloisElement + + GLWEAutomorphism + + GLWEShift + + GLWECopy + + CyclotomicOrder + + VecZnxNormalizeTmpBytes + + VecZnxNormalize, + Scratch: ScratchTakeCore, { fn glwe_trace_galois_elements(&self) -> Vec { trace_galois_elements(self.log_n(), self.cyclotomic_order()) @@ -115,7 +118,6 @@ where R: GLWEToMut, A: GLWEToRef, K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, - Scratch: ScratchTakeCore, { self.glwe_copy(res, a); self.glwe_trace_inplace(res, start, end, keys, scratch); @@ -125,7 +127,6 @@ where where R: GLWEToMut, K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, - Scratch: ScratchTakeCore, { let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); @@ -212,3 +213,31 @@ where } } } + +pub trait GLWETrace { + fn glwe_trace_galois_elements(&self) -> Vec; + + fn glwe_trace_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize + where + R: GLWEInfos, + A: GLWEInfos, + K: GGLWEInfos; + + fn glwe_trace( + &self, + res: &mut R, + start: usize, + end: usize, + a: &A, + keys: &HashMap, + scratch: &mut Scratch, + ) where + R: GLWEToMut, + A: GLWEToRef, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos; + + fn glwe_trace_inplace(&self, res: &mut R, start: usize, end: usize, keys: &HashMap, scratch: &mut Scratch) + where + R: GLWEToMut, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos; +} diff --git a/poulpy-core/src/keyswitching/ggsw.rs b/poulpy-core/src/keyswitching/ggsw.rs index 231b071..3dfb0b1 100644 --- a/poulpy-core/src/keyswitching/ggsw.rs +++ b/poulpy-core/src/keyswitching/ggsw.rs @@ -1,9 +1,9 @@ -use poulpy_hal::layouts::{Backend, DataMut, Module, Scratch, VecZnx}; +use poulpy_hal::layouts::{Backend, DataMut, Module, Scratch}; use crate::{ GGSWExpandRows, ScratchTakeCore, keyswitching::GLWEKeyswitch, - layouts::{GGLWEInfos, GGLWEPreparedToRef, GGSW, GGSWInfos, GGSWToMut, GGSWToRef, prepared::GLWETensorKeyPreparedToRef}, + layouts::{GGLWEInfos, GGLWEPreparedToRef, GGLWEToGGSWKeyPreparedToRef, GGSW, GGSWInfos, GGSWToMut, GGSWToRef}, }; impl GGSW> { @@ -30,7 +30,7 @@ impl GGSW { where A: GGSWToRef, K: GGLWEPreparedToRef, - T: GLWETensorKeyPreparedToRef, + T: GGLWEToGGSWKeyPreparedToRef, Scratch: ScratchTakeCore, M: GGSWKeyswitch, { @@ -40,7 +40,7 @@ impl GGSW { pub fn keyswitch_inplace(&mut self, module: &M, key: &K, tsk: &T, scratch: &mut Scratch) where K: GGLWEPreparedToRef, - T: GLWETensorKeyPreparedToRef, + T: GGLWEToGGSWKeyPreparedToRef, Scratch: ScratchTakeCore, M: GGSWKeyswitch, { @@ -48,9 +48,7 @@ impl GGSW { } } -impl GGSWKeyswitch for Module where Self: GLWEKeyswitch + GGSWExpandRows {} - -pub trait GGSWKeyswitch +impl GGSWKeyswitch for Module where Self: GLWEKeyswitch + GGSWExpandRows, { @@ -65,25 +63,26 @@ where assert_eq!(tsk_infos.rank_in(), tsk_infos.rank_out()); assert_eq!(key_infos.rank_in(), tsk_infos.rank_in()); - let rank: usize = key_infos.rank_out().into(); + self.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos) + .max(self.ggsw_expand_rows_tmp_bytes(res_infos, tsk_infos)) + } - let size_out: usize = res_infos.k().div_ceil(res_infos.base2k()) as usize; - let res_znx: usize = VecZnx::bytes_of(self.n(), rank + 1, size_out); - let ci_dft: usize = self.bytes_of_vec_znx_dft(rank + 1, size_out); - let ks: usize = self.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos); - let expand_rows: usize = self.ggsw_expand_rows_tmp_bytes(res_infos, tsk_infos); - let res_dft: usize = self.bytes_of_vec_znx_dft(rank + 1, size_out); + fn ggsw_keyswitch_inplace(&self, res: &mut R, key: &K, tsk: &T, scratch: &mut Scratch) + where + R: GGSWToMut, + K: GGLWEPreparedToRef, + T: GGLWEToGGSWKeyPreparedToRef, + Scratch: ScratchTakeCore, + { + let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); - if a_infos.base2k() == tsk_infos.base2k() { - res_znx + ci_dft + (ks | expand_rows | res_dft) - } else { - let a_conv: usize = VecZnx::bytes_of( - self.n(), - 1, - res_infos.k().div_ceil(tsk_infos.base2k()) as usize, - ) + self.vec_znx_normalize_tmp_bytes(); - res_znx + ci_dft + (a_conv | ks | expand_rows | res_dft) + for row in 0..res.dnum().into() { + // Key-switch column 0, i.e. + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2) + self.glwe_keyswitch_inplace(&mut res.at_mut(row, 0), key, scratch); } + + self.ggsw_expand_row(res, tsk, scratch); } fn ggsw_keyswitch(&self, res: &mut R, a: &A, key: &K, tsk: &T, scratch: &mut Scratch) @@ -91,7 +90,7 @@ where R: GGSWToMut, A: GGSWToRef, K: GGLWEPreparedToRef, - T: GLWETensorKeyPreparedToRef, + T: GGLWEToGGSWKeyPreparedToRef, Scratch: ScratchTakeCore, { let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); @@ -108,22 +107,31 @@ where self.ggsw_expand_row(res, tsk, scratch); } +} + +pub trait GGSWKeyswitch +where + Self: GLWEKeyswitch + GGSWExpandRows, +{ + fn ggsw_keyswitch_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &K, tsk_infos: &T) -> usize + where + R: GGSWInfos, + A: GGSWInfos, + K: GGLWEInfos, + T: GGLWEInfos; + + fn ggsw_keyswitch(&self, res: &mut R, a: &A, key: &K, tsk: &T, scratch: &mut Scratch) + where + R: GGSWToMut, + A: GGSWToRef, + K: GGLWEPreparedToRef, + T: GGLWEToGGSWKeyPreparedToRef, + Scratch: ScratchTakeCore; fn ggsw_keyswitch_inplace(&self, res: &mut R, key: &K, tsk: &T, scratch: &mut Scratch) where R: GGSWToMut, K: GGLWEPreparedToRef, - T: GLWETensorKeyPreparedToRef, - Scratch: ScratchTakeCore, - { - let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); - - for row in 0..res.dnum().into() { - // Key-switch column 0, i.e. - // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2) - self.glwe_keyswitch_inplace(&mut res.at_mut(row, 0), key, scratch); - } - - self.ggsw_expand_row(res, tsk, scratch); - } + T: GGLWEToGGSWKeyPreparedToRef, + Scratch: ScratchTakeCore; } diff --git a/poulpy-core/src/keyswitching/glwe.rs b/poulpy-core/src/keyswitching/glwe.rs index a021777..72def40 100644 --- a/poulpy-core/src/keyswitching/glwe.rs +++ b/poulpy-core/src/keyswitching/glwe.rs @@ -1,10 +1,10 @@ use poulpy_hal::{ api::{ ModuleN, ScratchAvailable, ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, - VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, - VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + VecZnxDftApply, VecZnxDftBytesOf, VecZnxDftCopy, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, + VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, - layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, ZnxInfos}, + layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VecZnxDftToRef, VmpPMat, ZnxInfos}, }; use crate::{ @@ -45,46 +45,10 @@ impl GLWE { } } -impl GLWEKeyswitch for Module where - Self: Sized - + ModuleN - + VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VecZnxNormalizeTmpBytes - + VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalize - + VecZnxNormalizeTmpBytes -{ -} - -pub trait GLWEKeyswitch +impl GLWEKeyswitch for Module where - Self: Sized - + ModuleN - + VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VecZnxNormalizeTmpBytes - + VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, + Self: Sized + GLWEKeySwitchInternal + VecZnxBigNormalizeTmpBytes + VecZnxBigNormalize, + Scratch: ScratchTakeCore, { fn glwe_keyswitch_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &B) -> usize where @@ -92,34 +56,10 @@ where A: GLWEInfos, B: GGLWEInfos, { - let in_size: usize = a_infos - .k() - .div_ceil(key_infos.base2k()) - .div_ceil(key_infos.dsize().into()) as usize; - let out_size: usize = res_infos.size(); - let ksk_size: usize = key_infos.size(); - let res_dft: usize = self.bytes_of_vec_znx_dft((key_infos.rank_out() + 1).into(), ksk_size); // TODO OPTIMIZE - let ai_dft: usize = self.bytes_of_vec_znx_dft((key_infos.rank_in()).into(), in_size); - let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes( - out_size, - in_size, - in_size, - (key_infos.rank_in()).into(), - (key_infos.rank_out() + 1).into(), - ksk_size, - ) + self.bytes_of_vec_znx_dft((key_infos.rank_in()).into(), in_size); - let normalize_big: usize = self.vec_znx_big_normalize_tmp_bytes(); - if a_infos.base2k() == key_infos.base2k() { - res_dft + ((ai_dft + vmp) | normalize_big) - } else if key_infos.dsize() == 1 { - // In this case, we only need one column, temporary, that we can drop once a_dft is computed. - let normalize_conv: usize = VecZnx::bytes_of(self.n(), 1, in_size) + self.vec_znx_normalize_tmp_bytes(); - res_dft + (((ai_dft + normalize_conv) | vmp) | normalize_big) - } else { - // Since we stride over a to get a_dft when dsize > 1, we need to store the full columns of a with in the base conversion. - let normalize_conv: usize = VecZnx::bytes_of(self.n(), (key_infos.rank_in()).into(), in_size); - res_dft + ((ai_dft + normalize_conv + (self.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big) - } + let cols: usize = res_infos.rank().as_usize() + 1; + self.glwe_keyswitch_internal_tmp_bytes(res_infos, a_infos, key_infos) + .max(self.vec_znx_big_normalize_tmp_bytes()) + + self.bytes_of_vec_znx_dft(cols, key_infos.size()) } fn glwe_keyswitch(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) @@ -127,7 +67,6 @@ where R: GLWEToMut, A: GLWEToRef, K: GGLWEPreparedToRef, - Scratch: ScratchTakeCore, { let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let a: &GLWE<&[u8]> = &a.to_ref(); @@ -164,8 +103,8 @@ where let base2k_out: usize = b.base2k().into(); let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), b.size()); // Todo optimise - let res_big: VecZnxBig<&mut [u8], BE> = keyswitch_internal(self, res_dft, a, b, scratch_1); - (0..(res.rank() + 1).into()).for_each(|i| { + let res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, a, b, scratch_1); + for i in 0..(res.rank() + 1).into() { self.vec_znx_big_normalize( basek_out, &mut res.data, @@ -175,37 +114,36 @@ where i, scratch_1, ); - }) + } } fn glwe_keyswitch_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) where R: GLWEToMut, K: GGLWEPreparedToRef, - Scratch: ScratchTakeCore, { let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); - let a: &GGLWEPrepared<&[u8], BE> = &key.to_ref(); + let key: &GGLWEPrepared<&[u8], BE> = &key.to_ref(); assert_eq!( res.rank(), - a.rank_in(), + key.rank_in(), "res.rank(): {} != a.rank_in(): {}", res.rank(), - a.rank_in() + key.rank_in() ); assert_eq!( res.rank(), - a.rank_out(), + key.rank_out(), "res.rank(): {} != b.rank_out(): {}", res.rank(), - a.rank_out() + key.rank_out() ); assert_eq!(res.n(), self.n() as u32); - assert_eq!(a.n(), self.n() as u32); + assert_eq!(key.n(), self.n() as u32); - let scrach_needed: usize = self.glwe_keyswitch_tmp_bytes(res, res, a); + let scrach_needed: usize = self.glwe_keyswitch_tmp_bytes(res, res, key); assert!( scratch.available() >= scrach_needed, @@ -214,11 +152,11 @@ where ); let base2k_in: usize = res.base2k().into(); - let base2k_out: usize = a.base2k().into(); + let base2k_out: usize = key.base2k().into(); - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), a.size()); // Todo optimise - let res_big: VecZnxBig<&mut [u8], BE> = keyswitch_internal(self, res_dft, res, a, scratch_1); - (0..(res.rank() + 1).into()).for_each(|i| { + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // Todo optimise + let res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1); + for i in 0..(res.rank() + 1).into() { self.vec_znx_big_normalize( base2k_in, &mut res.data, @@ -228,143 +166,235 @@ where i, scratch_1, ); - }) + } } } -impl GLWE> {} +pub trait GLWEKeyswitch { + fn glwe_keyswitch_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &B) -> usize + where + R: GLWEInfos, + A: GLWEInfos, + B: GGLWEInfos; -impl GLWE {} + fn glwe_keyswitch(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + A: GLWEToRef, + K: GGLWEPreparedToRef; -pub(crate) fn keyswitch_internal( - module: &M, - mut res: VecZnxDft, - a: &A, - key: &K, - scratch: &mut Scratch, -) -> VecZnxBig -where - DR: DataMut, - A: GLWEToRef, - K: GGLWEPreparedToRef, - M: ModuleN - + VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd + fn glwe_keyswitch_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + K: GGLWEPreparedToRef; +} + +impl GLWEKeySwitchInternal for Module where + Self: GGLWEProduct + VecZnxDftApply + + VecZnxNormalize + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalize, - Scratch: ScratchTakeCore, + + VecZnxNormalizeTmpBytes { - let a: &GLWE<&[u8]> = &a.to_ref(); - let key: &GGLWEPrepared<&[u8], BE> = &key.to_ref(); +} - let base2k_in: usize = a.base2k().into(); - let base2k_out: usize = key.base2k().into(); - let cols: usize = (a.rank() + 1).into(); - let a_size: usize = (a.size() * base2k_in).div_ceil(base2k_out); - let pmat: &VmpPMat<&[u8], BE> = &key.data; +pub(crate) trait GLWEKeySwitchInternal +where + Self: GGLWEProduct + + VecZnxDftApply + + VecZnxNormalize + + VecZnxIdftApplyConsume + + VecZnxBigAddSmallInplace + + VecZnxNormalizeTmpBytes, +{ + fn glwe_keyswitch_internal_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize + where + R: GLWEInfos, + A: GLWEInfos, + K: GGLWEInfos, + { + let cols: usize = (a_infos.rank() + 1).into(); + let a_size: usize = a_infos.size(); - if key.dsize() == 1 { - let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols - 1, a.size()); + let a_conv = if a_infos.base2k() == key_infos.base2k() { + 0 + } else { + VecZnx::bytes_of(self.n(), 1, a_size) + self.vec_znx_normalize_tmp_bytes() + }; + + self.gglwe_product_dft_tmp_bytes(res_infos.size(), a_size, key_infos) + self.bytes_of_vec_znx_dft(cols, a_size) + a_conv + } + + fn glwe_keyswitch_internal( + &self, + mut res: VecZnxDft, + a: &A, + key: &K, + scratch: &mut Scratch, + ) -> VecZnxBig + where + DR: DataMut, + A: GLWEToRef, + K: GGLWEPreparedToRef, + Scratch: ScratchTakeCore, + { + let a: &GLWE<&[u8]> = &a.to_ref(); + let key: &GGLWEPrepared<&[u8], BE> = &key.to_ref(); + + let base2k_in: usize = a.base2k().into(); + let base2k_out: usize = key.base2k().into(); + let cols: usize = (a.rank() + 1).into(); + let a_size: usize = (a.size() * base2k_in).div_ceil(base2k_out); + + let (mut a_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols - 1, a_size); if base2k_in == base2k_out { - (0..cols - 1).for_each(|col_i| { - module.vec_znx_dft_apply(1, 0, &mut ai_dft, col_i, a.data(), col_i + 1); - }); + for col_i in 0..cols - 1 { + self.vec_znx_dft_apply(1, 0, &mut a_dft, col_i, a.data(), col_i + 1); + } } else { - let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(module.n(), 1, a_size); - (0..cols - 1).for_each(|col_i| { - module.vec_znx_normalize( + let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(self.n(), 1, a_size); + for i in 0..cols - 1 { + self.vec_znx_normalize( base2k_out, &mut a_conv, 0, base2k_in, a.data(), - col_i + 1, + i + 1, scratch_2, ); - module.vec_znx_dft_apply(1, 0, &mut ai_dft, col_i, &a_conv, 0); - }); + self.vec_znx_dft_apply(1, 0, &mut a_dft, i, &a_conv, 0); + } } - module.vmp_apply_dft_to_dft(&mut res, &ai_dft, pmat, scratch_1); - } else { - let dsize: usize = key.dsize().into(); + self.gglwe_product_dft(&mut res, &a_dft, key, scratch_1); - let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols - 1, a_size.div_ceil(dsize)); - ai_dft.data_mut().fill(0); + let mut res_big: VecZnxBig = self.vec_znx_idft_apply_consume(res); + self.vec_znx_big_add_small_inplace(&mut res_big, 0, a.data(), 0); + res_big + } +} - if base2k_in == base2k_out { - for di in 0..dsize { - ai_dft.set_size((a_size + di) / dsize); +impl GGLWEProduct for Module where + Self: Sized + + ModuleN + + VecZnxDftBytesOf + + VmpApplyDftToDftTmpBytes + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VecZnxDftCopy +{ +} - // Small optimization for dsize > 2 - // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then - // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}. - // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. - // It is possible to further ignore the last dsize-1 limbs, but this introduce - // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same - // noise is kept with respect to the ideal functionality. - res.set_size(pmat.size() - ((dsize - di) as isize - 2).max(0) as usize); +pub(crate) trait GGLWEProduct +where + Self: Sized + + ModuleN + + VecZnxDftBytesOf + + VmpApplyDftToDftTmpBytes + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VecZnxDftCopy, +{ + fn gglwe_product_dft_tmp_bytes(&self, res_size: usize, a_size: usize, key_infos: &K) -> usize + where + K: GGLWEInfos, + { + let dsize: usize = key_infos.dsize().as_usize(); - for j in 0..cols - 1 { - module.vec_znx_dft_apply(dsize, dsize - di - 1, &mut ai_dft, j, a.data(), j + 1); - } - - if di == 0 { - module.vmp_apply_dft_to_dft(&mut res, &ai_dft, pmat, scratch_1); - } else { - module.vmp_apply_dft_to_dft_add(&mut res, &ai_dft, pmat, di, scratch_1); - } - } + if dsize == 1 { + self.vmp_apply_dft_to_dft_tmp_bytes( + res_size, + a_size, + key_infos.dnum().into(), + (key_infos.rank_in()).into(), + (key_infos.rank_out() + 1).into(), + key_infos.size(), + ) } else { - let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(module.n(), cols - 1, a_size); - for j in 0..cols - 1 { - module.vec_znx_normalize( - base2k_out, - &mut a_conv, - j, - base2k_in, - a.data(), - j + 1, - scratch_2, - ); - } + let dnum: usize = key_infos.dnum().into(); + let a_size: usize = a_size.div_ceil(dsize).min(dnum); + let ai_dft: usize = self.bytes_of_vec_znx_dft(key_infos.rank_in().into(), a_size); - for di in 0..dsize { - ai_dft.set_size((a_size + di) / dsize); + let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes( + res_size, + a_size, + dnum, + (key_infos.rank_in()).into(), + (key_infos.rank_out() + 1).into(), + key_infos.size(), + ); - // Small optimization for dsize > 2 - // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then - // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}. - // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. - // It is possible to further ignore the last dsize-1 limbs, but this introduce - // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same - // noise is kept with respect to the ideal functionality. - res.set_size(pmat.size() - ((dsize - di) as isize - 2).max(0) as usize); - - for j in 0..cols - 1 { - module.vec_znx_dft_apply(dsize, dsize - di - 1, &mut ai_dft, j, &a_conv, j); - } - - if di == 0 { - module.vmp_apply_dft_to_dft(&mut res, &ai_dft, pmat, scratch_2); - } else { - module.vmp_apply_dft_to_dft_add(&mut res, &ai_dft, pmat, di, scratch_2); - } - } + ai_dft + vmp } - - res.set_size(res.max_size()); } - let mut res_big: VecZnxBig = module.vec_znx_idft_apply_consume(res); - module.vec_znx_big_add_small_inplace(&mut res_big, 0, a.data(), 0); - res_big + fn gglwe_product_dft(&self, res: &mut VecZnxDft, a: &A, key: &K, scratch: &mut Scratch) + where + DR: DataMut, + A: VecZnxDftToRef, + K: GGLWEPreparedToRef, + Scratch: ScratchTakeCore, + { + let a: &VecZnxDft<&[u8], BE> = &a.to_ref(); + let key: &GGLWEPrepared<&[u8], BE> = &key.to_ref(); + + let cols: usize = a.cols(); + let a_size: usize = a.size(); + let pmat: &VmpPMat<&[u8], BE> = &key.data; + + // If dsize == 1, then the digit decomposition is equal to Base2K and we can simply + // can the vmp API. + if key.dsize() == 1 { + self.vmp_apply_dft_to_dft(res, a, pmat, scratch); + // If dsize != 1, then the digit decomposition is k * Base2K with k > 1. + // As such we need to perform a bivariate polynomial convolution in (X, Y) / (X^{N}+1) with Y = 2^-K + // (instead of yn univariate one in X). + // + // Since the basis in Y is small (in practice degree 6-7 max), we perform it naiveley. + // To do so, we group the different limbs of ai_dft by their respective degree in Y + // which are multiples of the current digit. + // For example if dsize = 3, with ai_dft = [a0, a1, a2, a3, a4, a5, a6], + // we group them as [[a0, a3, a5], [a1, a4, a6], [a2, a5, 0]] + // and evaluate sum(a_di * pmat * 2^{di*Base2k}) + } else { + let dsize: usize = key.dsize().into(); + let dnum: usize = key.dnum().into(); + + // We bound ai_dft size by the number of rows of the matrix + let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, a_size.div_ceil(dsize).min(dnum)); + ai_dft.data_mut().fill(0); + + for di in 0..dsize { + // Sets ai_dft size according to the current digit (if dsize does not divides a_size), + // bounded by the number of rows (digits) in the prepared matrix. + ai_dft.set_size(((a_size + di) / dsize).min(dnum)); + + // Small optimization for dsize > 2 + // VMP produce some error e, and since we aggregate vmp * 2^{di * Base2k}, then + // we also aggregate ei * 2^{di * Base2k}, with the largest error being ei * 2^{(dsize-1) * Base2k}. + // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. + // It is possible to further ignore the last dsize-1 limbs, but this introduce + // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same + // noise is kept with respect to the ideal functionality. + res.set_size(pmat.size() - ((dsize - di) as isize - 2).max(0) as usize); + + for j in 0..cols { + self.vec_znx_dft_copy(dsize, dsize - di - 1, &mut ai_dft, j, a, j); + } + + if di == 0 { + // res = pmat * ai_dft + self.vmp_apply_dft_to_dft(res, &ai_dft, pmat, scratch_1); + } else { + // res = (pmat * ai_dft) * 2^{di * Base2k} + self.vmp_apply_dft_to_dft_add(res, &ai_dft, pmat, di, scratch_1); + } + } + + res.set_size(res.max_size()); + } + } } diff --git a/poulpy-core/src/layouts/compressed/gglwe_to_ggsw_key.rs b/poulpy-core/src/layouts/compressed/gglwe_to_ggsw_key.rs new file mode 100644 index 0000000..e158a0c --- /dev/null +++ b/poulpy-core/src/layouts/compressed/gglwe_to_ggsw_key.rs @@ -0,0 +1,237 @@ +use poulpy_hal::{ + layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, WriterTo}, + source::Source, +}; + +use crate::layouts::{ + Base2K, Degree, Dnum, Dsize, GGLWECompressed, GGLWECompressedToMut, GGLWECompressedToRef, GGLWEDecompress, GGLWEInfos, + GGLWEToGGSWKey, GGLWEToGGSWKeyToMut, GLWEInfos, LWEInfos, Rank, TorusPrecision, +}; +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; + +use std::fmt; + +#[derive(PartialEq, Eq, Clone)] +pub struct GGLWEToGGSWKeyCompressed { + pub(crate) keys: Vec>, +} + +impl LWEInfos for GGLWEToGGSWKeyCompressed { + fn n(&self) -> Degree { + self.keys[0].n() + } + + fn base2k(&self) -> Base2K { + self.keys[0].base2k() + } + + fn k(&self) -> TorusPrecision { + self.keys[0].k() + } + + fn size(&self) -> usize { + self.keys[0].size() + } +} + +impl GLWEInfos for GGLWEToGGSWKeyCompressed { + fn rank(&self) -> Rank { + self.keys[0].rank_out() + } +} + +impl GGLWEInfos for GGLWEToGGSWKeyCompressed { + fn rank_in(&self) -> Rank { + self.rank_out() + } + + fn rank_out(&self) -> Rank { + self.keys[0].rank_out() + } + + fn dsize(&self) -> Dsize { + self.keys[0].dsize() + } + + fn dnum(&self) -> Dnum { + self.keys[0].dnum() + } +} + +impl fmt::Debug for GGLWEToGGSWKeyCompressed { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{self}") + } +} + +impl FillUniform for GGLWEToGGSWKeyCompressed { + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { + self.keys + .iter_mut() + .for_each(|key: &mut GGLWECompressed| key.fill_uniform(log_bound, source)) + } +} + +impl fmt::Display for GGLWEToGGSWKeyCompressed { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!(f, "(GGLWEToGGSWKeyCompressed)",)?; + for (i, key) in self.keys.iter().enumerate() { + write!(f, "{i}: {key}")?; + } + Ok(()) + } +} + +impl GGLWEToGGSWKeyCompressed> { + pub fn alloc_from_infos(infos: &A) -> Self + where + A: GGLWEInfos, + { + assert_eq!( + infos.rank_in(), + infos.rank_out(), + "rank_in != rank_out is not supported for GGLWEToGGSWKeyCompressed" + ); + Self::alloc( + infos.n(), + infos.base2k(), + infos.k(), + infos.rank(), + infos.dnum(), + infos.dsize(), + ) + } + + pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self { + GGLWEToGGSWKeyCompressed { + keys: (0..rank.as_usize()) + .map(|_| GGLWECompressed::alloc(n, base2k, k, rank, rank, dnum, dsize)) + .collect(), + } + } + + pub fn bytes_of_from_infos(infos: &A) -> usize + where + A: GGLWEInfos, + { + assert_eq!( + infos.rank_in(), + infos.rank_out(), + "rank_in != rank_out is not supported for GGLWEToGGSWKeyCompressed" + ); + Self::bytes_of( + infos.n(), + infos.base2k(), + infos.k(), + infos.rank(), + infos.dnum(), + infos.dsize(), + ) + } + + pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { + rank.as_usize() * GGLWECompressed::bytes_of(n, base2k, k, rank, dnum, dsize) + } +} + +impl GGLWEToGGSWKeyCompressed { + // Returns a mutable reference to GGLWE_{s}([s[i]*s[0], s[i]*s[1], ..., s[i]*s[rank]]) + pub fn at_mut(&mut self, i: usize) -> &mut GGLWECompressed { + assert!((i as u32) < self.rank()); + &mut self.keys[i] + } +} + +impl GGLWEToGGSWKeyCompressed { + // Returns a reference to GGLWE_{s}(s[i] * s[j]) + pub fn at(&self, i: usize) -> &GGLWECompressed { + assert!((i as u32) < self.rank()); + &self.keys[i] + } +} + +impl ReaderFrom for GGLWEToGGSWKeyCompressed { + fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { + let len: usize = reader.read_u64::()? as usize; + if self.keys.len() != len { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("self.keys.len()={} != read len={}", self.keys.len(), len), + )); + } + for key in &mut self.keys { + key.read_from(reader)?; + } + Ok(()) + } +} + +impl WriterTo for GGLWEToGGSWKeyCompressed { + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + writer.write_u64::(self.keys.len() as u64)?; + for key in &self.keys { + key.write_to(writer)?; + } + Ok(()) + } +} + +pub trait GGLWEToGGSWKeyDecompress +where + Self: GGLWEDecompress, +{ + fn decompress_gglwe_to_ggsw_key(&self, res: &mut R, other: &O) + where + R: GGLWEToGGSWKeyToMut, + O: GGLWEToGGSWKeyCompressedToRef, + { + let res: &mut GGLWEToGGSWKey<&mut [u8]> = &mut res.to_mut(); + let other: &GGLWEToGGSWKeyCompressed<&[u8]> = &other.to_ref(); + + assert_eq!(res.keys.len(), other.keys.len()); + + for (a, b) in res.keys.iter_mut().zip(other.keys.iter()) { + self.decompress_gglwe(a, b); + } + } +} + +impl GGLWEToGGSWKey { + pub fn decompress(&mut self, module: &M, other: &O) + where + M: GGLWEToGGSWKeyDecompress, + O: GGLWEToGGSWKeyCompressedToRef, + { + module.decompress_gglwe_to_ggsw_key(self, other); + } +} + +pub trait GGLWEToGGSWKeyCompressedToRef { + fn to_ref(&self) -> GGLWEToGGSWKeyCompressed<&[u8]>; +} + +impl GGLWEToGGSWKeyCompressedToRef for GGLWEToGGSWKeyCompressed +where + GGLWECompressed: GGLWECompressedToRef, +{ + fn to_ref(&self) -> GGLWEToGGSWKeyCompressed<&[u8]> { + GGLWEToGGSWKeyCompressed { + keys: self.keys.iter().map(|c| c.to_ref()).collect(), + } + } +} + +pub trait GGLWEToGGSWKeyCompressedToMut { + fn to_mut(&mut self) -> GGLWEToGGSWKeyCompressed<&mut [u8]>; +} + +impl GGLWEToGGSWKeyCompressedToMut for GGLWEToGGSWKeyCompressed +where + GGLWECompressed: GGLWECompressedToMut, +{ + fn to_mut(&mut self) -> GGLWEToGGSWKeyCompressed<&mut [u8]> { + GGLWEToGGSWKeyCompressed { + keys: self.keys.iter_mut().map(|c| c.to_mut()).collect(), + } + } +} diff --git a/poulpy-core/src/layouts/compressed/glwe_tensor_key.rs b/poulpy-core/src/layouts/compressed/glwe_tensor_key.rs index 6939ff2..c6e9297 100644 --- a/poulpy-core/src/layouts/compressed/glwe_tensor_key.rs +++ b/poulpy-core/src/layouts/compressed/glwe_tensor_key.rs @@ -4,31 +4,34 @@ use poulpy_hal::{ }; use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWECompressed, GGLWECompressedToMut, GGLWECompressedToRef, GGLWEDecompress, GGLWEInfos, - GLWEInfos, GLWETensorKey, GLWETensorKeyToMut, LWEInfos, Rank, TorusPrecision, + Base2K, Degree, Dnum, Dsize, GGLWECompressed, GGLWECompressedSeedMut, GGLWECompressedToMut, GGLWECompressedToRef, + GGLWEDecompress, GGLWEInfos, GGLWEToMut, GLWEInfos, GLWETensorKey, LWEInfos, Rank, TorusPrecision, }; -use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::fmt; #[derive(PartialEq, Eq, Clone)] -pub struct GLWETensorKeyCompressed { - pub(crate) keys: Vec>, +pub struct GLWETensorKeyCompressed(pub(crate) GGLWECompressed); + +impl GGLWECompressedSeedMut for GLWETensorKeyCompressed { + fn seed_mut(&mut self) -> &mut Vec<[u8; 32]> { + &mut self.0.seed + } } impl LWEInfos for GLWETensorKeyCompressed { fn n(&self) -> Degree { - self.keys[0].n() + self.0.n() } fn base2k(&self) -> Base2K { - self.keys[0].base2k() + self.0.base2k() } fn k(&self) -> TorusPrecision { - self.keys[0].k() + self.0.k() } fn size(&self) -> usize { - self.keys[0].size() + self.0.size() } } impl GLWEInfos for GLWETensorKeyCompressed { @@ -43,15 +46,15 @@ impl GGLWEInfos for GLWETensorKeyCompressed { } fn rank_out(&self) -> Rank { - self.keys[0].rank_out() + self.0.rank_out() } fn dsize(&self) -> Dsize { - self.keys[0].dsize() + self.0.dsize() } fn dnum(&self) -> Dnum { - self.keys[0].dnum() + self.0.dnum() } } @@ -63,18 +66,14 @@ impl fmt::Debug for GLWETensorKeyCompressed { impl FillUniform for GLWETensorKeyCompressed { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { - self.keys - .iter_mut() - .for_each(|key: &mut GGLWECompressed| key.fill_uniform(log_bound, source)) + self.0.fill_uniform(log_bound, source); } } impl fmt::Display for GLWETensorKeyCompressed { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!(f, "(GLWETensorKeyCompressed)",)?; - for (i, key) in self.keys.iter().enumerate() { - write!(f, "{i}: {key}")?; - } + write!(f, "{}", self.0)?; Ok(()) } } @@ -96,11 +95,15 @@ impl GLWETensorKeyCompressed> { pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self { let pairs: u32 = (((rank.as_u32() + 1) * rank.as_u32()) >> 1).max(1); - GLWETensorKeyCompressed { - keys: (0..pairs) - .map(|_| GGLWECompressed::alloc(n, base2k, k, Rank(1), rank, dnum, dsize)) - .collect(), - } + GLWETensorKeyCompressed(GGLWECompressed::alloc( + n, + base2k, + k, + Rank(pairs), + rank, + dnum, + dsize, + )) } pub fn bytes_of_from_infos(infos: &A) -> usize @@ -118,88 +121,35 @@ impl GLWETensorKeyCompressed> { } pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { - let pairs: usize = (((rank.0 + 1) * rank.0) >> 1).max(1) as usize; - pairs * GGLWECompressed::bytes_of(n, base2k, k, Rank(1), dnum, dsize) + let pairs: u32 = (((rank.as_u32() + 1) * rank.as_u32()) >> 1).max(1); + GGLWECompressed::bytes_of(n, base2k, k, Rank(pairs), dnum, dsize) } } impl ReaderFrom for GLWETensorKeyCompressed { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { - let len: usize = reader.read_u64::()? as usize; - if self.keys.len() != len { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - format!("self.keys.len()={} != read len={}", self.keys.len(), len), - )); - } - for key in &mut self.keys { - key.read_from(reader)?; - } + self.0.read_from(reader)?; Ok(()) } } impl WriterTo for GLWETensorKeyCompressed { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { - writer.write_u64::(self.keys.len() as u64)?; - for key in &self.keys { - key.write_to(writer)?; - } + self.0.write_to(writer)?; Ok(()) } } -pub trait GLWETensorKeyCompressedAtRef { - fn at(&self, i: usize, j: usize) -> &GGLWECompressed; -} - -impl GLWETensorKeyCompressedAtRef for GLWETensorKeyCompressed { - fn at(&self, mut i: usize, mut j: usize) -> &GGLWECompressed { - if i > j { - std::mem::swap(&mut i, &mut j); - }; - let rank: usize = self.rank_out().into(); - &self.keys[i * rank + j - (i * (i + 1) / 2)] - } -} - -pub trait GLWETensorKeyCompressedAtMut { - fn at_mut(&mut self, i: usize, j: usize) -> &mut GGLWECompressed; -} - -impl GLWETensorKeyCompressedAtMut for GLWETensorKeyCompressed { - fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GGLWECompressed { - if i > j { - std::mem::swap(&mut i, &mut j); - }; - let rank: usize = self.rank_out().into(); - &mut self.keys[i * rank + j - (i * (i + 1) / 2)] - } -} - pub trait GLWETensorKeyDecompress where Self: GGLWEDecompress, { fn decompress_tensor_key(&self, res: &mut R, other: &O) where - R: GLWETensorKeyToMut, - O: GLWETensorKeyCompressedToRef, + R: GGLWEToMut, + O: GGLWECompressedToRef, { - let res: &mut GLWETensorKey<&mut [u8]> = &mut res.to_mut(); - let other: &GLWETensorKeyCompressed<&[u8]> = &other.to_ref(); - - assert_eq!( - res.keys.len(), - other.keys.len(), - "invalid receiver: res.keys.len()={} != other.keys.len()={}", - res.keys.len(), - other.keys.len() - ); - - for (a, b) in res.keys.iter_mut().zip(other.keys.iter()) { - self.decompress_gglwe(a, b); - } + self.decompress_gglwe(res, other); } } @@ -208,39 +158,27 @@ impl GLWETensorKeyDecompress for Module where Self: GGLWEDecompre impl GLWETensorKey { pub fn decompress(&mut self, module: &M, other: &O) where - O: GLWETensorKeyCompressedToRef, + O: GGLWECompressedToRef, M: GLWETensorKeyDecompress, { module.decompress_tensor_key(self, other); } } -pub trait GLWETensorKeyCompressedToMut { - fn to_mut(&mut self) -> GLWETensorKeyCompressed<&mut [u8]>; -} - -impl GLWETensorKeyCompressedToMut for GLWETensorKeyCompressed +impl GGLWECompressedToMut for GLWETensorKeyCompressed where GGLWECompressed: GGLWECompressedToMut, { - fn to_mut(&mut self) -> GLWETensorKeyCompressed<&mut [u8]> { - GLWETensorKeyCompressed { - keys: self.keys.iter_mut().map(|c| c.to_mut()).collect(), - } + fn to_mut(&mut self) -> GGLWECompressed<&mut [u8]> { + self.0.to_mut() } } -pub trait GLWETensorKeyCompressedToRef { - fn to_ref(&self) -> GLWETensorKeyCompressed<&[u8]>; -} - -impl GLWETensorKeyCompressedToRef for GLWETensorKeyCompressed +impl GGLWECompressedToRef for GLWETensorKeyCompressed where GGLWECompressed: GGLWECompressedToRef, { - fn to_ref(&self) -> GLWETensorKeyCompressed<&[u8]> { - GLWETensorKeyCompressed { - keys: self.keys.iter().map(|c| c.to_ref()).collect(), - } + fn to_ref(&self) -> GGLWECompressed<&[u8]> { + self.0.to_ref() } } diff --git a/poulpy-core/src/layouts/compressed/glwe_to_lwe_switching_key.rs b/poulpy-core/src/layouts/compressed/glwe_to_lwe_key.rs similarity index 95% rename from poulpy-core/src/layouts/compressed/glwe_to_lwe_switching_key.rs rename to poulpy-core/src/layouts/compressed/glwe_to_lwe_key.rs index 6ac325c..5552d11 100644 --- a/poulpy-core/src/layouts/compressed/glwe_to_lwe_switching_key.rs +++ b/poulpy-core/src/layouts/compressed/glwe_to_lwe_key.rs @@ -7,7 +7,7 @@ use poulpy_hal::{ use crate::layouts::{ Base2K, Degree, Dnum, Dsize, GGLWECompressed, GGLWECompressedToMut, GGLWECompressedToRef, GGLWEInfos, GGLWEToMut, GLWEInfos, - GLWESwitchingKeyDegrees, GLWESwitchingKeyDegreesMut, GLWEToLWESwitchingKey, LWEInfos, Rank, TorusPrecision, + GLWESwitchingKeyDegrees, GLWESwitchingKeyDegreesMut, GLWEToLWEKey, LWEInfos, Rank, TorusPrecision, compressed::{GLWESwitchingKeyCompressed, GLWESwitchingKeyDecompress}, }; @@ -147,7 +147,7 @@ pub trait GLWEToLWESwitchingKeyDecompress where Self: GLWESwitchingKeyDecompress, { - fn decompress_glwe_to_lwe_switching_key(&self, res: &mut R, other: &O) + fn decompress_glwe_to_lwe_key(&self, res: &mut R, other: &O) where R: GGLWEToMut + GLWESwitchingKeyDegreesMut, O: GGLWECompressedToRef + GLWESwitchingKeyDegrees, @@ -158,13 +158,13 @@ where impl GLWEToLWESwitchingKeyDecompress for Module where Self: GLWESwitchingKeyDecompress {} -impl GLWEToLWESwitchingKey { +impl GLWEToLWEKey { pub fn decompress(&mut self, module: &M, other: &O) where O: GGLWECompressedToRef + GLWESwitchingKeyDegrees, M: GLWEToLWESwitchingKeyDecompress, { - module.decompress_glwe_to_lwe_switching_key(self, other); + module.decompress_glwe_to_lwe_key(self, other); } } diff --git a/poulpy-core/src/layouts/compressed/lwe_to_glwe_switching_key.rs b/poulpy-core/src/layouts/compressed/lwe_to_glwe_key.rs similarity index 73% rename from poulpy-core/src/layouts/compressed/lwe_to_glwe_switching_key.rs rename to poulpy-core/src/layouts/compressed/lwe_to_glwe_key.rs index 7a724c9..984ed05 100644 --- a/poulpy-core/src/layouts/compressed/lwe_to_glwe_switching_key.rs +++ b/poulpy-core/src/layouts/compressed/lwe_to_glwe_key.rs @@ -5,15 +5,15 @@ use poulpy_hal::{ use crate::layouts::{ Base2K, Degree, Dnum, Dsize, GGLWECompressed, GGLWECompressedToMut, GGLWECompressedToRef, GGLWEInfos, GGLWEToMut, GLWEInfos, - GLWESwitchingKeyDegrees, GLWESwitchingKeyDegreesMut, LWEInfos, LWEToGLWESwitchingKey, Rank, TorusPrecision, + GLWESwitchingKeyDegrees, GLWESwitchingKeyDegreesMut, LWEInfos, LWEToGLWEKey, Rank, TorusPrecision, compressed::{GLWESwitchingKeyCompressed, GLWESwitchingKeyDecompress}, }; use std::fmt; #[derive(PartialEq, Eq, Clone)] -pub struct LWEToGLWESwitchingKeyCompressed(pub(crate) GLWESwitchingKeyCompressed); +pub struct LWEToGLWEKeyCompressed(pub(crate) GLWESwitchingKeyCompressed); -impl LWEInfos for LWEToGLWESwitchingKeyCompressed { +impl LWEInfos for LWEToGLWEKeyCompressed { fn n(&self) -> Degree { self.0.n() } @@ -29,13 +29,13 @@ impl LWEInfos for LWEToGLWESwitchingKeyCompressed { self.0.size() } } -impl GLWEInfos for LWEToGLWESwitchingKeyCompressed { +impl GLWEInfos for LWEToGLWEKeyCompressed { fn rank(&self) -> Rank { self.rank_out() } } -impl GGLWEInfos for LWEToGLWESwitchingKeyCompressed { +impl GGLWEInfos for LWEToGLWEKeyCompressed { fn dsize(&self) -> Dsize { self.0.dsize() } @@ -53,37 +53,37 @@ impl GGLWEInfos for LWEToGLWESwitchingKeyCompressed { } } -impl fmt::Debug for LWEToGLWESwitchingKeyCompressed { +impl fmt::Debug for LWEToGLWEKeyCompressed { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{self}") } } -impl FillUniform for LWEToGLWESwitchingKeyCompressed { +impl FillUniform for LWEToGLWEKeyCompressed { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.0.fill_uniform(log_bound, source); } } -impl fmt::Display for LWEToGLWESwitchingKeyCompressed { +impl fmt::Display for LWEToGLWEKeyCompressed { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "(LWEToGLWESwitchingKeyCompressed) {}", self.0) } } -impl ReaderFrom for LWEToGLWESwitchingKeyCompressed { +impl ReaderFrom for LWEToGLWEKeyCompressed { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { self.0.read_from(reader) } } -impl WriterTo for LWEToGLWESwitchingKeyCompressed { +impl WriterTo for LWEToGLWEKeyCompressed { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { self.0.write_to(writer) } } -impl LWEToGLWESwitchingKeyCompressed> { +impl LWEToGLWEKeyCompressed> { pub fn alloc_from_infos(infos: &A) -> Self where A: GGLWEInfos, @@ -108,7 +108,7 @@ impl LWEToGLWESwitchingKeyCompressed> { } pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> Self { - LWEToGLWESwitchingKeyCompressed(GLWESwitchingKeyCompressed::alloc( + LWEToGLWEKeyCompressed(GLWESwitchingKeyCompressed::alloc( n, base2k, k, @@ -141,11 +141,11 @@ impl LWEToGLWESwitchingKeyCompressed> { } } -pub trait LWEToGLWESwitchingKeyDecompress +pub trait LWEToGLWEKeyDecompress where Self: GLWESwitchingKeyDecompress, { - fn decompress_lwe_to_glwe_switching_key(&self, res: &mut R, other: &O) + fn decompress_lwe_to_glwe_key(&self, res: &mut R, other: &O) where R: GGLWEToMut + GLWESwitchingKeyDegreesMut, O: GGLWECompressedToRef + GLWESwitchingKeyDegrees, @@ -154,25 +154,25 @@ where } } -impl LWEToGLWESwitchingKeyDecompress for Module where Self: GLWESwitchingKeyDecompress {} +impl LWEToGLWEKeyDecompress for Module where Self: GLWESwitchingKeyDecompress {} -impl LWEToGLWESwitchingKey { +impl LWEToGLWEKey { pub fn decompress(&mut self, module: &M, other: &O) where O: GGLWECompressedToRef + GLWESwitchingKeyDegrees, - M: LWEToGLWESwitchingKeyDecompress, + M: LWEToGLWEKeyDecompress, { - module.decompress_lwe_to_glwe_switching_key(self, other); + module.decompress_lwe_to_glwe_key(self, other); } } -impl GGLWECompressedToRef for LWEToGLWESwitchingKeyCompressed { +impl GGLWECompressedToRef for LWEToGLWEKeyCompressed { fn to_ref(&self) -> GGLWECompressed<&[u8]> { self.0.to_ref() } } -impl GGLWECompressedToMut for LWEToGLWESwitchingKeyCompressed { +impl GGLWECompressedToMut for LWEToGLWEKeyCompressed { fn to_mut(&mut self) -> GGLWECompressed<&mut [u8]> { self.0.to_mut() } diff --git a/poulpy-core/src/layouts/compressed/mod.rs b/poulpy-core/src/layouts/compressed/mod.rs index b85d48d..8dd6145 100644 --- a/poulpy-core/src/layouts/compressed/mod.rs +++ b/poulpy-core/src/layouts/compressed/mod.rs @@ -1,21 +1,23 @@ mod gglwe; +mod gglwe_to_ggsw_key; mod ggsw; mod glwe; mod glwe_automorphism_key; mod glwe_switching_key; mod glwe_tensor_key; -mod glwe_to_lwe_switching_key; +mod glwe_to_lwe_key; mod lwe; mod lwe_switching_key; -mod lwe_to_glwe_switching_key; +mod lwe_to_glwe_key; pub use gglwe::*; +pub use gglwe_to_ggsw_key::*; pub use ggsw::*; pub use glwe::*; pub use glwe_automorphism_key::*; pub use glwe_switching_key::*; pub use glwe_tensor_key::*; -pub use glwe_to_lwe_switching_key::*; +pub use glwe_to_lwe_key::*; pub use lwe::*; pub use lwe_switching_key::*; -pub use lwe_to_glwe_switching_key::*; +pub use lwe_to_glwe_key::*; diff --git a/poulpy-core/src/layouts/gglwe_to_ggsw_key.rs b/poulpy-core/src/layouts/gglwe_to_ggsw_key.rs new file mode 100644 index 0000000..398dfd2 --- /dev/null +++ b/poulpy-core/src/layouts/gglwe_to_ggsw_key.rs @@ -0,0 +1,254 @@ +use poulpy_hal::{ + layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, WriterTo}, + source::Source, +}; + +use crate::layouts::{ + Base2K, Degree, Dnum, Dsize, GGLWE, GGLWEInfos, GGLWEToMut, GGLWEToRef, GLWEInfos, LWEInfos, Rank, TorusPrecision, +}; +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; + +use std::fmt; + +#[derive(PartialEq, Eq, Copy, Clone, Debug)] +pub struct GGLWEToGGSWKeyLayout { + pub n: Degree, + pub base2k: Base2K, + pub k: TorusPrecision, + pub rank: Rank, + pub dnum: Dnum, + pub dsize: Dsize, +} + +#[derive(PartialEq, Eq, Clone)] +pub struct GGLWEToGGSWKey { + pub(crate) keys: Vec>, +} + +impl LWEInfos for GGLWEToGGSWKey { + fn n(&self) -> Degree { + self.keys[0].n() + } + + fn base2k(&self) -> Base2K { + self.keys[0].base2k() + } + + fn k(&self) -> TorusPrecision { + self.keys[0].k() + } + + fn size(&self) -> usize { + self.keys[0].size() + } +} + +impl GLWEInfos for GGLWEToGGSWKey { + fn rank(&self) -> Rank { + self.keys[0].rank_out() + } +} + +impl GGLWEInfos for GGLWEToGGSWKey { + fn rank_in(&self) -> Rank { + self.rank_out() + } + + fn rank_out(&self) -> Rank { + self.keys[0].rank_out() + } + + fn dsize(&self) -> Dsize { + self.keys[0].dsize() + } + + fn dnum(&self) -> Dnum { + self.keys[0].dnum() + } +} + +impl LWEInfos for GGLWEToGGSWKeyLayout { + fn n(&self) -> Degree { + self.n + } + + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } +} + +impl GLWEInfos for GGLWEToGGSWKeyLayout { + fn rank(&self) -> Rank { + self.rank_out() + } +} + +impl GGLWEInfos for GGLWEToGGSWKeyLayout { + fn rank_in(&self) -> Rank { + self.rank + } + + fn dsize(&self) -> Dsize { + self.dsize + } + + fn rank_out(&self) -> Rank { + self.rank + } + + fn dnum(&self) -> Dnum { + self.dnum + } +} + +impl fmt::Debug for GGLWEToGGSWKey { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{self}") + } +} + +impl FillUniform for GGLWEToGGSWKey { + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { + self.keys + .iter_mut() + .for_each(|key: &mut GGLWE| key.fill_uniform(log_bound, source)) + } +} + +impl fmt::Display for GGLWEToGGSWKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!(f, "(GGLWEToGGSWKey)",)?; + for (i, key) in self.keys.iter().enumerate() { + write!(f, "{i}: {key}")?; + } + Ok(()) + } +} + +impl GGLWEToGGSWKey> { + pub fn alloc_from_infos(infos: &A) -> Self + where + A: GGLWEInfos, + { + assert_eq!( + infos.rank_in(), + infos.rank_out(), + "rank_in != rank_out is not supported for GGLWEToGGSWKey" + ); + Self::alloc( + infos.n(), + infos.base2k(), + infos.k(), + infos.rank(), + infos.dnum(), + infos.dsize(), + ) + } + + pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self { + GGLWEToGGSWKey { + keys: (0..rank.as_usize()) + .map(|_| GGLWE::alloc(n, base2k, k, rank, rank, dnum, dsize)) + .collect(), + } + } + + pub fn bytes_of_from_infos(infos: &A) -> usize + where + A: GGLWEInfos, + { + assert_eq!( + infos.rank_in(), + infos.rank_out(), + "rank_in != rank_out is not supported for GGLWEToGGSWKey" + ); + Self::bytes_of( + infos.n(), + infos.base2k(), + infos.k(), + infos.rank(), + infos.dnum(), + infos.dsize(), + ) + } + + pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { + rank.as_usize() * GGLWE::bytes_of(n, base2k, k, rank, rank, dnum, dsize) + } +} + +impl GGLWEToGGSWKey { + // Returns a mutable reference to GGLWE_{s}([s[i]*s[0], s[i]*s[1], ..., s[i]*s[rank]]) + pub fn at_mut(&mut self, i: usize) -> &mut GGLWE { + assert!((i as u32) < self.rank()); + &mut self.keys[i] + } +} + +impl GGLWEToGGSWKey { + // Returns a reference to GGLWE_{s}(s[i] * s[j]) + pub fn at(&self, i: usize) -> &GGLWE { + assert!((i as u32) < self.rank()); + &self.keys[i] + } +} + +impl ReaderFrom for GGLWEToGGSWKey { + fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { + let len: usize = reader.read_u64::()? as usize; + if self.keys.len() != len { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("self.keys.len()={} != read len={}", self.keys.len(), len), + )); + } + for key in &mut self.keys { + key.read_from(reader)?; + } + Ok(()) + } +} + +impl WriterTo for GGLWEToGGSWKey { + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + writer.write_u64::(self.keys.len() as u64)?; + for key in &self.keys { + key.write_to(writer)?; + } + Ok(()) + } +} + +pub trait GGLWEToGGSWKeyToRef { + fn to_ref(&self) -> GGLWEToGGSWKey<&[u8]>; +} + +impl GGLWEToGGSWKeyToRef for GGLWEToGGSWKey +where + GGLWE: GGLWEToRef, +{ + fn to_ref(&self) -> GGLWEToGGSWKey<&[u8]> { + GGLWEToGGSWKey { + keys: self.keys.iter().map(|c| c.to_ref()).collect(), + } + } +} + +pub trait GGLWEToGGSWKeyToMut { + fn to_mut(&mut self) -> GGLWEToGGSWKey<&mut [u8]>; +} + +impl GGLWEToGGSWKeyToMut for GGLWEToGGSWKey +where + GGLWE: GGLWEToMut, +{ + fn to_mut(&mut self) -> GGLWEToGGSWKey<&mut [u8]> { + GGLWEToGGSWKey { + keys: self.keys.iter_mut().map(|c| c.to_mut()).collect(), + } + } +} diff --git a/poulpy-core/src/layouts/glwe_secret_tensor.rs b/poulpy-core/src/layouts/glwe_secret_tensor.rs new file mode 100644 index 0000000..287eda8 --- /dev/null +++ b/poulpy-core/src/layouts/glwe_secret_tensor.rs @@ -0,0 +1,221 @@ +use poulpy_hal::{ + api::{ + ModuleN, ScratchTakeBasic, SvpApplyDftToDft, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, + VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyTmpA, + }, + layouts::{ + Backend, Data, DataMut, DataRef, Module, ScalarZnx, ScalarZnxToMut, ScalarZnxToRef, Scratch, ZnxInfos, ZnxView, + ZnxViewMut, + }, +}; + +use crate::{ + ScratchTakeCore, + dist::Distribution, + layouts::{ + Base2K, Degree, GLWEInfos, GLWESecret, GLWESecretPreparedFactory, GLWESecretToMut, GLWESecretToRef, LWEInfos, Rank, + TorusPrecision, + }, +}; + +pub struct GLWESecretTensor { + pub(crate) data: ScalarZnx, + pub(crate) rank: Rank, + pub(crate) dist: Distribution, +} + +impl GLWESecretTensor> { + pub(crate) fn pairs(rank: usize) -> usize { + (((rank + 1) * rank) >> 1).max(1) + } +} + +impl LWEInfos for GLWESecretTensor { + fn base2k(&self) -> Base2K { + Base2K(0) + } + + fn k(&self) -> TorusPrecision { + TorusPrecision(0) + } + + fn n(&self) -> Degree { + Degree(self.data.n() as u32) + } + + fn size(&self) -> usize { + 1 + } +} + +impl GLWESecretTensor { + pub fn at(&self, mut i: usize, mut j: usize) -> ScalarZnx<&[u8]> { + if i > j { + std::mem::swap(&mut i, &mut j); + }; + let rank: usize = self.rank().into(); + ScalarZnx { + data: bytemuck::cast_slice(self.data.at(i * rank + j - (i * (i + 1) / 2), 0)), + n: self.n().into(), + cols: 1, + } + } +} + +impl GLWESecretTensor { + pub fn at_mut(&mut self, mut i: usize, mut j: usize) -> ScalarZnx<&mut [u8]> { + if i > j { + std::mem::swap(&mut i, &mut j); + }; + let rank: usize = self.rank().into(); + ScalarZnx { + n: self.n().into(), + data: bytemuck::cast_slice_mut(self.data.at_mut(i * rank + j - (i * (i + 1) / 2), 0)), + cols: 1, + } + } +} + +impl GLWEInfos for GLWESecretTensor { + fn rank(&self) -> Rank { + self.rank + } +} + +impl GLWESecretToRef for GLWESecretTensor { + fn to_ref(&self) -> GLWESecret<&[u8]> { + GLWESecret { + data: self.data.to_ref(), + dist: self.dist, + } + } +} + +impl GLWESecretToMut for GLWESecretTensor { + fn to_mut(&mut self) -> GLWESecret<&mut [u8]> { + GLWESecret { + dist: self.dist, + data: self.data.to_mut(), + } + } +} + +impl GLWESecretTensor> { + pub fn alloc_from_infos(infos: &A) -> Self + where + A: GLWEInfos, + { + Self::alloc(infos.n(), infos.rank()) + } + + pub fn alloc(n: Degree, rank: Rank) -> Self { + GLWESecretTensor { + data: ScalarZnx::alloc(n.into(), Self::pairs(rank.into())), + rank, + dist: Distribution::NONE, + } + } + + pub fn bytes_of_from_infos(infos: &A) -> usize + where + A: GLWEInfos, + { + Self::bytes_of(infos.n(), Self::pairs(infos.rank().into()).into()) + } + + pub fn bytes_of(n: Degree, rank: Rank) -> usize { + ScalarZnx::bytes_of(n.into(), Self::pairs(rank.into())) + } +} + +impl GLWESecretTensor { + pub fn prepare(&mut self, module: &M, other: &S, scratch: &mut Scratch) + where + M: GLWESecretTensorFactory, + S: GLWESecretToRef + GLWEInfos, + Scratch: ScratchTakeCore, + { + module.glwe_secret_tensor_prepare(self, other, scratch); + } +} + +pub trait GLWESecretTensorFactory { + fn glwe_secret_tensor_prepare_tmp_bytes(&self, rank: Rank) -> usize; + + fn glwe_secret_tensor_prepare(&self, res: &mut R, other: &O, scratch: &mut Scratch) + where + R: GLWESecretToMut + GLWEInfos, + O: GLWESecretToRef + GLWEInfos; +} + +impl GLWESecretTensorFactory for Module +where + Self: ModuleN + + GLWESecretPreparedFactory + + VecZnxBigNormalize + + VecZnxDftApply + + SvpApplyDftToDft + + VecZnxIdftApplyTmpA + + VecZnxBigNormalize + + VecZnxDftBytesOf + + VecZnxBigBytesOf + + VecZnxBigNormalizeTmpBytes, + Scratch: ScratchTakeCore, +{ + fn glwe_secret_tensor_prepare_tmp_bytes(&self, rank: Rank) -> usize { + self.bytes_of_glwe_secret_prepared(rank) + + self.bytes_of_vec_znx_dft(rank.into(), 1) + + self.bytes_of_vec_znx_dft(1, 1) + + self.bytes_of_vec_znx_big(1, 1) + + self.vec_znx_big_normalize_tmp_bytes() + } + + fn glwe_secret_tensor_prepare(&self, res: &mut R, a: &A, scratch: &mut Scratch) + where + R: GLWESecretToMut + GLWEInfos, + A: GLWESecretToRef + GLWEInfos, + { + let res: &mut GLWESecret<&mut [u8]> = &mut res.to_mut(); + let a: &GLWESecret<&[u8]> = &a.to_ref(); + + println!("res.rank: {} a.rank: {}", res.rank(), a.rank()); + + assert_eq!(res.rank(), GLWESecretTensor::pairs(a.rank().into()) as u32); + assert_eq!(res.n(), self.n() as u32); + assert_eq!(a.n(), self.n() as u32); + + let rank: usize = a.rank().into(); + + let (mut a_prepared, scratch_1) = scratch.take_glwe_secret_prepared(self, rank.into()); + a_prepared.prepare(self, a); + + let base2k: usize = 17; + + let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, rank, 1); + for i in 0..rank { + self.vec_znx_dft_apply(1, 0, &mut a_dft, i, &a.data.as_vec_znx(), i); + } + + let (mut a_ij_big, scratch_3) = scratch_2.take_vec_znx_big(self, 1, 1); + let (mut a_ij_dft, scratch_4) = scratch_3.take_vec_znx_dft(self, 1, 1); + + // sk_tensor = sk (x) sk + // For example: (s0, s1) (x) (s0, s1) = (s0^2, s0s1, s1^2) + for i in 0..rank { + for j in i..rank { + let idx: usize = i * rank + j - (i * (i + 1) / 2); + self.svp_apply_dft_to_dft(&mut a_ij_dft, 0, &a_prepared.data, j, &a_dft, i); + self.vec_znx_idft_apply_tmpa(&mut a_ij_big, 0, &mut a_ij_dft, 0); + self.vec_znx_big_normalize( + base2k, + &mut res.data.as_vec_znx_mut(), + idx, + base2k, + &a_ij_big, + 0, + scratch_4, + ); + } + } + } +} diff --git a/poulpy-core/src/layouts/glwe_tensor_key.rs b/poulpy-core/src/layouts/glwe_tensor_key.rs index bc0100f..032a892 100644 --- a/poulpy-core/src/layouts/glwe_tensor_key.rs +++ b/poulpy-core/src/layouts/glwe_tensor_key.rs @@ -6,7 +6,6 @@ use poulpy_hal::{ use crate::layouts::{ Base2K, Degree, Dnum, Dsize, GGLWE, GGLWEInfos, GGLWEToMut, GGLWEToRef, GLWEInfos, LWEInfos, Rank, TorusPrecision, }; -use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::fmt; @@ -21,31 +20,29 @@ pub struct GLWETensorKeyLayout { } #[derive(PartialEq, Eq, Clone)] -pub struct GLWETensorKey { - pub(crate) keys: Vec>, -} +pub struct GLWETensorKey(pub(crate) GGLWE); impl LWEInfos for GLWETensorKey { fn n(&self) -> Degree { - self.keys[0].n() + self.0.n() } fn base2k(&self) -> Base2K { - self.keys[0].base2k() + self.0.base2k() } fn k(&self) -> TorusPrecision { - self.keys[0].k() + self.0.k() } fn size(&self) -> usize { - self.keys[0].size() + self.0.size() } } impl GLWEInfos for GLWETensorKey { fn rank(&self) -> Rank { - self.keys[0].rank_out() + self.0.rank_out() } } @@ -55,15 +52,15 @@ impl GGLWEInfos for GLWETensorKey { } fn rank_out(&self) -> Rank { - self.keys[0].rank_out() + self.0.rank_out() } fn dsize(&self) -> Dsize { - self.keys[0].dsize() + self.0.dsize() } fn dnum(&self) -> Dnum { - self.keys[0].dnum() + self.0.dnum() } } @@ -113,18 +110,14 @@ impl fmt::Debug for GLWETensorKey { impl FillUniform for GLWETensorKey { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { - self.keys - .iter_mut() - .for_each(|key: &mut GGLWE| key.fill_uniform(log_bound, source)) + self.0.fill_uniform(log_bound, source) } } impl fmt::Display for GLWETensorKey { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!(f, "(GLWETensorKey)",)?; - for (i, key) in self.keys.iter().enumerate() { - write!(f, "{i}: {key}")?; - } + write!(f, "{}", self.0)?; Ok(()) } } @@ -151,11 +144,7 @@ impl GLWETensorKey> { pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self { let pairs: u32 = (((rank.0 + 1) * rank.0) >> 1).max(1); - GLWETensorKey { - keys: (0..pairs) - .map(|_| GGLWE::alloc(n, base2k, k, Rank(1), rank, dnum, dsize)) - .collect(), - } + GLWETensorKey(GGLWE::alloc(n, base2k, k, Rank(pairs), rank, dnum, dsize)) } pub fn bytes_of_from_infos(infos: &A) -> usize @@ -178,85 +167,39 @@ impl GLWETensorKey> { } pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { - let pairs: usize = (((rank.0 + 1) * rank.0) >> 1).max(1) as usize; - pairs * GGLWE::bytes_of(n, base2k, k, Rank(1), rank, dnum, dsize) - } -} - -impl GLWETensorKey { - // Returns a mutable reference to GGLWE_{s}(s[i] * s[j]) - pub fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GGLWE { - if i > j { - std::mem::swap(&mut i, &mut j); - }; - let rank: usize = self.rank_out().into(); - &mut self.keys[i * rank + j - (i * (i + 1) / 2)] - } -} - -impl GLWETensorKey { - // Returns a reference to GGLWE_{s}(s[i] * s[j]) - pub fn at(&self, mut i: usize, mut j: usize) -> &GGLWE { - if i > j { - std::mem::swap(&mut i, &mut j); - }; - let rank: usize = self.rank_out().into(); - &self.keys[i * rank + j - (i * (i + 1) / 2)] + let pairs: u32 = (((rank.0 + 1) * rank.0) >> 1).max(1); + GGLWE::bytes_of(n, base2k, k, Rank(pairs), rank, dnum, dsize) } } impl ReaderFrom for GLWETensorKey { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { - let len: usize = reader.read_u64::()? as usize; - if self.keys.len() != len { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - format!("self.keys.len()={} != read len={}", self.keys.len(), len), - )); - } - for key in &mut self.keys { - key.read_from(reader)?; - } + self.0.read_from(reader)?; Ok(()) } } impl WriterTo for GLWETensorKey { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { - writer.write_u64::(self.keys.len() as u64)?; - for key in &self.keys { - key.write_to(writer)?; - } + self.0.write_to(writer)?; Ok(()) } } -pub trait GLWETensorKeyToRef { - fn to_ref(&self) -> GLWETensorKey<&[u8]>; -} - -impl GLWETensorKeyToRef for GLWETensorKey +impl GGLWEToRef for GLWETensorKey where GGLWE: GGLWEToRef, { - fn to_ref(&self) -> GLWETensorKey<&[u8]> { - GLWETensorKey { - keys: self.keys.iter().map(|c| c.to_ref()).collect(), - } + fn to_ref(&self) -> GGLWE<&[u8]> { + self.0.to_ref() } } -pub trait GLWETensorKeyToMut { - fn to_mut(&mut self) -> GLWETensorKey<&mut [u8]>; -} - -impl GLWETensorKeyToMut for GLWETensorKey +impl GGLWEToMut for GLWETensorKey where GGLWE: GGLWEToMut, { - fn to_mut(&mut self) -> GLWETensorKey<&mut [u8]> { - GLWETensorKey { - keys: self.keys.iter_mut().map(|c| c.to_mut()).collect(), - } + fn to_mut(&mut self) -> GGLWE<&mut [u8]> { + self.0.to_mut() } } diff --git a/poulpy-core/src/layouts/glwe_to_lwe_switching_key.rs b/poulpy-core/src/layouts/glwe_to_lwe_key.rs similarity index 79% rename from poulpy-core/src/layouts/glwe_to_lwe_switching_key.rs rename to poulpy-core/src/layouts/glwe_to_lwe_key.rs index bc3ee4b..2541d32 100644 --- a/poulpy-core/src/layouts/glwe_to_lwe_switching_key.rs +++ b/poulpy-core/src/layouts/glwe_to_lwe_key.rs @@ -59,9 +59,9 @@ impl GGLWEInfos for GLWEToLWEKeyLayout { /// A special [GLWESwitchingKey] required to for the conversion from [GLWE] to [LWE]. #[derive(PartialEq, Eq, Clone)] -pub struct GLWEToLWESwitchingKey(pub(crate) GLWESwitchingKey); +pub struct GLWEToLWEKey(pub(crate) GLWESwitchingKey); -impl LWEInfos for GLWEToLWESwitchingKey { +impl LWEInfos for GLWEToLWEKey { fn base2k(&self) -> Base2K { self.0.base2k() } @@ -79,12 +79,12 @@ impl LWEInfos for GLWEToLWESwitchingKey { } } -impl GLWEInfos for GLWEToLWESwitchingKey { +impl GLWEInfos for GLWEToLWEKey { fn rank(&self) -> Rank { self.rank_out() } } -impl GGLWEInfos for GLWEToLWESwitchingKey { +impl GGLWEInfos for GLWEToLWEKey { fn rank_in(&self) -> Rank { self.0.rank_in() } @@ -102,37 +102,37 @@ impl GGLWEInfos for GLWEToLWESwitchingKey { } } -impl fmt::Debug for GLWEToLWESwitchingKey { +impl fmt::Debug for GLWEToLWEKey { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{self}") } } -impl FillUniform for GLWEToLWESwitchingKey { +impl FillUniform for GLWEToLWEKey { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.0.fill_uniform(log_bound, source); } } -impl fmt::Display for GLWEToLWESwitchingKey { +impl fmt::Display for GLWEToLWEKey { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "(GLWEToLWESwitchingKey) {}", self.0) + write!(f, "(GLWEToLWEKey) {}", self.0) } } -impl ReaderFrom for GLWEToLWESwitchingKey { +impl ReaderFrom for GLWEToLWEKey { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { self.0.read_from(reader) } } -impl WriterTo for GLWEToLWESwitchingKey { +impl WriterTo for GLWEToLWEKey { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { self.0.write_to(writer) } } -impl GLWEToLWESwitchingKey> { +impl GLWEToLWEKey> { pub fn alloc_from_infos(infos: &A) -> Self where A: GGLWEInfos, @@ -140,12 +140,12 @@ impl GLWEToLWESwitchingKey> { assert_eq!( infos.rank_out().0, 1, - "rank_out > 1 is not supported for GLWEToLWESwitchingKey" + "rank_out > 1 is not supported for GLWEToLWEKey" ); assert_eq!( infos.dsize().0, 1, - "dsize > 1 is not supported for GLWEToLWESwitchingKey" + "dsize > 1 is not supported for GLWEToLWEKey" ); Self::alloc( infos.n(), @@ -157,7 +157,7 @@ impl GLWEToLWESwitchingKey> { } pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> Self { - GLWEToLWESwitchingKey(GLWESwitchingKey::alloc( + GLWEToLWEKey(GLWESwitchingKey::alloc( n, base2k, k, @@ -196,19 +196,19 @@ impl GLWEToLWESwitchingKey> { } } -impl GGLWEToRef for GLWEToLWESwitchingKey { +impl GGLWEToRef for GLWEToLWEKey { fn to_ref(&self) -> GGLWE<&[u8]> { self.0.to_ref() } } -impl GGLWEToMut for GLWEToLWESwitchingKey { +impl GGLWEToMut for GLWEToLWEKey { fn to_mut(&mut self) -> GGLWE<&mut [u8]> { self.0.to_mut() } } -impl GLWESwitchingKeyDegreesMut for GLWEToLWESwitchingKey { +impl GLWESwitchingKeyDegreesMut for GLWEToLWEKey { fn input_degree(&mut self) -> &mut Degree { &mut self.0.input_degree } @@ -218,7 +218,7 @@ impl GLWESwitchingKeyDegreesMut for GLWEToLWESwitchingKey { } } -impl GLWESwitchingKeyDegrees for GLWEToLWESwitchingKey { +impl GLWESwitchingKeyDegrees for GLWEToLWEKey { fn input_degree(&self) -> &Degree { &self.0.input_degree } diff --git a/poulpy-core/src/layouts/lwe_to_glwe_switching_key.rs b/poulpy-core/src/layouts/lwe_to_glwe_key.rs similarity index 73% rename from poulpy-core/src/layouts/lwe_to_glwe_switching_key.rs rename to poulpy-core/src/layouts/lwe_to_glwe_key.rs index caa676d..5a44f61 100644 --- a/poulpy-core/src/layouts/lwe_to_glwe_switching_key.rs +++ b/poulpy-core/src/layouts/lwe_to_glwe_key.rs @@ -11,7 +11,7 @@ use crate::layouts::{ }; #[derive(PartialEq, Eq, Copy, Clone, Debug)] -pub struct LWEToGLWESwitchingKeyLayout { +pub struct LWEToGLWEKeyLayout { pub n: Degree, pub base2k: Base2K, pub k: TorusPrecision, @@ -19,7 +19,7 @@ pub struct LWEToGLWESwitchingKeyLayout { pub dnum: Dnum, } -impl LWEInfos for LWEToGLWESwitchingKeyLayout { +impl LWEInfos for LWEToGLWEKeyLayout { fn base2k(&self) -> Base2K { self.base2k } @@ -33,13 +33,13 @@ impl LWEInfos for LWEToGLWESwitchingKeyLayout { } } -impl GLWEInfos for LWEToGLWESwitchingKeyLayout { +impl GLWEInfos for LWEToGLWEKeyLayout { fn rank(&self) -> Rank { self.rank_out() } } -impl GGLWEInfos for LWEToGLWESwitchingKeyLayout { +impl GGLWEInfos for LWEToGLWEKeyLayout { fn rank_in(&self) -> Rank { Rank(1) } @@ -58,9 +58,9 @@ impl GGLWEInfos for LWEToGLWESwitchingKeyLayout { } #[derive(PartialEq, Eq, Clone)] -pub struct LWEToGLWESwitchingKey(pub(crate) GLWESwitchingKey); +pub struct LWEToGLWEKey(pub(crate) GLWESwitchingKey); -impl LWEInfos for LWEToGLWESwitchingKey { +impl LWEInfos for LWEToGLWEKey { fn base2k(&self) -> Base2K { self.0.base2k() } @@ -78,12 +78,12 @@ impl LWEInfos for LWEToGLWESwitchingKey { } } -impl GLWEInfos for LWEToGLWESwitchingKey { +impl GLWEInfos for LWEToGLWEKey { fn rank(&self) -> Rank { self.rank_out() } } -impl GGLWEInfos for LWEToGLWESwitchingKey { +impl GGLWEInfos for LWEToGLWEKey { fn dsize(&self) -> Dsize { self.0.dsize() } @@ -101,37 +101,37 @@ impl GGLWEInfos for LWEToGLWESwitchingKey { } } -impl fmt::Debug for LWEToGLWESwitchingKey { +impl fmt::Debug for LWEToGLWEKey { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{self}") } } -impl FillUniform for LWEToGLWESwitchingKey { +impl FillUniform for LWEToGLWEKey { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.0.fill_uniform(log_bound, source); } } -impl fmt::Display for LWEToGLWESwitchingKey { +impl fmt::Display for LWEToGLWEKey { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "(LWEToGLWESwitchingKey) {}", self.0) + write!(f, "(LWEToGLWEKey) {}", self.0) } } -impl ReaderFrom for LWEToGLWESwitchingKey { +impl ReaderFrom for LWEToGLWEKey { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { self.0.read_from(reader) } } -impl WriterTo for LWEToGLWESwitchingKey { +impl WriterTo for LWEToGLWEKey { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { self.0.write_to(writer) } } -impl LWEToGLWESwitchingKey> { +impl LWEToGLWEKey> { pub fn alloc_from_infos(infos: &A) -> Self where A: GGLWEInfos, @@ -139,12 +139,12 @@ impl LWEToGLWESwitchingKey> { assert_eq!( infos.rank_in().0, 1, - "rank_in > 1 is not supported for LWEToGLWESwitchingKey" + "rank_in > 1 is not supported for LWEToGLWEKey" ); assert_eq!( infos.dsize().0, 1, - "dsize > 1 is not supported for LWEToGLWESwitchingKey" + "dsize > 1 is not supported for LWEToGLWEKey" ); Self::alloc( @@ -157,7 +157,7 @@ impl LWEToGLWESwitchingKey> { } pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> Self { - LWEToGLWESwitchingKey(GLWESwitchingKey::alloc( + LWEToGLWEKey(GLWESwitchingKey::alloc( n, base2k, k, @@ -175,12 +175,12 @@ impl LWEToGLWESwitchingKey> { assert_eq!( infos.rank_in().0, 1, - "rank_in > 1 is not supported for LWEToGLWESwitchingKey" + "rank_in > 1 is not supported for LWEToGLWEKey" ); assert_eq!( infos.dsize().0, 1, - "dsize > 1 is not supported for LWEToGLWESwitchingKey" + "dsize > 1 is not supported for LWEToGLWEKey" ); Self::bytes_of( infos.n(), @@ -196,19 +196,19 @@ impl LWEToGLWESwitchingKey> { } } -impl GGLWEToRef for LWEToGLWESwitchingKey { +impl GGLWEToRef for LWEToGLWEKey { fn to_ref(&self) -> GGLWE<&[u8]> { self.0.to_ref() } } -impl GGLWEToMut for LWEToGLWESwitchingKey { +impl GGLWEToMut for LWEToGLWEKey { fn to_mut(&mut self) -> GGLWE<&mut [u8]> { self.0.to_mut() } } -impl GLWESwitchingKeyDegreesMut for LWEToGLWESwitchingKey { +impl GLWESwitchingKeyDegreesMut for LWEToGLWEKey { fn input_degree(&mut self) -> &mut Degree { &mut self.0.input_degree } @@ -218,7 +218,7 @@ impl GLWESwitchingKeyDegreesMut for LWEToGLWESwitchingKey { } } -impl GLWESwitchingKeyDegrees for LWEToGLWESwitchingKey { +impl GLWESwitchingKeyDegrees for LWEToGLWEKey { fn input_degree(&self) -> &Degree { &self.0.input_degree } diff --git a/poulpy-core/src/layouts/mod.rs b/poulpy-core/src/layouts/mod.rs index 7c5cc5b..2dbc700 100644 --- a/poulpy-core/src/layouts/mod.rs +++ b/poulpy-core/src/layouts/mod.rs @@ -1,40 +1,44 @@ mod gglwe; +mod gglwe_to_ggsw_key; mod ggsw; mod glwe; mod glwe_automorphism_key; mod glwe_plaintext; mod glwe_public_key; mod glwe_secret; +mod glwe_secret_tensor; mod glwe_switching_key; mod glwe_tensor; mod glwe_tensor_key; -mod glwe_to_lwe_switching_key; +mod glwe_to_lwe_key; mod lwe; mod lwe_plaintext; mod lwe_secret; mod lwe_switching_key; -mod lwe_to_glwe_switching_key; +mod lwe_to_glwe_key; pub mod compressed; pub mod prepared; pub use compressed::*; pub use gglwe::*; +pub use gglwe_to_ggsw_key::*; pub use ggsw::*; pub use glwe::*; pub use glwe_automorphism_key::*; pub use glwe_plaintext::*; pub use glwe_public_key::*; pub use glwe_secret::*; +pub use glwe_secret_tensor::*; pub use glwe_switching_key::*; pub use glwe_tensor::*; pub use glwe_tensor_key::*; -pub use glwe_to_lwe_switching_key::*; +pub use glwe_to_lwe_key::*; pub use lwe::*; pub use lwe_plaintext::*; pub use lwe_secret::*; pub use lwe_switching_key::*; -pub use lwe_to_glwe_switching_key::*; +pub use lwe_to_glwe_key::*; pub use prepared::*; use poulpy_hal::layouts::{Backend, Module}; diff --git a/poulpy-core/src/layouts/prepared/gglwe_to_ggsw_key.rs b/poulpy-core/src/layouts/prepared/gglwe_to_ggsw_key.rs new file mode 100644 index 0000000..d63dca6 --- /dev/null +++ b/poulpy-core/src/layouts/prepared/gglwe_to_ggsw_key.rs @@ -0,0 +1,252 @@ +use poulpy_hal::layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}; + +use crate::layouts::{ + Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWEPrepared, GGLWEPreparedFactory, GGLWEPreparedToMut, GGLWEPreparedToRef, + GGLWEToGGSWKey, GGLWEToGGSWKeyToRef, GLWEInfos, LWEInfos, Rank, TorusPrecision, +}; + +pub struct GGLWEToGGSWKeyPrepared { + pub(crate) keys: Vec>, +} + +impl LWEInfos for GGLWEToGGSWKeyPrepared { + fn n(&self) -> Degree { + self.keys[0].n() + } + + fn base2k(&self) -> Base2K { + self.keys[0].base2k() + } + + fn k(&self) -> TorusPrecision { + self.keys[0].k() + } + + fn size(&self) -> usize { + self.keys[0].size() + } +} + +impl GLWEInfos for GGLWEToGGSWKeyPrepared { + fn rank(&self) -> Rank { + self.keys[0].rank_out() + } +} + +impl GGLWEInfos for GGLWEToGGSWKeyPrepared { + fn rank_in(&self) -> Rank { + self.rank_out() + } + + fn rank_out(&self) -> Rank { + self.keys[0].rank_out() + } + + fn dsize(&self) -> Dsize { + self.keys[0].dsize() + } + + fn dnum(&self) -> Dnum { + self.keys[0].dnum() + } +} + +pub trait GGLWEToGGSWKeyPreparedFactory { + fn alloc_gglwe_to_ggsw_key_prepared_from_infos(&self, infos: &A) -> GGLWEToGGSWKeyPrepared, BE> + where + A: GGLWEInfos; + + fn alloc_gglwe_to_ggsw_key_prepared( + &self, + base2k: Base2K, + k: TorusPrecision, + rank: Rank, + dnum: Dnum, + dsize: Dsize, + ) -> GGLWEToGGSWKeyPrepared, BE>; + + fn bytes_of_gglwe_to_ggsw_from_infos(&self, infos: &A) -> usize + where + A: GGLWEInfos; + + fn bytes_of_gglwe_to_ggsw(&self, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize; + + fn prepare_gglwe_to_ggsw_key_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos; + + fn prepare_gglwe_to_ggsw_key(&self, res: &mut R, other: &O, scratch: &mut Scratch) + where + R: GGLWEToGGSWKeyPreparedToMut, + O: GGLWEToGGSWKeyToRef; +} + +impl GGLWEToGGSWKeyPreparedFactory for Module +where + Self: GGLWEPreparedFactory, +{ + fn alloc_gglwe_to_ggsw_key_prepared_from_infos(&self, infos: &A) -> GGLWEToGGSWKeyPrepared, BE> + where + A: GGLWEInfos, + { + assert_eq!( + infos.rank_in(), + infos.rank_out(), + "rank_in != rank_out is not supported for GGLWEToGGSWKeyPrepared" + ); + self.alloc_gglwe_to_ggsw_key_prepared( + infos.base2k(), + infos.k(), + infos.rank(), + infos.dnum(), + infos.dsize(), + ) + } + + fn alloc_gglwe_to_ggsw_key_prepared( + &self, + base2k: Base2K, + k: TorusPrecision, + rank: Rank, + dnum: Dnum, + dsize: Dsize, + ) -> GGLWEToGGSWKeyPrepared, BE> { + GGLWEToGGSWKeyPrepared { + keys: (0..rank.as_usize()) + .map(|_| self.alloc_gglwe_prepared(base2k, k, rank, rank, dnum, dsize)) + .collect(), + } + } + + fn bytes_of_gglwe_to_ggsw_from_infos(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + assert_eq!( + infos.rank_in(), + infos.rank_out(), + "rank_in != rank_out is not supported for GGLWEToGGSWKeyPrepared" + ); + self.bytes_of_gglwe_to_ggsw( + infos.base2k(), + infos.k(), + infos.rank(), + infos.dnum(), + infos.dsize(), + ) + } + + fn bytes_of_gglwe_to_ggsw(&self, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { + rank.as_usize() * self.bytes_of_gglwe_prepared(base2k, k, rank, rank, dnum, dsize) + } + + fn prepare_gglwe_to_ggsw_key_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + self.prepare_gglwe_tmp_bytes(infos) + } + + fn prepare_gglwe_to_ggsw_key(&self, res: &mut R, other: &O, scratch: &mut Scratch) + where + R: GGLWEToGGSWKeyPreparedToMut, + O: GGLWEToGGSWKeyToRef, + { + let res: &mut GGLWEToGGSWKeyPrepared<&mut [u8], BE> = &mut res.to_mut(); + let other: &GGLWEToGGSWKey<&[u8]> = &other.to_ref(); + + assert_eq!(res.keys.len(), other.keys.len()); + + for (a, b) in res.keys.iter_mut().zip(other.keys.iter()) { + self.prepare_gglwe(a, b, scratch); + } + } +} + +impl GGLWEToGGSWKeyPrepared, BE> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: GGLWEInfos, + M: GGLWEToGGSWKeyPreparedFactory, + { + module.alloc_gglwe_to_ggsw_key_prepared_from_infos(infos) + } + + pub fn alloc(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self + where + M: GGLWEToGGSWKeyPreparedFactory, + { + module.alloc_gglwe_to_ggsw_key_prepared(base2k, k, rank, dnum, dsize) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: GGLWEToGGSWKeyPreparedFactory, + { + module.bytes_of_gglwe_to_ggsw_from_infos(infos) + } + + pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize + where + M: GGLWEToGGSWKeyPreparedFactory, + { + module.bytes_of_gglwe_to_ggsw(base2k, k, rank, dnum, dsize) + } +} + +impl GGLWEToGGSWKeyPrepared { + pub fn prepare(&mut self, module: &M, other: &O, scratch: &mut Scratch) + where + M: GGLWEToGGSWKeyPreparedFactory, + O: GGLWEToGGSWKeyToRef, + { + module.prepare_gglwe_to_ggsw_key(self, other, scratch); + } +} + +impl GGLWEToGGSWKeyPrepared { + // Returns a mutable reference to GGLWEPrepared_{s}([s[i]*s[0], s[i]*s[1], ..., s[i]*s[rank]]) + pub fn at_mut(&mut self, i: usize) -> &mut GGLWEPrepared { + assert!((i as u32) < self.rank()); + &mut self.keys[i] + } +} + +impl GGLWEToGGSWKeyPrepared { + // Returns a reference to GGLWEPrepared_{s}([s[i]*s[0], s[i]*s[1], ..., s[i]*s[rank]]) + pub fn at(&self, i: usize) -> &GGLWEPrepared { + assert!((i as u32) < self.rank()); + &self.keys[i] + } +} + +pub trait GGLWEToGGSWKeyPreparedToRef { + fn to_ref(&self) -> GGLWEToGGSWKeyPrepared<&[u8], BE>; +} + +impl GGLWEToGGSWKeyPreparedToRef for GGLWEToGGSWKeyPrepared +where + GGLWEPrepared: GGLWEPreparedToRef, +{ + fn to_ref(&self) -> GGLWEToGGSWKeyPrepared<&[u8], BE> { + GGLWEToGGSWKeyPrepared { + keys: self.keys.iter().map(|c| c.to_ref()).collect(), + } + } +} + +pub trait GGLWEToGGSWKeyPreparedToMut { + fn to_mut(&mut self) -> GGLWEToGGSWKeyPrepared<&mut [u8], BE>; +} + +impl GGLWEToGGSWKeyPreparedToMut for GGLWEToGGSWKeyPrepared +where + GGLWEPrepared: GGLWEPreparedToMut, +{ + fn to_mut(&mut self) -> GGLWEToGGSWKeyPrepared<&mut [u8], BE> { + GGLWEToGGSWKeyPrepared { + keys: self.keys.iter_mut().map(|c| c.to_mut()).collect(), + } + } +} diff --git a/poulpy-core/src/layouts/prepared/glwe_switching_key.rs b/poulpy-core/src/layouts/prepared/glwe_switching_key.rs index d73d17d..f73299b 100644 --- a/poulpy-core/src/layouts/prepared/glwe_switching_key.rs +++ b/poulpy-core/src/layouts/prepared/glwe_switching_key.rs @@ -109,7 +109,7 @@ where ) } - fn bytes_of_glwe_switching_key_prepared( + fn bytes_of_glwe_key_prepared( &self, base2k: Base2K, k: TorusPrecision, @@ -125,7 +125,7 @@ where where A: GGLWEInfos, { - self.bytes_of_glwe_switching_key_prepared( + self.bytes_of_glwe_key_prepared( infos.base2k(), infos.k(), infos.rank_in(), @@ -199,7 +199,7 @@ impl GLWESwitchingKeyPrepared, B> { where M: GLWESwitchingKeyPreparedFactory, { - module.bytes_of_glwe_switching_key_prepared(base2k, k, rank_in, rank_out, dnum, dsize) + module.bytes_of_glwe_key_prepared(base2k, k, rank_in, rank_out, dnum, dsize) } } diff --git a/poulpy-core/src/layouts/prepared/glwe_tensor_key.rs b/poulpy-core/src/layouts/prepared/glwe_tensor_key.rs index bd63c75..0304b37 100644 --- a/poulpy-core/src/layouts/prepared/glwe_tensor_key.rs +++ b/poulpy-core/src/layouts/prepared/glwe_tensor_key.rs @@ -2,29 +2,27 @@ use poulpy_hal::layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}; use crate::layouts::{ Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWEPrepared, GGLWEPreparedFactory, GGLWEPreparedToMut, GGLWEPreparedToRef, - GLWEInfos, GLWETensorKey, GLWETensorKeyToRef, LWEInfos, Rank, TorusPrecision, + GGLWEToRef, GLWEInfos, LWEInfos, Rank, TorusPrecision, }; #[derive(PartialEq, Eq)] -pub struct GLWETensorKeyPrepared { - pub(crate) keys: Vec>, -} +pub struct GLWETensorKeyPrepared(pub(crate) GGLWEPrepared); impl LWEInfos for GLWETensorKeyPrepared { fn n(&self) -> Degree { - self.keys[0].n() + self.0.n() } fn base2k(&self) -> Base2K { - self.keys[0].base2k() + self.0.base2k() } fn k(&self) -> TorusPrecision { - self.keys[0].k() + self.0.k() } fn size(&self) -> usize { - self.keys[0].size() + self.0.size() } } @@ -40,15 +38,15 @@ impl GGLWEInfos for GLWETensorKeyPrepared { } fn rank_out(&self) -> Rank { - self.keys[0].rank_out() + self.0.rank_out() } fn dsize(&self) -> Dsize { - self.keys[0].dsize() + self.0.dsize() } fn dnum(&self) -> Dnum { - self.keys[0].dnum() + self.0.dnum() } } @@ -65,11 +63,7 @@ where rank: Rank, ) -> GLWETensorKeyPrepared, B> { let pairs: u32 = (((rank.as_u32() + 1) * rank.as_u32()) >> 1).max(1); - GLWETensorKeyPrepared { - keys: (0..pairs) - .map(|_| self.alloc_gglwe_prepared(base2k, k, Rank(1), rank, dnum, dsize)) - .collect(), - } + GLWETensorKeyPrepared(self.alloc_gglwe_prepared(base2k, k, Rank(pairs), rank, dnum, dsize)) } fn alloc_tensor_key_prepared_from_infos(&self, infos: &A) -> GLWETensorKeyPrepared, B> @@ -91,8 +85,8 @@ where } fn bytes_of_tensor_key_prepared(&self, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { - let pairs: usize = (((rank.0 + 1) * rank.0) >> 1).max(1) as usize; - pairs * self.bytes_of_gglwe_prepared(base2k, k, Rank(1), rank, dnum, dsize) + let pairs: u32 = (((rank.as_u32() + 1) * rank.as_u32()) >> 1).max(1); + self.bytes_of_gglwe_prepared(base2k, k, Rank(pairs), rank, dnum, dsize) } fn bytes_of_tensor_key_prepared_from_infos(&self, infos: &A) -> usize @@ -117,17 +111,10 @@ where fn prepare_tensor_key(&self, res: &mut R, other: &O, scratch: &mut Scratch) where - R: GLWETensorKeyPreparedToMut, - O: GLWETensorKeyToRef, + R: GGLWEPreparedToMut, + O: GGLWEToRef, { - let mut res: GLWETensorKeyPrepared<&mut [u8], B> = res.to_mut(); - let other: GLWETensorKey<&[u8]> = other.to_ref(); - - assert_eq!(res.keys.len(), other.keys.len()); - - for (a, b) in res.keys.iter_mut().zip(other.keys.iter()) { - self.prepare_gglwe(a, b, scratch); - } + self.prepare_gglwe(res, other, scratch); } } @@ -165,28 +152,6 @@ impl GLWETensorKeyPrepared, B> { } } -impl GLWETensorKeyPrepared { - // Returns a mutable reference to GGLWE_{s}(s[i] * s[j]) - pub fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GGLWEPrepared { - if i > j { - std::mem::swap(&mut i, &mut j); - }; - let rank: usize = self.rank_out().into(); - &mut self.keys[i * rank + j - (i * (i + 1) / 2)] - } -} - -impl GLWETensorKeyPrepared { - // Returns a reference to GGLWE_{s}(s[i] * s[j]) - pub fn at(&self, mut i: usize, mut j: usize) -> &GGLWEPrepared { - if i > j { - std::mem::swap(&mut i, &mut j); - }; - let rank: usize = self.rank_out().into(); - &self.keys[i * rank + j - (i * (i + 1) / 2)] - } -} - impl GLWETensorKeyPrepared, B> { pub fn prepare_tmp_bytes(&self, module: &M, infos: &A) -> usize where @@ -200,39 +165,27 @@ impl GLWETensorKeyPrepared, B> { impl GLWETensorKeyPrepared { pub fn prepare(&mut self, module: &M, other: &O, scratch: &mut Scratch) where - O: GLWETensorKeyToRef, + O: GGLWEToRef, M: GLWETensorKeyPreparedFactory, { module.prepare_tensor_key(self, other, scratch); } } -pub trait GLWETensorKeyPreparedToMut { - fn to_mut(&mut self) -> GLWETensorKeyPrepared<&mut [u8], B>; -} - -impl GLWETensorKeyPreparedToMut for GLWETensorKeyPrepared +impl GGLWEPreparedToMut for GLWETensorKeyPrepared where GGLWEPrepared: GGLWEPreparedToMut, { - fn to_mut(&mut self) -> GLWETensorKeyPrepared<&mut [u8], B> { - GLWETensorKeyPrepared { - keys: self.keys.iter_mut().map(|c| c.to_mut()).collect(), - } + fn to_mut(&mut self) -> GGLWEPrepared<&mut [u8], B> { + self.0.to_mut() } } -pub trait GLWETensorKeyPreparedToRef { - fn to_ref(&self) -> GLWETensorKeyPrepared<&[u8], B>; -} - -impl GLWETensorKeyPreparedToRef for GLWETensorKeyPrepared +impl GGLWEPreparedToRef for GLWETensorKeyPrepared where GGLWEPrepared: GGLWEPreparedToRef, { - fn to_ref(&self) -> GLWETensorKeyPrepared<&[u8], B> { - GLWETensorKeyPrepared { - keys: self.keys.iter().map(|c| c.to_ref()).collect(), - } + fn to_ref(&self) -> GGLWEPrepared<&[u8], B> { + self.0.to_ref() } } diff --git a/poulpy-core/src/layouts/prepared/glwe_to_lwe_switching_key.rs b/poulpy-core/src/layouts/prepared/glwe_to_lwe_key.rs similarity index 54% rename from poulpy-core/src/layouts/prepared/glwe_to_lwe_switching_key.rs rename to poulpy-core/src/layouts/prepared/glwe_to_lwe_key.rs index 6edac5e..675a73f 100644 --- a/poulpy-core/src/layouts/prepared/glwe_to_lwe_switching_key.rs +++ b/poulpy-core/src/layouts/prepared/glwe_to_lwe_key.rs @@ -7,9 +7,9 @@ use crate::layouts::{ }; #[derive(PartialEq, Eq)] -pub struct GLWEToLWESwitchingKeyPrepared(pub(crate) GLWESwitchingKeyPrepared); +pub struct GLWEToLWEKeyPrepared(pub(crate) GLWESwitchingKeyPrepared); -impl LWEInfos for GLWEToLWESwitchingKeyPrepared { +impl LWEInfos for GLWEToLWEKeyPrepared { fn base2k(&self) -> Base2K { self.0.base2k() } @@ -27,13 +27,13 @@ impl LWEInfos for GLWEToLWESwitchingKeyPrepared { } } -impl GLWEInfos for GLWEToLWESwitchingKeyPrepared { +impl GLWEInfos for GLWEToLWEKeyPrepared { fn rank(&self) -> Rank { self.rank_out() } } -impl GGLWEInfos for GLWEToLWESwitchingKeyPrepared { +impl GGLWEInfos for GLWEToLWEKeyPrepared { fn rank_in(&self) -> Rank { self.0.rank_in() } @@ -51,65 +51,65 @@ impl GGLWEInfos for GLWEToLWESwitchingKeyPrepared { } } -pub trait GLWEToLWESwitchingKeyPreparedFactory +pub trait GLWEToLWEKeyPreparedFactory where Self: GLWESwitchingKeyPreparedFactory, { - fn alloc_glwe_to_lwe_switching_key_prepared( + fn alloc_glwe_to_lwe_key_prepared( &self, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum, - ) -> GLWEToLWESwitchingKeyPrepared, B> { - GLWEToLWESwitchingKeyPrepared(self.alloc_glwe_switching_key_prepared(base2k, k, rank_in, Rank(1), dnum, Dsize(1))) + ) -> GLWEToLWEKeyPrepared, B> { + GLWEToLWEKeyPrepared(self.alloc_glwe_switching_key_prepared(base2k, k, rank_in, Rank(1), dnum, Dsize(1))) } - fn alloc_glwe_to_lwe_switching_key_prepared_from_infos(&self, infos: &A) -> GLWEToLWESwitchingKeyPrepared, B> + fn alloc_glwe_to_lwe_key_prepared_from_infos(&self, infos: &A) -> GLWEToLWEKeyPrepared, B> where A: GGLWEInfos, { debug_assert_eq!( infos.rank_out().0, 1, - "rank_out > 1 is not supported for GLWEToLWESwitchingKeyPrepared" + "rank_out > 1 is not supported for GLWEToLWEKeyPrepared" ); debug_assert_eq!( infos.dsize().0, 1, - "dsize > 1 is not supported for GLWEToLWESwitchingKeyPrepared" + "dsize > 1 is not supported for GLWEToLWEKeyPrepared" ); - self.alloc_glwe_to_lwe_switching_key_prepared(infos.base2k(), infos.k(), infos.rank_in(), infos.dnum()) + self.alloc_glwe_to_lwe_key_prepared(infos.base2k(), infos.k(), infos.rank_in(), infos.dnum()) } - fn bytes_of_glwe_to_lwe_switching_key_prepared(&self, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> usize { - self.bytes_of_glwe_switching_key_prepared(base2k, k, rank_in, Rank(1), dnum, Dsize(1)) + fn bytes_of_glwe_to_lwe_key_prepared(&self, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> usize { + self.bytes_of_glwe_key_prepared(base2k, k, rank_in, Rank(1), dnum, Dsize(1)) } - fn bytes_of_glwe_to_lwe_switching_key_prepared_from_infos(&self, infos: &A) -> usize + fn bytes_of_glwe_to_lwe_key_prepared_from_infos(&self, infos: &A) -> usize where A: GGLWEInfos, { debug_assert_eq!( infos.rank_out().0, 1, - "rank_out > 1 is not supported for GLWEToLWESwitchingKeyPrepared" + "rank_out > 1 is not supported for GLWEToLWEKeyPrepared" ); debug_assert_eq!( infos.dsize().0, 1, - "dsize > 1 is not supported for GLWEToLWESwitchingKeyPrepared" + "dsize > 1 is not supported for GLWEToLWEKeyPrepared" ); - self.bytes_of_glwe_to_lwe_switching_key_prepared(infos.base2k(), infos.k(), infos.rank_in(), infos.dnum()) + self.bytes_of_glwe_to_lwe_key_prepared(infos.base2k(), infos.k(), infos.rank_in(), infos.dnum()) } - fn prepare_glwe_to_lwe_switching_key_tmp_bytes(&self, infos: &A) -> usize + fn prepare_glwe_to_lwe_key_tmp_bytes(&self, infos: &A) -> usize where A: GGLWEInfos, { self.prepare_glwe_switching_key_tmp_bytes(infos) } - fn prepare_glwe_to_lwe_switching_key(&self, res: &mut R, other: &O, scratch: &mut Scratch) + fn prepare_glwe_to_lwe_key(&self, res: &mut R, other: &O, scratch: &mut Scratch) where R: GGLWEPreparedToMut + GLWESwitchingKeyDegreesMut, O: GGLWEToRef + GLWESwitchingKeyDegrees, @@ -118,61 +118,61 @@ where } } -impl GLWEToLWESwitchingKeyPreparedFactory for Module where Self: GLWESwitchingKeyPreparedFactory {} +impl GLWEToLWEKeyPreparedFactory for Module where Self: GLWESwitchingKeyPreparedFactory {} -impl GLWEToLWESwitchingKeyPrepared, B> { +impl GLWEToLWEKeyPrepared, B> { pub fn alloc_from_infos(module: &M, infos: &A) -> Self where A: GGLWEInfos, - M: GLWEToLWESwitchingKeyPreparedFactory, + M: GLWEToLWEKeyPreparedFactory, { - module.alloc_glwe_to_lwe_switching_key_prepared_from_infos(infos) + module.alloc_glwe_to_lwe_key_prepared_from_infos(infos) } pub fn alloc(module: &M, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> Self where - M: GLWEToLWESwitchingKeyPreparedFactory, + M: GLWEToLWEKeyPreparedFactory, { - module.alloc_glwe_to_lwe_switching_key_prepared(base2k, k, rank_in, dnum) + module.alloc_glwe_to_lwe_key_prepared(base2k, k, rank_in, dnum) } pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize where A: GGLWEInfos, - M: GLWEToLWESwitchingKeyPreparedFactory, + M: GLWEToLWEKeyPreparedFactory, { - module.bytes_of_glwe_to_lwe_switching_key_prepared_from_infos(infos) + module.bytes_of_glwe_to_lwe_key_prepared_from_infos(infos) } pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> usize where - M: GLWEToLWESwitchingKeyPreparedFactory, + M: GLWEToLWEKeyPreparedFactory, { - module.bytes_of_glwe_to_lwe_switching_key_prepared(base2k, k, rank_in, dnum) + module.bytes_of_glwe_to_lwe_key_prepared(base2k, k, rank_in, dnum) } } -impl GLWEToLWESwitchingKeyPrepared, B> { +impl GLWEToLWEKeyPrepared, B> { pub fn prepare_tmp_bytes(&self, module: &M, infos: &A) where A: GGLWEInfos, - M: GLWEToLWESwitchingKeyPreparedFactory, + M: GLWEToLWEKeyPreparedFactory, { - module.prepare_glwe_to_lwe_switching_key_tmp_bytes(infos); + module.prepare_glwe_to_lwe_key_tmp_bytes(infos); } } -impl GLWEToLWESwitchingKeyPrepared { +impl GLWEToLWEKeyPrepared { pub fn prepare(&mut self, module: &M, other: &O, scratch: &mut Scratch) where O: GGLWEToRef + GLWESwitchingKeyDegrees, - M: GLWEToLWESwitchingKeyPreparedFactory, + M: GLWEToLWEKeyPreparedFactory, { - module.prepare_glwe_to_lwe_switching_key(self, other, scratch); + module.prepare_glwe_to_lwe_key(self, other, scratch); } } -impl GGLWEPreparedToRef for GLWEToLWESwitchingKeyPrepared +impl GGLWEPreparedToRef for GLWEToLWEKeyPrepared where GLWESwitchingKeyPrepared: GGLWEPreparedToRef, { @@ -181,7 +181,7 @@ where } } -impl GGLWEPreparedToMut for GLWEToLWESwitchingKeyPrepared +impl GGLWEPreparedToMut for GLWEToLWEKeyPrepared where GLWESwitchingKeyPrepared: GGLWEPreparedToRef, { @@ -190,7 +190,7 @@ where } } -impl GLWESwitchingKeyDegreesMut for GLWEToLWESwitchingKeyPrepared { +impl GLWESwitchingKeyDegreesMut for GLWEToLWEKeyPrepared { fn input_degree(&mut self) -> &mut Degree { &mut self.0.input_degree } @@ -200,7 +200,7 @@ impl GLWESwitchingKeyDegreesMut for GLWEToLWESwitchingKe } } -impl GLWESwitchingKeyDegrees for GLWEToLWESwitchingKeyPrepared { +impl GLWESwitchingKeyDegrees for GLWEToLWEKeyPrepared { fn input_degree(&self) -> &Degree { &self.0.input_degree } diff --git a/poulpy-core/src/layouts/prepared/lwe_switching_key.rs b/poulpy-core/src/layouts/prepared/lwe_switching_key.rs index 327d001..16f77eb 100644 --- a/poulpy-core/src/layouts/prepared/lwe_switching_key.rs +++ b/poulpy-core/src/layouts/prepared/lwe_switching_key.rs @@ -86,7 +86,7 @@ where } fn bytes_of_lwe_switching_key_prepared(&self, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> usize { - self.bytes_of_glwe_switching_key_prepared(base2k, k, Rank(1), Rank(1), dnum, Dsize(1)) + self.bytes_of_glwe_key_prepared(base2k, k, Rank(1), Rank(1), dnum, Dsize(1)) } fn bytes_of_lwe_switching_key_prepared_from_infos(&self, infos: &A) -> usize diff --git a/poulpy-core/src/layouts/prepared/lwe_to_glwe_switching_key.rs b/poulpy-core/src/layouts/prepared/lwe_to_glwe_key.rs similarity index 53% rename from poulpy-core/src/layouts/prepared/lwe_to_glwe_switching_key.rs rename to poulpy-core/src/layouts/prepared/lwe_to_glwe_key.rs index 30ed131..25f08f8 100644 --- a/poulpy-core/src/layouts/prepared/lwe_to_glwe_switching_key.rs +++ b/poulpy-core/src/layouts/prepared/lwe_to_glwe_key.rs @@ -8,9 +8,9 @@ use crate::layouts::{ /// A special [GLWESwitchingKey] required to for the conversion from [LWE] to [GLWE]. #[derive(PartialEq, Eq)] -pub struct LWEToGLWESwitchingKeyPrepared(pub(crate) GLWESwitchingKeyPrepared); +pub struct LWEToGLWEKeyPrepared(pub(crate) GLWESwitchingKeyPrepared); -impl LWEInfos for LWEToGLWESwitchingKeyPrepared { +impl LWEInfos for LWEToGLWEKeyPrepared { fn base2k(&self) -> Base2K { self.0.base2k() } @@ -28,13 +28,13 @@ impl LWEInfos for LWEToGLWESwitchingKeyPrepared { } } -impl GLWEInfos for LWEToGLWESwitchingKeyPrepared { +impl GLWEInfos for LWEToGLWEKeyPrepared { fn rank(&self) -> Rank { self.rank_out() } } -impl GGLWEInfos for LWEToGLWESwitchingKeyPrepared { +impl GGLWEInfos for LWEToGLWEKeyPrepared { fn dsize(&self) -> Dsize { self.0.dsize() } @@ -52,71 +52,65 @@ impl GGLWEInfos for LWEToGLWESwitchingKeyPrepared { } } -pub trait LWEToGLWESwitchingKeyPreparedFactory +pub trait LWEToGLWEKeyPreparedFactory where Self: GLWESwitchingKeyPreparedFactory, { - fn alloc_lwe_to_glwe_switching_key_prepared( + fn alloc_lwe_to_glwe_key_prepared( &self, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum, - ) -> LWEToGLWESwitchingKeyPrepared, B> { - LWEToGLWESwitchingKeyPrepared(self.alloc_glwe_switching_key_prepared(base2k, k, Rank(1), rank_out, dnum, Dsize(1))) + ) -> LWEToGLWEKeyPrepared, B> { + LWEToGLWEKeyPrepared(self.alloc_glwe_switching_key_prepared(base2k, k, Rank(1), rank_out, dnum, Dsize(1))) } - fn alloc_lwe_to_glwe_switching_key_prepared_from_infos(&self, infos: &A) -> LWEToGLWESwitchingKeyPrepared, B> + fn alloc_lwe_to_glwe_key_prepared_from_infos(&self, infos: &A) -> LWEToGLWEKeyPrepared, B> where A: GGLWEInfos, { debug_assert_eq!( infos.rank_in().0, 1, - "rank_in > 1 is not supported for LWEToGLWESwitchingKey" + "rank_in > 1 is not supported for LWEToGLWEKey" ); debug_assert_eq!( infos.dsize().0, 1, - "dsize > 1 is not supported for LWEToGLWESwitchingKey" + "dsize > 1 is not supported for LWEToGLWEKey" ); - self.alloc_lwe_to_glwe_switching_key_prepared(infos.base2k(), infos.k(), infos.rank_out(), infos.dnum()) + self.alloc_lwe_to_glwe_key_prepared(infos.base2k(), infos.k(), infos.rank_out(), infos.dnum()) } - fn bytes_of_lwe_to_glwe_switching_key_prepared( - &self, - base2k: Base2K, - k: TorusPrecision, - rank_out: Rank, - dnum: Dnum, - ) -> usize { - self.bytes_of_glwe_switching_key_prepared(base2k, k, Rank(1), rank_out, dnum, Dsize(1)) + fn bytes_of_lwe_to_glwe_key_prepared(&self, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> usize { + self.bytes_of_glwe_key_prepared(base2k, k, Rank(1), rank_out, dnum, Dsize(1)) } - fn bytes_of_lwe_to_glwe_switching_key_prepared_from_infos(&self, infos: &A) -> usize + fn bytes_of_lwe_to_glwe_key_prepared_from_infos(&self, infos: &A) -> usize where A: GGLWEInfos, { debug_assert_eq!( infos.rank_in().0, 1, - "rank_in > 1 is not supported for LWEToGLWESwitchingKey" + "rank_in > 1 is not supported for LWEToGLWEKey" ); debug_assert_eq!( infos.dsize().0, 1, - "dsize > 1 is not supported for LWEToGLWESwitchingKey" + "dsize > 1 is not supported for LWEToGLWEKey" ); - self.bytes_of_lwe_to_glwe_switching_key_prepared(infos.base2k(), infos.k(), infos.rank_out(), infos.dnum()) + self.bytes_of_lwe_to_glwe_key_prepared(infos.base2k(), infos.k(), infos.rank_out(), infos.dnum()) } - fn prepare_lwe_to_glwe_switching_key_tmp_bytes(&self, infos: &A) + fn prepare_lwe_to_glwe_key_tmp_bytes(&self, infos: &A) where A: GGLWEInfos, { self.prepare_glwe_switching_key_tmp_bytes(infos); } - fn prepare_lwe_to_glwe_switching_key(&self, res: &mut R, other: &O, scratch: &mut Scratch) + fn prepare_lwe_to_glwe_key(&self, res: &mut R, other: &O, scratch: &mut Scratch) where R: GGLWEPreparedToMut + GLWESwitchingKeyDegreesMut, O: GGLWEToRef + GLWESwitchingKeyDegrees, @@ -125,61 +119,61 @@ where } } -impl LWEToGLWESwitchingKeyPreparedFactory for Module where Self: GLWESwitchingKeyPreparedFactory {} +impl LWEToGLWEKeyPreparedFactory for Module where Self: GLWESwitchingKeyPreparedFactory {} -impl LWEToGLWESwitchingKeyPrepared, B> { +impl LWEToGLWEKeyPrepared, B> { pub fn alloc_from_infos(module: &M, infos: &A) -> Self where A: GGLWEInfos, - M: LWEToGLWESwitchingKeyPreparedFactory, + M: LWEToGLWEKeyPreparedFactory, { - module.alloc_lwe_to_glwe_switching_key_prepared_from_infos(infos) + module.alloc_lwe_to_glwe_key_prepared_from_infos(infos) } pub fn alloc(module: &M, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> Self where - M: LWEToGLWESwitchingKeyPreparedFactory, + M: LWEToGLWEKeyPreparedFactory, { - module.alloc_lwe_to_glwe_switching_key_prepared(base2k, k, rank_out, dnum) + module.alloc_lwe_to_glwe_key_prepared(base2k, k, rank_out, dnum) } pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize where A: GGLWEInfos, - M: LWEToGLWESwitchingKeyPreparedFactory, + M: LWEToGLWEKeyPreparedFactory, { - module.bytes_of_lwe_to_glwe_switching_key_prepared_from_infos(infos) + module.bytes_of_lwe_to_glwe_key_prepared_from_infos(infos) } pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> usize where - M: LWEToGLWESwitchingKeyPreparedFactory, + M: LWEToGLWEKeyPreparedFactory, { - module.bytes_of_lwe_to_glwe_switching_key_prepared(base2k, k, rank_out, dnum) + module.bytes_of_lwe_to_glwe_key_prepared(base2k, k, rank_out, dnum) } } -impl LWEToGLWESwitchingKeyPrepared, B> { +impl LWEToGLWEKeyPrepared, B> { pub fn prepare_tmp_bytes(&self, module: &M, infos: &A) where A: GGLWEInfos, - M: LWEToGLWESwitchingKeyPreparedFactory, + M: LWEToGLWEKeyPreparedFactory, { - module.prepare_lwe_to_glwe_switching_key_tmp_bytes(infos); + module.prepare_lwe_to_glwe_key_tmp_bytes(infos); } } -impl LWEToGLWESwitchingKeyPrepared { +impl LWEToGLWEKeyPrepared { pub fn prepare(&mut self, module: &M, other: &O, scratch: &mut Scratch) where O: GGLWEToRef + GLWESwitchingKeyDegrees, - M: LWEToGLWESwitchingKeyPreparedFactory, + M: LWEToGLWEKeyPreparedFactory, { - module.prepare_lwe_to_glwe_switching_key(self, other, scratch); + module.prepare_lwe_to_glwe_key(self, other, scratch); } } -impl GGLWEPreparedToRef for LWEToGLWESwitchingKeyPrepared +impl GGLWEPreparedToRef for LWEToGLWEKeyPrepared where GLWESwitchingKeyPrepared: GGLWEPreparedToRef, { @@ -188,7 +182,7 @@ where } } -impl GGLWEPreparedToMut for LWEToGLWESwitchingKeyPrepared +impl GGLWEPreparedToMut for LWEToGLWEKeyPrepared where GLWESwitchingKeyPrepared: GGLWEPreparedToMut, { @@ -197,7 +191,7 @@ where } } -impl GLWESwitchingKeyDegreesMut for LWEToGLWESwitchingKeyPrepared { +impl GLWESwitchingKeyDegreesMut for LWEToGLWEKeyPrepared { fn input_degree(&mut self) -> &mut Degree { &mut self.0.input_degree } diff --git a/poulpy-core/src/layouts/prepared/mod.rs b/poulpy-core/src/layouts/prepared/mod.rs index 8944b97..4d76cfb 100644 --- a/poulpy-core/src/layouts/prepared/mod.rs +++ b/poulpy-core/src/layouts/prepared/mod.rs @@ -1,4 +1,5 @@ mod gglwe; +mod gglwe_to_ggsw_key; mod ggsw; mod glwe; mod glwe_automorphism_key; @@ -6,11 +7,12 @@ mod glwe_public_key; mod glwe_secret; mod glwe_switching_key; mod glwe_tensor_key; -mod glwe_to_lwe_switching_key; +mod glwe_to_lwe_key; mod lwe_switching_key; -mod lwe_to_glwe_switching_key; +mod lwe_to_glwe_key; pub use gglwe::*; +pub use gglwe_to_ggsw_key::*; pub use ggsw::*; pub use glwe::*; pub use glwe_automorphism_key::*; @@ -18,6 +20,6 @@ pub use glwe_public_key::*; pub use glwe_secret::*; pub use glwe_switching_key::*; pub use glwe_tensor_key::*; -pub use glwe_to_lwe_switching_key::*; +pub use glwe_to_lwe_key::*; pub use lwe_switching_key::*; -pub use lwe_to_glwe_switching_key::*; +pub use lwe_to_glwe_key::*; diff --git a/poulpy-core/src/lib.rs b/poulpy-core/src/lib.rs index ccad084..e9c6499 100644 --- a/poulpy-core/src/lib.rs +++ b/poulpy-core/src/lib.rs @@ -4,6 +4,7 @@ mod decryption; mod dist; mod encryption; mod external_product; +mod glwe_packer; mod glwe_packing; mod glwe_trace; mod keyswitching; @@ -20,6 +21,7 @@ pub use decryption::*; pub use dist::*; pub use encryption::*; pub use external_product::*; +pub use glwe_packer::*; pub use glwe_packing::*; pub use glwe_trace::*; pub use keyswitching::*; diff --git a/poulpy-core/src/noise/gglwe.rs b/poulpy-core/src/noise/gglwe.rs index dc32d57..c6dd278 100644 --- a/poulpy-core/src/noise/gglwe.rs +++ b/poulpy-core/src/noise/gglwe.rs @@ -62,7 +62,7 @@ where let noise_have: f64 = pt.data.std(base2k, 0).log2(); - // println!("noise_have: {noise_have}"); + println!("noise_have: {noise_have}"); assert!( noise_have <= max_noise, diff --git a/poulpy-core/src/operations/glwe.rs b/poulpy-core/src/operations/glwe.rs index 492b611..9802c14 100644 --- a/poulpy-core/src/operations/glwe.rs +++ b/poulpy-core/src/operations/glwe.rs @@ -67,6 +67,14 @@ where ); } } + + // fn glwe_relinearize(&self, res: &mut R, a: &A, tsk: &T, scratch: &mut Scratch) + // where + // R: GLWEToRef, + // A: GLWETensorToRef, + // T: GLWETensorKeyPreparedToRef, + // { + // } } pub trait GLWEAdd diff --git a/poulpy-core/src/scratch.rs b/poulpy-core/src/scratch.rs index 2220dc4..944fbd7 100644 --- a/poulpy-core/src/scratch.rs +++ b/poulpy-core/src/scratch.rs @@ -7,7 +7,7 @@ use crate::{ dist::Distribution, layouts::{ Degree, GGLWE, GGLWEInfos, GGLWELayout, GGSW, GGSWInfos, GLWE, GLWEAutomorphismKey, GLWEInfos, GLWEPlaintext, - GLWEPrepared, GLWEPublicKey, GLWESecret, GLWESwitchingKey, GLWETensorKey, Rank, + GLWEPrepared, GLWEPublicKey, GLWESecret, GLWESecretTensor, GLWESwitchingKey, GLWETensorKey, Rank, prepared::{ GGLWEPrepared, GGSWPrepared, GLWEAutomorphismKeyPrepared, GLWEPublicKeyPrepared, GLWESecretPrepared, GLWESwitchingKeyPrepared, GLWETensorKeyPrepared, @@ -232,6 +232,18 @@ where ) } + fn take_glwe_secret_tensor(&mut self, n: Degree, rank: Rank) -> (GLWESecretTensor<&mut [u8]>, &mut Self) { + let (data, scratch) = self.take_scalar_znx(n.into(), GLWESecretTensor::pairs(rank.into())); + ( + GLWESecretTensor { + data, + rank, + dist: Distribution::NONE, + }, + scratch, + ) + } + fn take_glwe_secret_prepared(&mut self, module: &M, rank: Rank) -> (GLWESecretPrepared<&mut [u8], B>, &mut Self) where M: ModuleN + SvpPPolBytesOf, @@ -313,25 +325,12 @@ where infos.rank_out(), "rank_in != rank_out is not supported for GLWETensorKey" ); - let mut keys: Vec> = Vec::new(); - let pairs: usize = (((infos.rank_out().0 + 1) * infos.rank_out().0) >> 1).max(1) as usize; - - let mut scratch: &mut Self = self; + let pairs: u32 = (((infos.rank_out().0 + 1) * infos.rank_out().0) >> 1).max(1); let mut ksk_infos: GGLWELayout = infos.gglwe_layout(); - ksk_infos.rank_in = Rank(1); - - if pairs != 0 { - let (gglwe, s) = scratch.take_gglwe(&ksk_infos); - scratch = s; - keys.push(gglwe); - } - for _ in 1..pairs { - let (gglwe, s) = scratch.take_gglwe(&ksk_infos); - scratch = s; - keys.push(gglwe); - } - (GLWETensorKey { keys }, scratch) + ksk_infos.rank_in = Rank(pairs); + let (data, scratch) = self.take_gglwe(infos); + (GLWETensorKey(data), scratch) } fn take_glwe_tensor_key_prepared(&mut self, module: &M, infos: &A) -> (GLWETensorKeyPrepared<&mut [u8], B>, &mut Self) @@ -346,25 +345,11 @@ where "rank_in != rank_out is not supported for GGLWETensorKeyPrepared" ); - let mut keys: Vec> = Vec::new(); - let pairs: usize = (((infos.rank_out().0 + 1) * infos.rank_out().0) >> 1).max(1) as usize; - - let mut scratch: &mut Self = self; - + let pairs: u32 = (((infos.rank_out().0 + 1) * infos.rank_out().0) >> 1).max(1); let mut ksk_infos: GGLWELayout = infos.gglwe_layout(); - ksk_infos.rank_in = Rank(1); - - if pairs != 0 { - let (gglwe, s) = scratch.take_gglwe_prepared(module, &ksk_infos); - scratch = s; - keys.push(gglwe); - } - for _ in 1..pairs { - let (gglwe, s) = scratch.take_gglwe_prepared(module, &ksk_infos); - scratch = s; - keys.push(gglwe); - } - (GLWETensorKeyPrepared { keys }, scratch) + ksk_infos.rank_in = Rank(pairs); + let (data, scratch) = self.take_gglwe_prepared(module, infos); + (GLWETensorKeyPrepared(data), scratch) } } diff --git a/poulpy-core/src/tests/mod.rs b/poulpy-core/src/tests/mod.rs index dd16db0..aab0ec9 100644 --- a/poulpy-core/src/tests/mod.rs +++ b/poulpy-core/src/tests/mod.rs @@ -36,6 +36,7 @@ gglwe_automorphism_key_encrypt_sk => crate::tests::test_suite::encryption::test_ gglwe_automorphism_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_automorphism_key_compressed_encrypt_sk, gglwe_tensor_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_tensor_key_encrypt_sk, gglwe_tensor_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_tensor_key_compressed_encrypt_sk, +gglwe_to_ggsw_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_to_ggsw_key_encrypt_sk, // GGLWE Keyswitching gglwe_switching_key_keyswitch => crate::tests::test_suite::keyswitch::test_gglwe_switching_key_keyswitch, gglwe_switching_key_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_gglwe_switching_key_keyswitch_inplace, @@ -75,7 +76,7 @@ backend_test_suite!( glwe_encrypt_zero_sk => crate::tests::test_suite::encryption::test_glwe_encrypt_zero_sk, glwe_encrypt_pk => crate::tests::test_suite::encryption::test_glwe_encrypt_pk, // GLWE Keyswitch - glwe_keyswitch => crate::tests::test_suite::keyswitch::test_glwe_keyswitch, +glwe_keyswitch => crate::tests::test_suite::keyswitch::test_glwe_keyswitch, glwe_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_glwe_keyswitch_inplace, // GLWE Automorphism glwe_automorphism => crate::tests::test_suite::automorphism::test_glwe_automorphism, @@ -93,6 +94,7 @@ gglwe_automorphism_key_encrypt_sk => crate::tests::test_suite::encryption::test_ gglwe_automorphism_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_automorphism_key_compressed_encrypt_sk, gglwe_tensor_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_tensor_key_encrypt_sk, gglwe_tensor_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_tensor_key_compressed_encrypt_sk, +gglwe_to_ggsw_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_to_ggsw_key_encrypt_sk, // GGLWE Keyswitching gglwe_switching_key_keyswitch => crate::tests::test_suite::keyswitch::test_gglwe_switching_key_keyswitch, gglwe_switching_key_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_gglwe_switching_key_keyswitch_inplace, diff --git a/poulpy-core/src/tests/serialization.rs b/poulpy-core/src/tests/serialization.rs index c67d87d..14e62bb 100644 --- a/poulpy-core/src/tests/serialization.rs +++ b/poulpy-core/src/tests/serialization.rs @@ -1,12 +1,12 @@ use poulpy_hal::test_suite::serialization::test_reader_writer_interface; use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWE, GGSW, GLWE, GLWEAutomorphismKey, GLWESwitchingKey, GLWETensorKey, GLWEToLWESwitchingKey, - LWE, LWESwitchingKey, LWEToGLWESwitchingKey, Rank, TorusPrecision, + Base2K, Degree, Dnum, Dsize, GGLWE, GGSW, GLWE, GLWEAutomorphismKey, GLWESwitchingKey, GLWETensorKey, GLWEToLWEKey, LWE, + LWESwitchingKey, LWEToGLWEKey, Rank, TorusPrecision, compressed::{ GGLWECompressed, GGSWCompressed, GLWEAutomorphismKeyCompressed, GLWECompressed, GLWESwitchingKeyCompressed, GLWETensorKeyCompressed, GLWEToLWESwitchingKeyCompressed, LWECompressed, LWESwitchingKeyCompressed, - LWEToGLWESwitchingKeyCompressed, + LWEToGLWEKeyCompressed, }, }; @@ -93,28 +93,27 @@ fn test_tensor_key_compressed_serialization() { } #[test] -fn glwe_to_lwe_switching_key_serialization() { - let original: GLWEToLWESwitchingKey> = GLWEToLWESwitchingKey::alloc(N_GLWE, BASE2K, K, RANK, DNUM); +fn glwe_to_lwe_key_serialization() { + let original: GLWEToLWEKey> = GLWEToLWEKey::alloc(N_GLWE, BASE2K, K, RANK, DNUM); test_reader_writer_interface(original); } #[test] -fn glwe_to_lwe_switching_key_compressed_serialization() { +fn glwe_to_lwe_key_compressed_serialization() { let original: GLWEToLWESwitchingKeyCompressed> = GLWEToLWESwitchingKeyCompressed::alloc(N_GLWE, BASE2K, K, RANK, DNUM); test_reader_writer_interface(original); } #[test] -fn lwe_to_glwe_switching_key_serialization() { - let original: LWEToGLWESwitchingKey> = LWEToGLWESwitchingKey::alloc(N_GLWE, BASE2K, K, RANK, DNUM); +fn lwe_to_glwe_key_serialization() { + let original: LWEToGLWEKey> = LWEToGLWEKey::alloc(N_GLWE, BASE2K, K, RANK, DNUM); test_reader_writer_interface(original); } #[test] -fn lwe_to_glwe_switching_key_compressed_serialization() { - let original: LWEToGLWESwitchingKeyCompressed> = - LWEToGLWESwitchingKeyCompressed::alloc(N_GLWE, BASE2K, K, RANK, DNUM); +fn lwe_to_glwe_key_compressed_serialization() { + let original: LWEToGLWEKeyCompressed> = LWEToGLWEKeyCompressed::alloc(N_GLWE, BASE2K, K, RANK, DNUM); test_reader_writer_interface(original); } diff --git a/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs b/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs index b978a9d..6e2a226 100644 --- a/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs +++ b/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs @@ -5,12 +5,12 @@ use poulpy_hal::{ }; use crate::{ - GGSWAutomorphism, GGSWEncryptSk, GGSWNoise, GLWEAutomorphismKeyEncryptSk, GLWETensorKeyEncryptSk, ScratchTakeCore, + GGLWEToGGSWKeyEncryptSk, GGSWAutomorphism, GGSWEncryptSk, GGSWNoise, GLWEAutomorphismKeyEncryptSk, ScratchTakeCore, encryption::SIGMA, layouts::{ - GGSW, GGSWLayout, GLWEAutomorphismKey, GLWEAutomorphismKeyPreparedFactory, GLWESecret, GLWESecretPreparedFactory, - GLWETensorKey, GLWETensorKeyLayout, GLWETensorKeyPreparedFactory, - prepared::{GLWEAutomorphismKeyPrepared, GLWESecretPrepared, GLWETensorKeyPrepared}, + GGLWEToGGSWKey, GGLWEToGGSWKeyLayout, GGLWEToGGSWKeyPreparedFactory, GGSW, GGSWLayout, GLWEAutomorphismKey, + GLWEAutomorphismKeyPreparedFactory, GLWESecret, GLWESecretPreparedFactory, + prepared::{GGLWEToGGSWKeyPrepared, GLWEAutomorphismKeyPrepared, GLWESecretPrepared}, }, noise::noise_ggsw_keyswitch, }; @@ -21,8 +21,8 @@ where + GLWEAutomorphismKeyEncryptSk + GLWEAutomorphismKeyPreparedFactory + GGSWAutomorphism - + GLWETensorKeyPreparedFactory - + GLWETensorKeyEncryptSk + + GGLWEToGGSWKeyPreparedFactory + + GGLWEToGGSWKeyEncryptSk + GLWESecretPreparedFactory + VecZnxAutomorphismInplace + GGSWNoise, @@ -64,7 +64,7 @@ where rank: rank.into(), }; - let tensor_key_layout: GLWETensorKeyLayout = GLWETensorKeyLayout { + let tsk_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout { n: n.into(), base2k: base2k.into(), k: k_tsk.into(), @@ -73,7 +73,7 @@ where rank: rank.into(), }; - let auto_key_layout: GLWETensorKeyLayout = GLWETensorKeyLayout { + let auto_key_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout { n: n.into(), base2k: base2k.into(), k: k_ksk.into(), @@ -84,7 +84,7 @@ where let mut ct_in: GGSW> = GGSW::alloc_from_infos(&ggsw_in_layout); let mut ct_out: GGSW> = GGSW::alloc_from_infos(&ggsw_out_layout); - let mut tensor_key: GLWETensorKey> = GLWETensorKey::alloc_from_infos(&tensor_key_layout); + let mut tsk: GGLWEToGGSWKey> = GGLWEToGGSWKey::alloc_from_infos(&tsk_layout); let mut auto_key: GLWEAutomorphismKey> = GLWEAutomorphismKey::alloc_from_infos(&auto_key_layout); let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); @@ -95,8 +95,8 @@ where let mut scratch: ScratchOwned = ScratchOwned::alloc( GGSW::encrypt_sk_tmp_bytes(module, &ct_in) | GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &auto_key) - | GLWETensorKey::encrypt_sk_tmp_bytes(module, &tensor_key) - | GGSW::automorphism_tmp_bytes(module, &ct_out, &ct_in, &auto_key, &tensor_key), + | GGLWEToGGSWKey::encrypt_sk_tmp_bytes(module, &tsk) + | GGSW::automorphism_tmp_bytes(module, &ct_out, &ct_in, &auto_key, &tsk), ); let var_xs: f64 = 0.5; @@ -115,7 +115,7 @@ where &mut source_xe, scratch.borrow(), ); - tensor_key.encrypt_sk( + tsk.encrypt_sk( module, &sk, &mut source_xa, @@ -138,9 +138,8 @@ where GLWEAutomorphismKeyPrepared::alloc_from_infos(module, &auto_key_layout); auto_key_prepared.prepare(module, &auto_key, scratch.borrow()); - let mut tsk_prepared: GLWETensorKeyPrepared, BE> = - GLWETensorKeyPrepared::alloc_from_infos(module, &tensor_key_layout); - tsk_prepared.prepare(module, &tensor_key, scratch.borrow()); + let mut tsk_prepared: GGLWEToGGSWKeyPrepared, BE> = GGLWEToGGSWKeyPrepared::alloc_from_infos(module, &tsk); + tsk_prepared.prepare(module, &tsk, scratch.borrow()); ct_out.automorphism( module, @@ -180,8 +179,8 @@ where + GLWEAutomorphismKeyEncryptSk + GLWEAutomorphismKeyPreparedFactory + GGSWAutomorphism - + GLWETensorKeyPreparedFactory - + GLWETensorKeyEncryptSk + + GGLWEToGGSWKeyPreparedFactory + + GGLWEToGGSWKeyEncryptSk + GLWESecretPreparedFactory + VecZnxAutomorphismInplace + GGSWNoise, @@ -211,7 +210,7 @@ where rank: rank.into(), }; - let tensor_key_layout: GLWETensorKeyLayout = GLWETensorKeyLayout { + let tsk_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout { n: n.into(), base2k: base2k.into(), k: k_tsk.into(), @@ -220,7 +219,7 @@ where rank: rank.into(), }; - let auto_key_layout: GLWETensorKeyLayout = GLWETensorKeyLayout { + let auto_key_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout { n: n.into(), base2k: base2k.into(), k: k_ksk.into(), @@ -230,7 +229,7 @@ where }; let mut ct: GGSW> = GGSW::alloc_from_infos(&ggsw_out_layout); - let mut tensor_key: GLWETensorKey> = GLWETensorKey::alloc_from_infos(&tensor_key_layout); + let mut tsk: GGLWEToGGSWKey> = GGLWEToGGSWKey::alloc_from_infos(&tsk_layout); let mut auto_key: GLWEAutomorphismKey> = GLWEAutomorphismKey::alloc_from_infos(&auto_key_layout); let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); @@ -241,8 +240,8 @@ where let mut scratch: ScratchOwned = ScratchOwned::alloc( GGSW::encrypt_sk_tmp_bytes(module, &ct) | GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &auto_key) - | GLWETensorKey::encrypt_sk_tmp_bytes(module, &tensor_key) - | GGSW::automorphism_tmp_bytes(module, &ct, &ct, &auto_key, &tensor_key), + | GGLWEToGGSWKey::encrypt_sk_tmp_bytes(module, &tsk) + | GGSW::automorphism_tmp_bytes(module, &ct, &ct, &auto_key, &tsk), ); let var_xs: f64 = 0.5; @@ -261,7 +260,7 @@ where &mut source_xe, scratch.borrow(), ); - tensor_key.encrypt_sk( + tsk.encrypt_sk( module, &sk, &mut source_xa, @@ -284,9 +283,8 @@ where GLWEAutomorphismKeyPrepared::alloc_from_infos(module, &auto_key_layout); auto_key_prepared.prepare(module, &auto_key, scratch.borrow()); - let mut tsk_prepared: GLWETensorKeyPrepared, BE> = - GLWETensorKeyPrepared::alloc_from_infos(module, &tensor_key_layout); - tsk_prepared.prepare(module, &tensor_key, scratch.borrow()); + let mut tsk_prepared: GGLWEToGGSWKeyPrepared, BE> = GGLWEToGGSWKeyPrepared::alloc_from_infos(module, &tsk); + tsk_prepared.prepare(module, &tsk, scratch.borrow()); ct.automorphism_inplace(module, &auto_key_prepared, &tsk_prepared, scratch.borrow()); diff --git a/poulpy-core/src/tests/test_suite/conversion.rs b/poulpy-core/src/tests/test_suite/conversion.rs index c6e7d00..2412411 100644 --- a/poulpy-core/src/tests/test_suite/conversion.rs +++ b/poulpy-core/src/tests/test_suite/conversion.rs @@ -8,10 +8,10 @@ use crate::{ GLWEDecrypt, GLWEEncryptSk, GLWEFromLWE, GLWEToLWESwitchingKeyEncryptSk, LWEDecrypt, LWEEncryptSk, LWEToGLWESwitchingKeyEncryptSk, ScratchTakeCore, layouts::{ - Base2K, Degree, Dnum, GLWE, GLWELayout, GLWEPlaintext, GLWESecret, GLWESecretPreparedFactory, GLWEToLWEKeyLayout, - GLWEToLWESwitchingKey, GLWEToLWESwitchingKeyPreparedFactory, LWE, LWELayout, LWEPlaintext, LWESecret, - LWEToGLWESwitchingKey, LWEToGLWESwitchingKeyLayout, LWEToGLWESwitchingKeyPreparedFactory, Rank, TorusPrecision, - prepared::{GLWESecretPrepared, GLWEToLWESwitchingKeyPrepared, LWEToGLWESwitchingKeyPrepared}, + Base2K, Degree, Dnum, GLWE, GLWELayout, GLWEPlaintext, GLWESecret, GLWESecretPreparedFactory, GLWEToLWEKey, + GLWEToLWEKeyLayout, GLWEToLWEKeyPrepared, GLWEToLWEKeyPreparedFactory, LWE, LWELayout, LWEPlaintext, LWESecret, + LWEToGLWEKey, LWEToGLWEKeyLayout, LWEToGLWEKeyPrepared, LWEToGLWEKeyPreparedFactory, Rank, TorusPrecision, + prepared::GLWESecretPrepared, }, }; @@ -22,7 +22,7 @@ where + GLWEDecrypt + GLWESecretPreparedFactory + LWEEncryptSk - + LWEToGLWESwitchingKeyPreparedFactory, + + LWEToGLWEKeyPreparedFactory, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { @@ -36,7 +36,7 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); - let lwe_to_glwe_infos: LWEToGLWESwitchingKeyLayout = LWEToGLWESwitchingKeyLayout { + let lwe_to_glwe_infos: LWEToGLWEKeyLayout = LWEToGLWEKeyLayout { n: n_glwe, base2k: Base2K(17), k: TorusPrecision(51), @@ -58,7 +58,7 @@ where }; let mut scratch: ScratchOwned = ScratchOwned::alloc( - LWEToGLWESwitchingKey::encrypt_sk_tmp_bytes(module, &lwe_to_glwe_infos) + LWEToGLWEKey::encrypt_sk_tmp_bytes(module, &lwe_to_glwe_infos) | GLWE::from_lwe_tmp_bytes(module, &glwe_infos, &lwe_infos, &lwe_to_glwe_infos) | GLWE::decrypt_tmp_bytes(module, &glwe_infos), ); @@ -80,7 +80,7 @@ where let mut lwe_ct: LWE> = LWE::alloc_from_infos(&lwe_infos); lwe_ct.encrypt_sk(module, &lwe_pt, &sk_lwe, &mut source_xa, &mut source_xe); - let mut ksk: LWEToGLWESwitchingKey> = LWEToGLWESwitchingKey::alloc_from_infos(&lwe_to_glwe_infos); + let mut ksk: LWEToGLWEKey> = LWEToGLWEKey::alloc_from_infos(&lwe_to_glwe_infos); ksk.encrypt_sk( module, @@ -93,8 +93,7 @@ where let mut glwe_ct: GLWE> = GLWE::alloc_from_infos(&glwe_infos); - let mut ksk_prepared: LWEToGLWESwitchingKeyPrepared, BE> = - LWEToGLWESwitchingKeyPrepared::alloc_from_infos(module, &ksk); + let mut ksk_prepared: LWEToGLWEKeyPrepared, BE> = LWEToGLWEKeyPrepared::alloc_from_infos(module, &ksk); ksk_prepared.prepare(module, &ksk, scratch.borrow()); glwe_ct.from_lwe(module, &lwe_ct, &ksk_prepared, scratch.borrow()); @@ -114,7 +113,7 @@ where + GLWEDecrypt + GLWESecretPreparedFactory + GLWEToLWESwitchingKeyEncryptSk - + GLWEToLWESwitchingKeyPreparedFactory, + + GLWEToLWEKeyPreparedFactory, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { @@ -150,7 +149,7 @@ where let mut source_xe: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWEToLWESwitchingKey::encrypt_sk_tmp_bytes(module, &glwe_to_lwe_infos) + GLWEToLWEKey::encrypt_sk_tmp_bytes(module, &glwe_to_lwe_infos) | LWE::from_glwe_tmp_bytes(module, &lwe_infos, &glwe_infos, &glwe_to_lwe_infos) | GLWE::decrypt_tmp_bytes(module, &glwe_infos), ); @@ -178,7 +177,7 @@ where scratch.borrow(), ); - let mut ksk: GLWEToLWESwitchingKey> = GLWEToLWESwitchingKey::alloc_from_infos(&glwe_to_lwe_infos); + let mut ksk: GLWEToLWEKey> = GLWEToLWEKey::alloc_from_infos(&glwe_to_lwe_infos); ksk.encrypt_sk( module, @@ -191,8 +190,7 @@ where let mut lwe_ct: LWE> = LWE::alloc_from_infos(&lwe_infos); - let mut ksk_prepared: GLWEToLWESwitchingKeyPrepared, BE> = - GLWEToLWESwitchingKeyPrepared::alloc_from_infos(module, &ksk); + let mut ksk_prepared: GLWEToLWEKeyPrepared, BE> = GLWEToLWEKeyPrepared::alloc_from_infos(module, &ksk); ksk_prepared.prepare(module, &ksk, scratch.borrow()); lwe_ct.from_glwe(module, &glwe_ct, &ksk_prepared, scratch.borrow()); diff --git a/poulpy-core/src/tests/test_suite/encryption/gglwe_to_ggsw_key.rs b/poulpy-core/src/tests/test_suite/encryption/gglwe_to_ggsw_key.rs new file mode 100644 index 0000000..884e21a --- /dev/null +++ b/poulpy-core/src/tests/test_suite/encryption/gglwe_to_ggsw_key.rs @@ -0,0 +1,144 @@ +use poulpy_hal::{ + api::{ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxCopy}, + layouts::{Backend, Module, ScalarZnx, Scratch, ScratchOwned}, + source::Source, +}; + +use crate::{ + GGLWENoise, GGLWEToGGSWKeyCompressedEncryptSk, GGLWEToGGSWKeyEncryptSk, ScratchTakeCore, + decryption::GLWEDecrypt, + encryption::SIGMA, + layouts::{ + Dsize, GGLWEDecompress, GGLWEToGGSWKey, GGLWEToGGSWKeyCompressed, GGLWEToGGSWKeyDecompress, GGLWEToGGSWKeyLayout, + GLWESecret, GLWESecretPreparedFactory, GLWESecretTensor, GLWESecretTensorFactory, LWEInfos, prepared::GLWESecretPrepared, + }, +}; + +pub fn test_gglwe_to_ggsw_key_encrypt_sk(module: &Module) +where + Module: GGLWEToGGSWKeyEncryptSk + + GLWESecretTensorFactory + + GLWESecretPreparedFactory + + GLWEDecrypt + + GGLWENoise + + VecZnxCopy, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, +{ + let base2k: usize = 8; + let k: usize = 54; + + for rank in 2_usize..3 { + let n: usize = module.n(); + let dnum: usize = k / base2k; + + let key_infos: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout { + n: n.into(), + base2k: base2k.into(), + k: k.into(), + dnum: dnum.into(), + dsize: Dsize(1), + rank: rank.into(), + }; + + let mut key: GGLWEToGGSWKey> = GGLWEToGGSWKey::alloc_from_infos(&key_infos); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWEToGGSWKey::encrypt_sk_tmp_bytes(module, &key_infos)); + + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&key_infos); + sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank.into()); + sk_prepared.prepare(module, &sk); + + key.encrypt_sk( + module, + &sk, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + let mut sk_tensor: GLWESecretTensor> = GLWESecretTensor::alloc_from_infos(&sk); + sk_tensor.prepare(module, &sk, scratch.borrow()); + + let max_noise = SIGMA.log2() + 0.5 - (key.k().as_u32() as f64); + + let mut pt_want: ScalarZnx> = ScalarZnx::alloc(module.n(), rank); + + for i in 0..rank { + for j in 0..rank { + module.vec_znx_copy( + &mut pt_want.as_vec_znx_mut(), + j, + &sk_tensor.at(i, j).as_vec_znx(), + 0, + ); + } + + println!("pt_want: {}", pt_want.as_vec_znx()); + + module.gglwe_assert_noise(key.at(i), &sk_prepared, &pt_want, max_noise); + } + } +} + +pub fn test_gglwe_to_ggsw_compressed_encrypt_sk(module: &Module) +where + Module: GGLWEToGGSWKeyCompressedEncryptSk + + GLWESecretPreparedFactory + + GLWEDecrypt + + GLWESecretTensorFactory + + GGLWENoise + + GGLWEDecompress + + GGLWEToGGSWKeyDecompress, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, +{ + let base2k = 8; + let k = 54; + for rank in 1_usize..3 { + let n: usize = module.n(); + let dnum: usize = k / base2k; + + let key_infos: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout { + n: n.into(), + base2k: base2k.into(), + k: k.into(), + dnum: dnum.into(), + dsize: Dsize(1), + rank: rank.into(), + }; + + let mut key_compressed: GGLWEToGGSWKeyCompressed> = GGLWEToGGSWKeyCompressed::alloc_from_infos(&key_infos); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWEToGGSWKeyCompressed::encrypt_sk_tmp_bytes( + module, &key_infos, + )); + + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&key_infos); + sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank.into()); + sk_prepared.prepare(module, &sk); + + let seed_xa: [u8; 32] = [1u8; 32]; + + key_compressed.encrypt_sk(module, &sk, seed_xa, &mut source_xe, scratch.borrow()); + + let mut key: GGLWEToGGSWKey> = GGLWEToGGSWKey::alloc_from_infos(&key_infos); + key.decompress(module, &key_compressed); + + let mut sk_tensor: GLWESecretTensor> = GLWESecretTensor::alloc_from_infos(&sk); + sk_tensor.prepare(module, &sk, scratch.borrow()); + + for i in 0..rank { + module.gglwe_assert_noise(key.at(i), &sk_prepared, &sk_tensor.data, SIGMA + 0.5); + } + } +} diff --git a/poulpy-core/src/tests/test_suite/encryption/glwe_tsk.rs b/poulpy-core/src/tests/test_suite/encryption/glwe_tsk.rs index 940f917..26baa92 100644 --- a/poulpy-core/src/tests/test_suite/encryption/glwe_tsk.rs +++ b/poulpy-core/src/tests/test_suite/encryption/glwe_tsk.rs @@ -1,20 +1,16 @@ use poulpy_hal::{ - api::{ - ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDft, VecZnxBigAlloc, VecZnxBigNormalize, - VecZnxCopy, VecZnxDftAlloc, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyTmpA, VecZnxSubScalarInplace, - VecZnxSwitchRing, - }, - layouts::{Backend, Module, Scratch, ScratchOwned, VecZnxBig, VecZnxDft}, + api::{ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow}, + layouts::{Backend, Module, Scratch, ScratchOwned}, source::Source, }; use crate::{ - GLWETensorKeyCompressedEncryptSk, GLWETensorKeyEncryptSk, ScratchTakeCore, + GGLWENoise, GLWETensorKeyCompressedEncryptSk, GLWETensorKeyEncryptSk, ScratchTakeCore, decryption::GLWEDecrypt, encryption::SIGMA, layouts::{ - Dsize, GLWEPlaintext, GLWESecret, GLWESecretPreparedFactory, GLWETensorKey, GLWETensorKeyCompressed, GLWETensorKeyLayout, - prepared::GLWESecretPrepared, + Dsize, GGLWEDecompress, GLWESecret, GLWESecretPreparedFactory, GLWESecretTensor, GLWESecretTensorFactory, GLWETensorKey, + GLWETensorKeyCompressed, GLWETensorKeyLayout, prepared::GLWESecretPrepared, }, }; @@ -23,20 +19,15 @@ where Module: GLWETensorKeyEncryptSk + GLWESecretPreparedFactory + GLWEDecrypt - + VecZnxDftAlloc - + VecZnxBigAlloc - + VecZnxDftApply - + SvpApplyDftToDft - + VecZnxIdftApplyTmpA - + VecZnxBigNormalize - + VecZnxSubScalarInplace, + + GLWESecretTensorFactory + + GGLWENoise, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { let base2k: usize = 8; let k: usize = 54; - for rank in 1_usize..3 { + for rank in 2_usize..3 { let n: usize = module.n(); let dnum: usize = k / base2k; @@ -73,42 +64,10 @@ where scratch.borrow(), ); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&tensor_key_infos); + let mut sk_tensor: GLWESecretTensor> = GLWESecretTensor::alloc_from_infos(&sk); + sk_tensor.prepare(module, &sk, scratch.borrow()); - let mut sk_ij_dft: VecZnxDft, BE> = module.vec_znx_dft_alloc(1, 1); - let mut sk_ij_big: VecZnxBig, BE> = module.vec_znx_big_alloc(1, 1); - let mut sk_ij: GLWESecret> = GLWESecret::alloc(n.into(), 1_u32.into()); - let mut sk_dft: VecZnxDft, BE> = module.vec_znx_dft_alloc(rank, 1); - - for i in 0..rank { - module.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); - } - - for i in 0..rank { - for j in 0..rank { - module.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_prepared.data, j, &sk_dft, i); - module.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0); - module.vec_znx_big_normalize( - base2k, - &mut sk_ij.data.as_vec_znx_mut(), - 0, - base2k, - &sk_ij_big, - 0, - scratch.borrow(), - ); - for row_i in 0..dnum { - let ct = tensor_key.at(i, j).at(row_i, 0); - - ct.decrypt(module, &mut pt, &sk_prepared, scratch.borrow()); - - module.vec_znx_sub_scalar_inplace(&mut pt.data, 0, row_i, &sk_ij.data, 0); - - let std_pt: f64 = pt.data.std(base2k, 0) * (k as f64).exp2(); - assert!((SIGMA - std_pt).abs() <= 0.5, "{SIGMA} {std_pt}"); - } - } - } + module.gglwe_assert_noise(&tensor_key, &sk_prepared, &sk_tensor.data, SIGMA + 0.5); } } @@ -118,15 +77,9 @@ where + GLWESecretPreparedFactory + GLWETensorKeyCompressedEncryptSk + GLWEDecrypt - + VecZnxDftAlloc - + VecZnxBigAlloc - + VecZnxDftApply - + SvpApplyDftToDft - + VecZnxIdftApplyTmpA - + VecZnxSubScalarInplace - + VecZnxFillUniform - + VecZnxCopy - + VecZnxSwitchRing, + + GLWESecretTensorFactory + + GGLWENoise + + GGLWEDecompress, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { @@ -168,42 +121,9 @@ where let mut tensor_key: GLWETensorKey> = GLWETensorKey::alloc_from_infos(&tensor_key_infos); tensor_key.decompress(module, &tensor_key_compressed); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&tensor_key_infos); + let mut sk_tensor: GLWESecretTensor> = GLWESecretTensor::alloc_from_infos(&sk); + sk_tensor.prepare(module, &sk, scratch.borrow()); - let mut sk_ij_dft: VecZnxDft, BE> = module.vec_znx_dft_alloc(1, 1); - let mut sk_ij_big: VecZnxBig, BE> = module.vec_znx_big_alloc(1, 1); - let mut sk_ij: GLWESecret> = GLWESecret::alloc(n.into(), 1_u32.into()); - let mut sk_dft: VecZnxDft, BE> = module.vec_znx_dft_alloc(rank, 1); - - for i in 0..rank { - module.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); - } - - for i in 0..rank { - for j in 0..rank { - module.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_prepared.data, j, &sk_dft, i); - module.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0); - module.vec_znx_big_normalize( - base2k, - &mut sk_ij.data.as_vec_znx_mut(), - 0, - base2k, - &sk_ij_big, - 0, - scratch.borrow(), - ); - for row_i in 0..dnum { - tensor_key - .at(i, j) - .at(row_i, 0) - .decrypt(module, &mut pt, &sk_prepared, scratch.borrow()); - - module.vec_znx_sub_scalar_inplace(&mut pt.data, 0, row_i, &sk_ij.data, 0); - - let std_pt: f64 = pt.data.std(base2k, 0) * (k as f64).exp2(); - assert!((SIGMA - std_pt).abs() <= 0.5, "{SIGMA} {std_pt}"); - } - } - } + module.gglwe_assert_noise(&tensor_key, &sk_prepared, &sk_tensor.data, SIGMA + 0.5); } } diff --git a/poulpy-core/src/tests/test_suite/encryption/mod.rs b/poulpy-core/src/tests/test_suite/encryption/mod.rs index d871177..0fe0f49 100644 --- a/poulpy-core/src/tests/test_suite/encryption/mod.rs +++ b/poulpy-core/src/tests/test_suite/encryption/mod.rs @@ -1,11 +1,13 @@ mod gglwe_atk; mod gglwe_ct; +mod gglwe_to_ggsw_key; mod ggsw_ct; mod glwe_ct; mod glwe_tsk; pub use gglwe_atk::*; pub use gglwe_ct::*; +pub use gglwe_to_ggsw_key::*; pub use ggsw_ct::*; pub use glwe_ct::*; pub use glwe_tsk::*; diff --git a/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs b/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs index 28e71a5..c4191fa 100644 --- a/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs +++ b/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs @@ -5,12 +5,13 @@ use poulpy_hal::{ }; use crate::{ - GGSWEncryptSk, GGSWKeyswitch, GGSWNoise, GLWESwitchingKeyEncryptSk, GLWETensorKeyEncryptSk, ScratchTakeCore, + GGLWEToGGSWKeyEncryptSk, GGSWEncryptSk, GGSWKeyswitch, GGSWNoise, GLWESwitchingKeyEncryptSk, ScratchTakeCore, encryption::SIGMA, layouts::{ - GGSW, GGSWLayout, GLWESecret, GLWESecretPreparedFactory, GLWESwitchingKey, GLWESwitchingKeyLayout, - GLWESwitchingKeyPreparedFactory, GLWETensorKey, GLWETensorKeyLayout, GLWETensorKeyPreparedFactory, - prepared::{GLWESecretPrepared, GLWESwitchingKeyPrepared, GLWETensorKeyPrepared}, + GGLWEToGGSWKey, GGLWEToGGSWKeyPrepared, GGLWEToGGSWKeyPreparedFactory, GGSW, GGSWLayout, GLWESecret, + GLWESecretPreparedFactory, GLWESwitchingKey, GLWESwitchingKeyLayout, GLWESwitchingKeyPreparedFactory, + GLWETensorKeyLayout, + prepared::{GLWESecretPrepared, GLWESwitchingKeyPrepared}, }, noise::noise_ggsw_keyswitch, }; @@ -20,10 +21,10 @@ pub fn test_ggsw_keyswitch(module: &Module) where Module: GGSWEncryptSk + GLWESwitchingKeyEncryptSk - + GLWETensorKeyEncryptSk + + GGLWEToGGSWKeyEncryptSk + GGSWKeyswitch + GLWESecretPreparedFactory - + GLWETensorKeyPreparedFactory + + GGLWEToGGSWKeyPreparedFactory + GLWESwitchingKeyPreparedFactory + GGSWNoise, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, @@ -82,7 +83,7 @@ where let mut ggsw_in: GGSW> = GGSW::alloc_from_infos(&ggsw_in_infos); let mut ggsw_out: GGSW> = GGSW::alloc_from_infos(&ggsw_out_infos); - let mut tsk: GLWETensorKey> = GLWETensorKey::alloc_from_infos(&tsk_infos); + let mut tsk: GGLWEToGGSWKey> = GGLWEToGGSWKey::alloc_from_infos(&tsk_infos); let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc_from_infos(&ksk_apply_infos); let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); @@ -93,7 +94,7 @@ where let mut scratch: ScratchOwned = ScratchOwned::alloc( GGSW::encrypt_sk_tmp_bytes(module, &ggsw_in_infos) | GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &ksk_apply_infos) - | GLWETensorKey::encrypt_sk_tmp_bytes(module, &tsk_infos) + | GGLWEToGGSWKey::encrypt_sk_tmp_bytes(module, &tsk_infos) | GGSW::keyswitch_tmp_bytes( module, &ggsw_out_infos, @@ -148,7 +149,7 @@ where GLWESwitchingKeyPrepared::alloc_from_infos(module, &ksk); ksk_prepared.prepare(module, &ksk, scratch.borrow()); - let mut tsk_prepared: GLWETensorKeyPrepared, BE> = GLWETensorKeyPrepared::alloc_from_infos(module, &tsk); + let mut tsk_prepared: GGLWEToGGSWKeyPrepared, BE> = GGLWEToGGSWKeyPrepared::alloc_from_infos(module, &tsk); tsk_prepared.prepare(module, &tsk, scratch.borrow()); ggsw_out.keyswitch( @@ -185,10 +186,10 @@ pub fn test_ggsw_keyswitch_inplace(module: &Module) where Module: GGSWEncryptSk + GLWESwitchingKeyEncryptSk - + GLWETensorKeyEncryptSk + + GGLWEToGGSWKeyEncryptSk + GGSWKeyswitch + GLWESecretPreparedFactory - + GLWETensorKeyPreparedFactory + + GGLWEToGGSWKeyPreparedFactory + GLWESwitchingKeyPreparedFactory + GGSWNoise, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, @@ -236,7 +237,7 @@ where }; let mut ggsw_out: GGSW> = GGSW::alloc_from_infos(&ggsw_out_infos); - let mut tsk: GLWETensorKey> = GLWETensorKey::alloc_from_infos(&tsk_infos); + let mut tsk: GGLWEToGGSWKey> = GGLWEToGGSWKey::alloc_from_infos(&tsk_infos); let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc_from_infos(&ksk_apply_infos); let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); @@ -247,7 +248,7 @@ where let mut scratch: ScratchOwned = ScratchOwned::alloc( GGSW::encrypt_sk_tmp_bytes(module, &ggsw_out_infos) | GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &ksk_apply_infos) - | GLWETensorKey::encrypt_sk_tmp_bytes(module, &tsk_infos) + | GGLWEToGGSWKey::encrypt_sk_tmp_bytes(module, &tsk_infos) | GGSW::keyswitch_tmp_bytes( module, &ggsw_out_infos, @@ -302,7 +303,7 @@ where GLWESwitchingKeyPrepared::alloc_from_infos(module, &ksk); ksk_prepared.prepare(module, &ksk, scratch.borrow()); - let mut tsk_prepared: GLWETensorKeyPrepared, BE> = GLWETensorKeyPrepared::alloc_from_infos(module, &tsk); + let mut tsk_prepared: GGLWEToGGSWKeyPrepared, BE> = GGLWEToGGSWKeyPrepared::alloc_from_infos(module, &tsk); tsk_prepared.prepare(module, &tsk, scratch.borrow()); ggsw_out.keyswitch_inplace(module, &ksk_prepared, &tsk_prepared, scratch.borrow()); diff --git a/poulpy-core/src/tests/test_suite/packing.rs b/poulpy-core/src/tests/test_suite/packing.rs index 029e059..cf8284a 100644 --- a/poulpy-core/src/tests/test_suite/packing.rs +++ b/poulpy-core/src/tests/test_suite/packing.rs @@ -7,7 +7,7 @@ use poulpy_hal::{ }; use crate::{ - GLWEAutomorphismKeyEncryptSk, GLWEDecrypt, GLWEEncryptSk, GLWEPacker, GLWEPacking, GLWERotate, GLWESub, ScratchTakeCore, + GLWEAutomorphismKeyEncryptSk, GLWEDecrypt, GLWEEncryptSk, GLWEPacker, GLWEPackerOps, GLWERotate, GLWESub, ScratchTakeCore, layouts::{ GLWE, GLWEAutomorphismKey, GLWEAutomorphismKeyLayout, GLWEAutomorphismKeyPreparedFactory, GLWELayout, GLWEPlaintext, GLWESecret, GLWESecretPreparedFactory, @@ -20,7 +20,7 @@ where Module: GLWEEncryptSk + GLWEAutomorphismKeyEncryptSk + GLWEAutomorphismKeyPreparedFactory - + GLWEPacking + + GLWEPackerOps + GLWESecretPreparedFactory + GLWESub + GLWEDecrypt diff --git a/poulpy-hal/Cargo.toml b/poulpy-hal/Cargo.toml index 93325b3..f114681 100644 --- a/poulpy-hal/Cargo.toml +++ b/poulpy-hal/Cargo.toml @@ -19,7 +19,7 @@ rand_core = {workspace = true} byteorder = {workspace = true} once_cell = {workspace = true} rand_chacha = "0.9.0" -bytemuck = "1.23.2" +bytemuck = {workspace = true} [build-dependencies] diff --git a/poulpy-hal/src/api/convolution.rs b/poulpy-hal/src/api/convolution.rs index 10caf6b..d1c6c5e 100644 --- a/poulpy-hal/src/api/convolution.rs +++ b/poulpy-hal/src/api/convolution.rs @@ -78,6 +78,7 @@ where self.bytes_of_svp_ppol(1) + self.bytes_of_vec_znx_dft(1, b_size) } + #[allow(clippy::too_many_arguments)] /// Evaluates a bivariate convolution over Z[X, Y] / (X^N + 1) where Y = 2^-K over the /// selected columsn and stores the result on the selected column, scaled by 2^{k * Base2K} /// @@ -139,6 +140,7 @@ where } } + #[allow(clippy::too_many_arguments)] fn bivariate_convolution( &self, k: i64, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs index 373eb7c..6a541cf 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs @@ -10,8 +10,8 @@ use crate::tfhe::{ use poulpy_core::{ GLWEToLWESwitchingKeyEncryptSk, GetDistribution, LWEFromGLWE, ScratchTakeCore, layouts::{ - GGSWInfos, GGSWPreparedFactory, GLWEInfos, GLWESecretToRef, GLWEToLWEKeyLayout, GLWEToLWESwitchingKey, - GLWEToLWESwitchingKeyPreparedFactory, LWE, LWEInfos, LWESecretToRef, prepared::GLWEToLWESwitchingKeyPrepared, + GGSWInfos, GGSWPreparedFactory, GLWEInfos, GLWESecretToRef, GLWEToLWEKey, GLWEToLWEKeyLayout, + GLWEToLWEKeyPreparedFactory, LWE, LWEInfos, LWESecretToRef, prepared::GLWEToLWEKeyPrepared, }, }; use poulpy_hal::{ @@ -46,7 +46,7 @@ where BRA: BlindRotationAlgo, { cbt: CircuitBootstrappingKey, - ks: GLWEToLWESwitchingKey, + ks: GLWEToLWEKey, } impl BDDKey, BRA> @@ -59,7 +59,7 @@ where { Self { cbt: CircuitBootstrappingKey::alloc_from_infos(&infos.cbt_infos()), - ks: GLWEToLWESwitchingKey::alloc_from_infos(&infos.ks_infos()), + ks: GLWEToLWEKey::alloc_from_infos(&infos.ks_infos()), } } } @@ -130,12 +130,12 @@ where BE: Backend, { pub(crate) cbt: CircuitBootstrappingKeyPrepared, - pub(crate) ks: GLWEToLWESwitchingKeyPrepared, + pub(crate) ks: GLWEToLWEKeyPrepared, } pub trait BDDKeyPreparedFactory where - Self: Sized + CircuitBootstrappingKeyPreparedFactory + GLWEToLWESwitchingKeyPreparedFactory, + Self: Sized + CircuitBootstrappingKeyPreparedFactory + GLWEToLWEKeyPreparedFactory, { fn alloc_bdd_key_from_infos(&self, infos: &A) -> BDDKeyPrepared, BRA, BE> where @@ -143,7 +143,7 @@ where { BDDKeyPrepared { cbt: CircuitBootstrappingKeyPrepared::alloc_from_infos(self, &infos.cbt_infos()), - ks: GLWEToLWESwitchingKeyPrepared::alloc_from_infos(self, &infos.ks_infos()), + ks: GLWEToLWEKeyPrepared::alloc_from_infos(self, &infos.ks_infos()), } } @@ -152,7 +152,7 @@ where A: BDDKeyInfos, { self.circuit_bootstrapping_key_prepare_tmp_bytes(&infos.cbt_infos()) - .max(self.prepare_glwe_to_lwe_switching_key_tmp_bytes(&infos.ks_infos())) + .max(self.prepare_glwe_to_lwe_key_tmp_bytes(&infos.ks_infos())) } fn prepare_bdd_key(&self, res: &mut BDDKeyPrepared, other: &BDDKey, scratch: &mut Scratch) @@ -166,7 +166,7 @@ where } } impl BDDKeyPreparedFactory for Module where - Self: Sized + CircuitBootstrappingKeyPreparedFactory + GLWEToLWESwitchingKeyPreparedFactory + Self: Sized + CircuitBootstrappingKeyPreparedFactory + GLWEToLWEKeyPreparedFactory { } diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs index a627e0c..a5adc52 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs @@ -6,7 +6,7 @@ use poulpy_hal::{ }; use poulpy_core::{ - GGSWFromGGLWE, GLWEDecrypt, GLWEPacking, GLWETrace, ScratchTakeCore, + GGSWFromGGLWE, GLWEDecrypt, GLWEPacking, GLWERotate, GLWETrace, ScratchTakeCore, layouts::{ Dsize, GGLWELayout, GGSWInfos, GGSWToMut, GLWEInfos, GLWESecretPreparedFactory, GLWEToMut, GLWEToRef, LWEInfos, LWEToRef, }, @@ -115,7 +115,8 @@ where + GLWEPacking + GGSWFromGGLWE + GLWESecretPreparedFactory - + GLWEDecrypt, + + GLWEDecrypt + + GLWERotate, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchTakeCore, { @@ -216,7 +217,9 @@ pub fn circuit_bootstrap_core( + GLWEPacking + GGSWFromGGLWE + GLWESecretPreparedFactory - + GLWEDecrypt, + + GLWEDecrypt + + GLWERotate + + ModuleLogN, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchTakeCore, { @@ -332,7 +335,7 @@ fn post_process( ) where R: GLWEToMut, A: GLWEToRef, - M: ModuleLogN + GLWETrace + GLWEPacking, + M: ModuleLogN + GLWETrace + GLWEPacking + GLWERotate, Scratch: ScratchTakeCore, { let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/key.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/key.rs index c6b8adc..de57832 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/key.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/key.rs @@ -1,8 +1,8 @@ use poulpy_core::{ - Distribution, GLWEAutomorphismKeyEncryptSk, GLWETensorKeyEncryptSk, GetDistribution, ScratchTakeCore, + Distribution, GGLWEToGGSWKeyEncryptSk, GLWEAutomorphismKeyEncryptSk, GetDistribution, ScratchTakeCore, layouts::{ - GGLWEInfos, GGSWInfos, GLWEAutomorphismKey, GLWEAutomorphismKeyLayout, GLWEInfos, GLWESecretPreparedFactory, - GLWESecretToRef, GLWETensorKey, GLWETensorKeyLayout, LWEInfos, LWESecretToRef, prepared::GLWESecretPrepared, + GGLWEInfos, GGLWEToGGSWKey, GGSWInfos, GLWEAutomorphismKey, GLWEAutomorphismKeyLayout, GLWEInfos, + GLWESecretPreparedFactory, GLWESecretToRef, GLWETensorKeyLayout, LWEInfos, LWESecretToRef, prepared::GLWESecretPrepared, }, trace_galois_elements, }; @@ -81,14 +81,14 @@ impl CircuitBootstrappingKey, BRA> { (gal_el, key) }) .collect(), - tsk: GLWETensorKey::alloc_from_infos(trk_infos), + tsk: GGLWEToGGSWKey::alloc_from_infos(trk_infos), } } } pub struct CircuitBootstrappingKey { pub(crate) brk: BlindRotationKey, - pub(crate) tsk: GLWETensorKey>, + pub(crate) tsk: GGLWEToGGSWKey>, pub(crate) atk: HashMap>>, } @@ -112,7 +112,7 @@ impl CircuitBootstrappingKey { impl CircuitBootstrappingKeyEncryptSk for Module where - Self: GLWETensorKeyEncryptSk + Self: GGLWEToGGSWKeyEncryptSk + BlindRotationKeyEncryptSk + GLWEAutomorphismKeyEncryptSk + GLWESecretPreparedFactory, diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/key_prepared.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/key_prepared.rs index 6adca70..c611846 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/key_prepared.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/key_prepared.rs @@ -1,8 +1,8 @@ use poulpy_core::{ layouts::{ - GGLWEInfos, GGSWInfos, GLWEAutomorphismKeyLayout, GLWEAutomorphismKeyPreparedFactory, GLWEInfos, GLWETensorKeyLayout, - GLWETensorKeyPreparedFactory, LWEInfos, - prepared::{GLWEAutomorphismKeyPrepared, GLWETensorKeyPrepared}, + GGLWEInfos, GGLWEToGGSWKeyPrepared, GGLWEToGGSWKeyPreparedFactory, GGSWInfos, GLWEAutomorphismKeyLayout, + GLWEAutomorphismKeyPreparedFactory, GLWEInfos, GLWETensorKeyLayout, GLWETensorKeyPreparedFactory, LWEInfos, + prepared::GLWEAutomorphismKeyPrepared, }, trace_galois_elements, }; @@ -50,7 +50,7 @@ pub trait CircuitBootstrappingKeyPreparedFactory - + GLWETensorKeyPreparedFactory + + GGLWEToGGSWKeyPreparedFactory + GLWEAutomorphismKeyPreparedFactory, { fn circuit_bootstrapping_key_prepared_alloc_from_infos( @@ -65,7 +65,7 @@ where CircuitBootstrappingKeyPrepared { brk: BlindRotationKeyPrepared::alloc(self, &infos.brk_infos()), - tsk: GLWETensorKeyPrepared::alloc_from_infos(self, &infos.tsk_infos()), + tsk: GGLWEToGGSWKeyPrepared::alloc_from_infos(self, &infos.tsk_infos()), atk: gal_els .iter() .map(|&gal_el| { @@ -81,7 +81,7 @@ where A: CircuitBootstrappingKeyInfos, { self.blind_rotation_key_prepare_tmp_bytes(&infos.brk_infos()) - .max(self.prepare_tensor_key_tmp_bytes(&infos.tsk_infos())) + .max(self.prepare_gglwe_to_ggsw_key_tmp_bytes(&infos.tsk_infos())) .max(self.prepare_glwe_automorphism_key_tmp_bytes(&infos.atk_infos())) } @@ -105,7 +105,7 @@ where pub struct CircuitBootstrappingKeyPrepared { pub(crate) brk: BlindRotationKeyPrepared, - pub(crate) tsk: GLWETensorKeyPrepared, B>, + pub(crate) tsk: GGLWEToGGSWKeyPrepared, B>, pub(crate) atk: HashMap, B>>, }