diff --git a/poulpy-core/src/automorphism/glwe_ct.rs b/poulpy-core/src/automorphism/glwe_ct.rs index 8daa416..ff9dfcd 100644 --- a/poulpy-core/src/automorphism/glwe_ct.rs +++ b/poulpy-core/src/automorphism/glwe_ct.rs @@ -27,7 +27,7 @@ impl GLWE { pub fn automorphism(&mut self, module: &M, a: &A, key: &K, scratch: &mut Scratch) where M: GLWEAutomorphism, - A: GLWEToRef, + A: GLWEToRef + GLWEInfos, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, Scratch: ScratchTakeCore, { @@ -37,7 +37,7 @@ impl GLWE { pub fn automorphism_add(&mut self, module: &M, a: &A, key: &K, scratch: &mut Scratch) where M: GLWEAutomorphism, - A: GLWEToRef, + A: GLWEToRef + GLWEInfos, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, Scratch: ScratchTakeCore, { @@ -47,7 +47,7 @@ impl GLWE { pub fn automorphism_sub(&mut self, module: &M, a: &A, key: &K, scratch: &mut Scratch) where M: GLWEAutomorphism, - A: GLWEToRef, + A: GLWEToRef + GLWEInfos, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, Scratch: ScratchTakeCore, { @@ -57,7 +57,7 @@ impl GLWE { pub fn automorphism_sub_negate(&mut self, module: &M, a: &A, key: &K, scratch: &mut Scratch) where M: GLWEAutomorphism, - A: GLWEToRef, + A: GLWEToRef + GLWEInfos, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, Scratch: ScratchTakeCore, { @@ -110,46 +110,46 @@ pub trait GLWEAutomorphism { fn glwe_automorphism(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) where - R: GLWEToMut, - A: GLWEToRef, + R: GLWEToMut + GLWEInfos, + A: GLWEToRef + GLWEInfos, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos; fn glwe_automorphism_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) where - R: GLWEToMut, + R: GLWEToMut + GLWEInfos, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos; fn glwe_automorphism_add(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) where - R: GLWEToMut, - A: GLWEToRef, + R: GLWEToMut + GLWEInfos, + A: GLWEToRef + GLWEInfos, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos; fn glwe_automorphism_add_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) where - R: GLWEToMut, + R: GLWEToMut + GLWEInfos, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos; fn glwe_automorphism_sub(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) where - R: GLWEToMut, - A: GLWEToRef, + R: GLWEToMut + GLWEInfos, + A: GLWEToRef + GLWEInfos, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos; fn glwe_automorphism_sub_negate(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) where - R: GLWEToMut, - A: GLWEToRef, + R: GLWEToMut + GLWEInfos, + A: GLWEToRef + GLWEInfos, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos; fn glwe_automorphism_sub_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) where - R: GLWEToMut, + R: GLWEToMut + GLWEInfos, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos; fn glwe_automorphism_sub_negate_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) where - R: GLWEToMut, + R: GLWEToMut + GLWEInfos, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos; } @@ -179,8 +179,8 @@ where fn glwe_automorphism(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) where - R: GLWEToMut, - A: GLWEToRef, + R: GLWEToMut + GLWEInfos, + A: GLWEToRef + GLWEInfos, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, Scratch: ScratchTakeCore, { @@ -195,7 +195,7 @@ where fn glwe_automorphism_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) where - R: GLWEToMut, + R: GLWEToMut + GLWEInfos, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, Scratch: ScratchTakeCore, { @@ -210,8 +210,8 @@ where fn glwe_automorphism_add(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) where - R: GLWEToMut, - A: GLWEToRef, + R: GLWEToMut + GLWEInfos, + A: GLWEToRef + GLWEInfos, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, Scratch: ScratchTakeCore, { diff --git a/poulpy-core/src/external_product/glwe.rs b/poulpy-core/src/external_product/glwe.rs index 2c9fe12..ef85998 100644 --- a/poulpy-core/src/external_product/glwe.rs +++ b/poulpy-core/src/external_product/glwe.rs @@ -1,10 +1,10 @@ use poulpy_hal::{ api::{ - ModuleN, ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftApply, - VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, + ModuleN, ScratchAvailable, ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, + VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, - layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft}, + layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnxBig, VecZnxDft}, }; use crate::{ @@ -30,7 +30,7 @@ impl GLWE> { impl GLWE { pub fn external_product(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch) where - A: GLWEToRef, + A: GLWEToRef + GLWEInfos, B: GGSWPreparedToRef + GGSWInfos, M: GLWEExternalProduct, Scratch: ScratchTakeCore, @@ -57,20 +57,14 @@ pub trait GLWEExternalProduct { fn glwe_external_product_inplace(&self, res: &mut R, a: &D, scratch: &mut Scratch) where - R: GLWEToMut, + R: GLWEToMut + GLWEInfos, D: GGSWPreparedToRef + GGSWInfos, Scratch: ScratchTakeCore; fn glwe_external_product(&self, res: &mut R, lhs: &A, rhs: &D, scratch: &mut Scratch) where - R: GLWEToMut, - A: GLWEToRef, - D: GGSWPreparedToRef + GGSWInfos, - Scratch: ScratchTakeCore; - fn glwe_external_product_add(&self, res: &mut R, lhs: &A, rhs: &D, scratch: &mut Scratch) - where - R: GLWEToMut, - A: GLWEToRef, + R: GLWEToMut + GLWEInfos, + A: GLWEToRef + GLWEInfos, D: GGSWPreparedToRef + GGSWInfos, Scratch: ScratchTakeCore; } @@ -84,168 +78,113 @@ where + VecZnxBigAddSmallInplace + GLWENormalize, { - fn glwe_external_product_tmp_bytes(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize + fn glwe_external_product_tmp_bytes(&self, res: &R, a: &A, ggsw: &B) -> usize where R: GLWEInfos, A: GLWEInfos, B: GGSWInfos, { - let res_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), b_infos.size()); - res_dft - + self - .glwe_external_product_internal_tmp_bytes(res_infos, a_infos, b_infos) - .max(self.vec_znx_big_normalize_tmp_bytes()) + let cols: usize = res.rank().as_usize() + 1; + let size: usize = if a.base2k() != ggsw.base2k() { + let a_conv_infos = &GLWELayout { + n: a.n(), + base2k: ggsw.base2k(), + k: a.k(), + rank: a.rank(), + }; + self.glwe_external_product_internal_tmp_bytes(res, a_conv_infos, ggsw) + GLWE::bytes_of_from_infos(a_conv_infos) + } else { + self.glwe_external_product_internal_tmp_bytes(res, a, ggsw) + }; + + size.max(self.vec_znx_big_normalize_tmp_bytes()) + self.bytes_of_vec_znx_dft(cols, ggsw.size()) } - fn glwe_external_product_inplace(&self, res: &mut R, a: &D, scratch: &mut Scratch) + fn glwe_external_product_inplace(&self, res: &mut R, ggsw: &D, scratch: &mut Scratch) where - R: GLWEToMut, + R: GLWEToMut + GLWEInfos, D: GGSWPreparedToRef + GGSWInfos, Scratch: ScratchTakeCore, { - let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); - let rhs: &GGSWPrepared<&[u8], BE> = &a.to_ref(); + assert_eq!(ggsw.rank(), res.rank()); + assert_eq!(ggsw.n(), res.n()); + assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, res, ggsw)); - let basek_in: usize = res.base2k().into(); - let basek_ggsw: usize = rhs.base2k().into(); + let base2k_res: usize = res.base2k().as_usize(); + let base2k_ggsw: usize = ggsw.base2k().as_usize(); - #[cfg(debug_assertions)] - { - use poulpy_hal::api::ScratchAvailable; + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), ggsw.size()); // Todo optimise - assert_eq!(rhs.rank(), res.rank()); - assert_eq!(rhs.n(), res.n()); - assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, res, rhs)); - } - - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), a.size()); // Todo optimise - let res_big = self.glwe_external_product_internal(res_dft, res, a, scratch_1); - for j in 0..(res.rank() + 1).into() { - self.vec_znx_big_normalize( - basek_in, - &mut res.data, - j, - basek_ggsw, - &res_big, - j, - scratch_1, - ); - } - } - - fn glwe_external_product(&self, res: &mut R, lhs: &A, rhs: &D, scratch: &mut Scratch) - where - R: GLWEToMut, - A: GLWEToRef, - D: GGSWPreparedToRef, - Scratch: ScratchTakeCore, - { - let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); - let lhs: &GLWE<&[u8]> = &lhs.to_ref(); - - let rhs: &GGSWPrepared<&[u8], BE> = &rhs.to_ref(); - - let basek_ggsw: usize = rhs.base2k().into(); - let basek_out: usize = res.base2k().into(); - - #[cfg(debug_assertions)] - { - use poulpy_hal::api::ScratchAvailable; - - assert_eq!(rhs.rank(), lhs.rank()); - assert_eq!(rhs.rank(), res.rank()); - assert_eq!(rhs.n(), res.n()); - assert_eq!(lhs.n(), res.n()); - assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, lhs, rhs)); - } - - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), rhs.size()); // Todo optimise - let res_big = self.glwe_external_product_internal(res_dft, lhs, rhs, scratch_1); - - for j in 0..(res.rank() + 1).into() { - self.vec_znx_big_normalize( - basek_out, - &mut res.data, - j, - basek_ggsw, - &res_big, - j, - scratch_1, - ); - } - } - - fn glwe_external_product_add(&self, res: &mut R, a: &A, key: &D, scratch: &mut Scratch) - where - R: GLWEToMut, - A: GLWEToRef, - D: GGSWPreparedToRef, - Scratch: ScratchTakeCore, - { - let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); - let a: &GLWE<&[u8]> = &a.to_ref(); - let key: &GGSWPrepared<&[u8], BE> = &key.to_ref(); - - assert_eq!(a.base2k(), res.base2k()); - - let res_base2k: usize = res.base2k().into(); - let key_base2k: usize = key.base2k().into(); - - #[cfg(debug_assertions)] - { - use poulpy_hal::api::ScratchAvailable; - - assert_eq!(key.rank(), a.rank()); - assert_eq!(key.rank(), res.rank()); - assert_eq!(key.n(), res.n()); - assert_eq!(a.n(), res.n()); - assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, a, key)); - } - - if res_base2k == key_base2k { - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // Todo optimise - let mut res_big = self.glwe_external_product_internal(res_dft, a, key, scratch_1); - for j in 0..(res.rank() + 1).into() { - self.vec_znx_big_add_small_inplace(&mut res_big, j, res.data(), j); - self.vec_znx_big_normalize( - res_base2k, - &mut res.data, - j, - key_base2k, - &res_big, - j, - scratch_1, - ); - } - } else { - let (mut a_conv, scratch_1) = scratch.take_glwe(&GLWELayout { - n: a.n(), - base2k: key.base2k(), - k: a.k(), - rank: a.rank(), - }); + let res_big: VecZnxBig<&mut [u8], BE> = if base2k_res != base2k_ggsw { let (mut res_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { n: res.n(), - base2k: key.base2k(), + base2k: ggsw.base2k(), k: res.k(), rank: res.rank(), }); - self.glwe_normalize(&mut a_conv, a, scratch_2); self.glwe_normalize(&mut res_conv, res, scratch_2); - let (res_dft, scratch_2) = scratch_2.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // Todo optimise - let mut res_big = self.glwe_external_product_internal(res_dft, &a_conv, key, scratch_2); - for j in 0..(res.rank() + 1).into() { - self.vec_znx_big_add_small_inplace(&mut res_big, j, res_conv.data(), j); - self.vec_znx_big_normalize( - res_base2k, - &mut res.data, - j, - key_base2k, - &res_big, - j, - scratch_2, - ); - } + self.glwe_external_product_internal(res_dft, &res_conv, ggsw, scratch_2) + } else { + self.glwe_external_product_internal(res_dft, res, ggsw, scratch_1) + }; + + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + for j in 0..(res.rank() + 1).into() { + self.vec_znx_big_normalize( + base2k_res, + res.data_mut(), + j, + base2k_ggsw, + &res_big, + j, + scratch_1, + ); + } + } + + fn glwe_external_product(&self, res: &mut R, a: &A, ggsw: &G, scratch: &mut Scratch) + where + R: GLWEToMut + GLWEInfos, + A: GLWEToRef + GLWEInfos, + G: GGSWPreparedToRef + GGSWInfos, + Scratch: ScratchTakeCore, + { + assert_eq!(ggsw.rank(), a.rank()); + assert_eq!(ggsw.rank(), res.rank()); + assert_eq!(ggsw.n(), res.n()); + assert_eq!(a.n(), res.n()); + assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, a, ggsw)); + + let base2k_a: usize = a.base2k().into(); + let base2k_ggsw: usize = ggsw.base2k().into(); + let base2k_res: usize = res.base2k().into(); + + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), ggsw.size()); // Todo optimise + + let res_big: VecZnxBig<&mut [u8], BE> = if base2k_a != base2k_ggsw { + let (mut a_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { + n: a.n(), + base2k: ggsw.base2k(), + k: a.k(), + rank: a.rank(), + }); + self.glwe_normalize(&mut a_conv, a, scratch_2); + self.glwe_external_product_internal(res_dft, &a_conv, ggsw, scratch_2) + } else { + self.glwe_external_product_internal(res_dft, a, ggsw, scratch_1) + }; + + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + for j in 0..(res.rank() + 1).into() { + self.vec_znx_big_normalize( + base2k_res, + res.data_mut(), + j, + base2k_ggsw, + &res_big, + j, + scratch_1, + ); } } } @@ -309,12 +248,7 @@ where ); let normalize_big: usize = self.vec_znx_normalize_tmp_bytes(); - if a_infos.base2k() == b_infos.base2k() { - a_dft + (vmp | normalize_big) - } else { - let normalize_conv: usize = VecZnx::bytes_of(self.n(), (b_infos.rank() + 1).into(), in_size); - (a_dft + normalize_conv + (self.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big - } + a_dft + vmp.max(normalize_big) } fn glwe_external_product_internal( @@ -333,69 +267,36 @@ where let a: &GLWE<&[u8]> = &a.to_ref(); let ggsw: &GGSWPrepared<&[u8], BE> = &ggsw.to_ref(); - let basek_in: usize = a.base2k().into(); - let basek_ggsw: usize = ggsw.base2k().into(); + assert_eq!(a.base2k(), ggsw.base2k()); let cols: usize = (ggsw.rank() + 1).into(); let dsize: usize = ggsw.dsize().into(); - let a_size: usize = (a.size() * basek_in).div_ceil(basek_ggsw); + let a_size: usize = a.size(); let (mut a_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, a_size.div_ceil(dsize)); a_dft.data_mut().fill(0); - if basek_in == basek_ggsw { - for di in 0..dsize { - // (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize) - a_dft.set_size((a.size() + di) / dsize); + for di in 0..dsize { + // (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize) + a_dft.set_size((a.size() + di) / dsize); - // Small optimization for dsize > 2 - // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then - // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}. - // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. - // It is possible to further ignore the last dsize-1 limbs, but this introduce - // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same - // noise is kept with respect to the ideal functionality. - res_dft.set_size(ggsw.size() - ((dsize - di) as isize - 2).max(0) as usize); - - for j in 0..cols { - self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &a.data, j); - } - - if di == 0 { - self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &ggsw.data, scratch_1); - } else { - self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &ggsw.data, di, scratch_1); - } - } - } else { - let (mut a_conv, scratch_3) = scratch_1.take_vec_znx(self.n(), cols, a_size); + // Small optimization for dsize > 2 + // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then + // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}. + // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. + // It is possible to further ignore the last dsize-1 limbs, but this introduce + // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same + // noise is kept with respect to the ideal functionality. + res_dft.set_size(ggsw.size() - ((dsize - di) as isize - 2).max(0) as usize); for j in 0..cols { - self.vec_znx_normalize(basek_ggsw, &mut a_conv, j, basek_in, &a.data, j, scratch_3); + self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &a.data, j); } - for di in 0..dsize { - // (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize) - a_dft.set_size((a.size() + di) / dsize); - - // Small optimization for dsize > 2 - // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then - // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}. - // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. - // It is possible to further ignore the last dsize-1 limbs, but this introduce - // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same - // noise is kept with respect to the ideal functionality. - res_dft.set_size(ggsw.size() - ((dsize - di) as isize - 2).max(0) as usize); - - for j in 0..cols { - self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &a.data, j); - } - - if di == 0 { - self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &ggsw.data, scratch_1); - } else { - self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &ggsw.data, di, scratch_1); - } + if di == 0 { + self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &ggsw.data, scratch_1); + } else { + self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &ggsw.data, di, scratch_1); } } diff --git a/poulpy-core/src/keyswitching/ggsw.rs b/poulpy-core/src/keyswitching/ggsw.rs index 3a5efb3..b36644d 100644 --- a/poulpy-core/src/keyswitching/ggsw.rs +++ b/poulpy-core/src/keyswitching/ggsw.rs @@ -29,7 +29,7 @@ impl GGSW { pub fn keyswitch(&mut self, module: &M, a: &A, key: &K, tsk: &T, scratch: &mut Scratch) where A: GGSWToRef, - K: GGLWEPreparedToRef, + K: GGLWEPreparedToRef + GGLWEInfos, T: GGLWEToGGSWKeyPreparedToRef, Scratch: ScratchTakeCore, M: GGSWKeyswitch, @@ -39,7 +39,7 @@ impl GGSW { pub fn keyswitch_inplace(&mut self, module: &M, key: &K, tsk: &T, scratch: &mut Scratch) where - K: GGLWEPreparedToRef, + K: GGLWEPreparedToRef + GGLWEInfos, T: GGLWEToGGSWKeyPreparedToRef, Scratch: ScratchTakeCore, M: GGSWKeyswitch, @@ -70,7 +70,7 @@ where fn ggsw_keyswitch_inplace(&self, res: &mut R, key: &K, tsk: &T, scratch: &mut Scratch) where R: GGSWToMut, - K: GGLWEPreparedToRef, + K: GGLWEPreparedToRef + GGLWEInfos, T: GGLWEToGGSWKeyPreparedToRef, Scratch: ScratchTakeCore, { @@ -89,7 +89,7 @@ where where R: GGSWToMut, A: GGSWToRef, - K: GGLWEPreparedToRef, + K: GGLWEPreparedToRef + GGLWEInfos, T: GGLWEToGGSWKeyPreparedToRef, Scratch: ScratchTakeCore, { @@ -125,14 +125,14 @@ where where R: GGSWToMut, A: GGSWToRef, - K: GGLWEPreparedToRef, + K: GGLWEPreparedToRef + GGLWEInfos, T: GGLWEToGGSWKeyPreparedToRef, Scratch: ScratchTakeCore; fn ggsw_keyswitch_inplace(&self, res: &mut R, key: &K, tsk: &T, scratch: &mut Scratch) where R: GGSWToMut, - K: GGLWEPreparedToRef, + K: GGLWEPreparedToRef + GGLWEInfos, T: GGLWEToGGSWKeyPreparedToRef, Scratch: ScratchTakeCore; } diff --git a/poulpy-core/src/keyswitching/glwe.rs b/poulpy-core/src/keyswitching/glwe.rs index 5bb5875..5fe298e 100644 --- a/poulpy-core/src/keyswitching/glwe.rs +++ b/poulpy-core/src/keyswitching/glwe.rs @@ -27,8 +27,8 @@ impl GLWE> { impl GLWE { pub fn keyswitch(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch) where - A: GLWEToRef, - B: GGLWEPreparedToRef, + A: GLWEToRef + GLWEInfos, + B: GGLWEPreparedToRef + GGLWEInfos, M: GLWEKeyswitch, Scratch: ScratchTakeCore, { @@ -37,7 +37,7 @@ impl GLWE { pub fn keyswitch_inplace(&mut self, module: &M, a: &A, scratch: &mut Scratch) where - A: GGLWEPreparedToRef, + A: GGLWEPreparedToRef + GGLWEInfos, M: GLWEKeyswitch, Scratch: ScratchTakeCore, { @@ -74,14 +74,10 @@ where fn glwe_keyswitch(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) where - R: GLWEToMut, - A: GLWEToRef, - K: GGLWEPreparedToRef, + R: GLWEToMut + GLWEInfos, + A: GLWEToRef + GLWEInfos, + K: GGLWEPreparedToRef + GGLWEInfos, { - let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); - let a: &GLWE<&[u8]> = &a.to_ref(); - let key: &GGLWEPrepared<&[u8], BE> = &key.to_ref(); - assert_eq!( a.rank(), key.rank_in(), @@ -128,10 +124,11 @@ where self.glwe_keyswitch_internal(res_dft, a, key, scratch_1) }; + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); for i in 0..(res.rank() + 1).into() { self.vec_znx_big_normalize( base2k_res, - &mut res.data, + res.data_mut(), i, base2k_key, &res_big, @@ -143,12 +140,9 @@ where fn glwe_keyswitch_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) where - R: GLWEToMut, - K: GGLWEPreparedToRef, + R: GLWEToMut + GLWEInfos, + K: GGLWEPreparedToRef + GGLWEInfos, { - let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); - let key: &GGLWEPrepared<&[u8], BE> = &key.to_ref(); - assert_eq!( res.rank(), key.rank_in(), @@ -194,6 +188,7 @@ where self.glwe_keyswitch_internal(res_dft, res, key, scratch_1) }; + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); for i in 0..(res.rank() + 1).into() { self.vec_znx_big_normalize( base2k_res, @@ -217,14 +212,14 @@ pub trait GLWEKeyswitch { fn glwe_keyswitch(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) where - R: GLWEToMut, - A: GLWEToRef, - K: GGLWEPreparedToRef; + R: GLWEToMut + GLWEInfos, + A: GLWEToRef + GLWEInfos, + K: GGLWEPreparedToRef + GGLWEInfos; fn glwe_keyswitch_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) where - R: GLWEToMut, - K: GGLWEPreparedToRef; + R: GLWEToMut + GLWEInfos, + K: GGLWEPreparedToRef + GGLWEInfos; } impl GLWEKeySwitchInternal for Module where diff --git a/poulpy-core/src/tests/test_suite/automorphism/glwe_ct.rs b/poulpy-core/src/tests/test_suite/automorphism/glwe_ct.rs index 5b43f44..0fb2299 100644 --- a/poulpy-core/src/tests/test_suite/automorphism/glwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/automorphism/glwe_ct.rs @@ -160,7 +160,7 @@ where let k_out: usize = 102; let max_dsize: usize = k_out.div_ceil(base2k_key); - let p = -5; + let p: i64 = -5; for rank in 1_usize..3 { for dsize in 1..max_dsize + 1 { let k_ksk: usize = k_out + base2k_key * dsize; diff --git a/poulpy-core/src/tests/test_suite/external_product/gglwe_ksk.rs b/poulpy-core/src/tests/test_suite/external_product/gglwe_ksk.rs index 07a0926..f8dd1b5 100644 --- a/poulpy-core/src/tests/test_suite/external_product/gglwe_ksk.rs +++ b/poulpy-core/src/tests/test_suite/external_product/gglwe_ksk.rs @@ -27,24 +27,28 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k: usize = 12; - let k_in: usize = 60; - let dsize: usize = k_in.div_ceil(base2k); + let base2k_in: usize = 17; + let base2k_key: usize = 13; + let base2k_out: usize = base2k_in; // MUST BE SAME + let k_in: usize = 102; + let max_dsize: usize = k_in.div_ceil(base2k_key); + for rank_in in 1_usize..3 { for rank_out in 1_usize..3 { - for di in 1_usize..dsize + 1 { - let k_ggsw: usize = k_in + base2k * di; + for dsize in 1_usize..max_dsize + 1 { + let k_ggsw: usize = k_in + base2k_key * dsize; let k_out: usize = k_in; // Better capture noise. let n: usize = module.n(); - let dnum: usize = k_in.div_ceil(base2k * di); + let dnum_in: usize = k_in / base2k_in; + let dnum: usize = k_in.div_ceil(base2k_key * dsize); let dsize_in: usize = 1; let gglwe_in_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_in.into(), k: k_in.into(), - dnum: dnum.into(), + dnum: dnum_in.into(), dsize: dsize_in.into(), rank_in: rank_in.into(), rank_out: rank_out.into(), @@ -52,9 +56,9 @@ where let gglwe_out_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_out.into(), k: k_out.into(), - dnum: dnum.into(), + dnum: dnum_in.into(), dsize: dsize_in.into(), rank_in: rank_in.into(), rank_out: rank_out.into(), @@ -62,10 +66,10 @@ where let ggsw_infos: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_ggsw.into(), dnum: dnum.into(), - dsize: di.into(), + dsize: dsize.into(), rank: rank_out.into(), }; @@ -143,7 +147,7 @@ where let max_noise: f64 = noise_ggsw_product( n as f64, - base2k * di, + base2k_key * dsize, var_xs, var_msg, var_a0_err, @@ -176,24 +180,27 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k: usize = 12; - let k_out: usize = 60; - let dsize: usize = k_out.div_ceil(base2k); + let base2k_out: usize = 17; + let base2k_key: usize = 13; + let k_out: usize = 102; + let max_dsize: usize = k_out.div_ceil(base2k_key); + for rank_in in 1_usize..3 { for rank_out in 1_usize..3 { - for di in 1_usize..dsize + 1 { - let k_ggsw: usize = k_out + base2k * di; + for dsize in 1_usize..max_dsize + 1 { + let k_ggsw: usize = k_out + base2k_key * dsize; let n: usize = module.n(); - let dnum: usize = k_out.div_ceil(base2k * di); + let dnum_in: usize = k_out / base2k_out; + let dnum: usize = k_out.div_ceil(base2k_key * dsize); let dsize_in: usize = 1; let gglwe_out_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_out.into(), k: k_out.into(), - dnum: dnum.into(), + dnum: dnum_in.into(), dsize: dsize_in.into(), rank_in: rank_in.into(), rank_out: rank_out.into(), @@ -201,10 +208,10 @@ where let ggsw_infos: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_ggsw.into(), dnum: dnum.into(), - dsize: di.into(), + dsize: dsize.into(), rank: rank_out.into(), }; @@ -281,7 +288,7 @@ where let max_noise: f64 = noise_ggsw_product( n as f64, - base2k * di, + base2k_key * dsize, var_xs, var_msg, var_a0_err, diff --git a/poulpy-core/src/tests/test_suite/external_product/ggsw_ct.rs b/poulpy-core/src/tests/test_suite/external_product/ggsw_ct.rs index b455b70..6c9a3a9 100644 --- a/poulpy-core/src/tests/test_suite/external_product/ggsw_ct.rs +++ b/poulpy-core/src/tests/test_suite/external_product/ggsw_ct.rs @@ -26,23 +26,26 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k: usize = 12; - let k_in: usize = 60; - let dsize: usize = k_in.div_ceil(base2k); + let base2k_in: usize = 17; + let base2k_key: usize = 13; + let base2k_out: usize = base2k_in; // MUST BE SAME + let k_in: usize = 102; + let max_dsize: usize = k_in.div_ceil(base2k_key); + for rank in 1_usize..3 { - for di in 1..dsize + 1 { - let k_apply: usize = k_in + base2k * di; + for dsize in 1..max_dsize + 1 { + let k_apply: usize = k_in + base2k_key * dsize; let k_out: usize = k_in; // Better capture noise. let n: usize = module.n(); - let dnum: usize = k_in.div_ceil(base2k * di); - let dnum_in: usize = k_in.div_euclid(base2k * di); + let dnum: usize = k_in.div_ceil(base2k_key * dsize); + let dnum_in: usize = k_in / base2k_in; let dsize_in: usize = 1; let ggsw_in_infos: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_in.into(), k: k_in.into(), dnum: dnum_in.into(), dsize: dsize_in.into(), @@ -51,7 +54,7 @@ where let ggsw_out_infos: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_out.into(), k: k_out.into(), dnum: dnum_in.into(), dsize: dsize_in.into(), @@ -60,10 +63,10 @@ where let ggsw_apply_infos: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_apply.into(), dnum: dnum.into(), - dsize: di.into(), + dsize: dsize.into(), rank: rank.into(), }; @@ -130,7 +133,7 @@ where let max_noise = |_col_j: usize| -> f64 { noise_ggsw_product( n as f64, - base2k * di, + base2k_key * dsize, 0.5, var_msg, var_a0_err, @@ -160,21 +163,23 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k: usize = 12; - let k_out: usize = 60; - let dsize: usize = k_out.div_ceil(base2k); + let base2k_out: usize = 17; + let base2k_key: usize = 13; + let k_out: usize = 102; + let max_dsize: usize = k_out.div_ceil(base2k_key); + for rank in 1_usize..3 { - for di in 1..dsize + 1 { - let k_apply: usize = k_out + base2k * di; + for dsize in 1..max_dsize + 1 { + let k_apply: usize = k_out + base2k_key * dsize; let n: usize = module.n(); - let dnum: usize = k_out.div_ceil(di * base2k); - let dnum_in: usize = k_out.div_euclid(base2k * di); + let dnum: usize = k_out.div_ceil(dsize * base2k_key); + let dnum_in: usize = k_out / base2k_out; let dsize_in: usize = 1; let ggsw_out_infos: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_out.into(), k: k_out.into(), dnum: dnum_in.into(), dsize: dsize_in.into(), @@ -183,10 +188,10 @@ where let ggsw_apply_infos: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_apply.into(), dnum: dnum.into(), - dsize: di.into(), + dsize: dsize.into(), rank: rank.into(), }; @@ -253,7 +258,7 @@ where let max_noise = |_col_j: usize| -> f64 { noise_ggsw_product( n as f64, - base2k * di, + base2k_key * dsize, 0.5, var_msg, var_a0_err, diff --git a/poulpy-core/src/tests/test_suite/external_product/glwe_ct.rs b/poulpy-core/src/tests/test_suite/external_product/glwe_ct.rs index 0425d35..e071be9 100644 --- a/poulpy-core/src/tests/test_suite/external_product/glwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/external_product/glwe_ct.rs @@ -5,7 +5,7 @@ use poulpy_hal::{ }; use crate::{ - GGSWEncryptSk, GLWEEncryptSk, GLWEExternalProduct, GLWENoise, ScratchTakeCore, + GGSWEncryptSk, GLWEEncryptSk, GLWEExternalProduct, GLWENoise, GLWENormalize, ScratchTakeCore, encryption::SIGMA, layouts::{ GGSW, GGSWLayout, GGSWPreparedFactory, GLWE, GLWELayout, GLWEPlaintext, GLWESecret, GLWESecretPreparedFactory, @@ -24,41 +24,44 @@ where + GLWEEncryptSk + GLWENoise + VecZnxRotateInplace - + GLWESecretPreparedFactory, + + GLWESecretPreparedFactory + + GLWENormalize, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k: usize = 12; - let k_in: usize = 45; - let dsize: usize = k_in.div_ceil(base2k); + let base2k_in: usize = 17; + let base2k_key: usize = 13; + let base2k_out: usize = 15; + let k_in: usize = 102; + let max_dsize: usize = k_in.div_ceil(base2k_key); for rank in 1_usize..3 { - for di in 1..dsize + 1 { - let k_ggsw: usize = k_in + base2k * di; + for dsize in 1..max_dsize + 1 { + let k_ggsw: usize = k_in + base2k_key * dsize; let k_out: usize = k_ggsw; // Better capture noise let n: usize = module.n(); - let dnum: usize = k_in.div_ceil(base2k * dsize); + let dnum: usize = k_in.div_ceil(k_ggsw * dsize); let glwe_in_infos: GLWELayout = GLWELayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_in.into(), k: k_in.into(), rank: rank.into(), }; let glwe_out_infos: GLWELayout = GLWELayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_out.into(), k: k_out.into(), rank: rank.into(), }; let ggsw_apply_infos: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_ggsw.into(), dnum: dnum.into(), - dsize: di.into(), + dsize: dsize.into(), rank: rank.into(), }; @@ -66,16 +69,17 @@ where let mut glwe_in: GLWE> = GLWE::alloc_from_infos(&glwe_in_infos); let mut glwe_out: GLWE> = GLWE::alloc_from_infos(&glwe_out_infos); let mut pt_ggsw: ScalarZnx> = ScalarZnx::alloc(n, 1); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_in_infos); + let mut pt_in: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_in_infos); + let mut pt_out: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_out_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); // Random input plaintext - module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa); + module.vec_znx_fill_uniform(base2k_in, &mut pt_in.data, 0, &mut source_xa); - pt_want.data.at_mut(0, 0)[1] = 1; + pt_in.data.at_mut(0, 0)[1] = 1; let k: usize = 1; @@ -104,7 +108,7 @@ where glwe_in.encrypt_sk( module, - &pt_want, + &pt_in, &sk_prepared, &mut source_xa, &mut source_xe, @@ -116,7 +120,9 @@ where glwe_out.external_product(module, &glwe_in, &ct_ggsw_prepared, scratch.borrow()); - module.vec_znx_rotate_inplace(k as i64, &mut pt_want.data, 0, scratch.borrow()); + module.vec_znx_rotate_inplace(k as i64, &mut pt_in.data, 0, scratch.borrow()); + + module.glwe_normalize(&mut pt_out, &pt_in, scratch.borrow()); let var_gct_err_lhs: f64 = SIGMA * SIGMA; let var_gct_err_rhs: f64 = 0f64; @@ -127,7 +133,7 @@ where let max_noise: f64 = noise_ggsw_product( n as f64, - base2k * dsize, + base2k_key * max_dsize, 0.5, var_msg, var_a0_err, @@ -139,7 +145,7 @@ where k_ggsw, ); - glwe_out.assert_noise(module, &sk_prepared, &pt_want, max_noise + 0.5); + glwe_out.assert_noise(module, &sk_prepared, &pt_out, max_noise + 0.5); } } } @@ -158,29 +164,31 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k: usize = 12; - let k_out: usize = 60; - let dsize: usize = k_out.div_ceil(base2k); + let base2k_out: usize = 17; + let base2k_key: usize = 13; + let k_out: usize = 102; + let max_dsize: usize = k_out.div_ceil(base2k_key); + for rank in 1_usize..3 { - for di in 1..dsize + 1 { - let k_ggsw: usize = k_out + base2k * di; + for dsize in 1..max_dsize + 1 { + let k_ggsw: usize = k_out + base2k_key * dsize; let n: usize = module.n(); - let dnum: usize = k_out.div_ceil(base2k * dsize); + let dnum: usize = k_out.div_ceil(base2k_out * max_dsize); let glwe_out_infos: GLWELayout = GLWELayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_out.into(), k: k_out.into(), rank: rank.into(), }; let ggsw_apply_infos: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_ggsw.into(), dnum: dnum.into(), - dsize: di.into(), + dsize: dsize.into(), rank: rank.into(), }; @@ -194,7 +202,7 @@ where let mut source_xa: Source = Source::new([0u8; 32]); // Random input plaintext - module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa); + module.vec_znx_fill_uniform(base2k_out, &mut pt_want.data, 0, &mut source_xa); pt_want.data.at_mut(0, 0)[1] = 1; @@ -248,7 +256,7 @@ where let max_noise: f64 = noise_ggsw_product( n as f64, - base2k * dsize, + base2k_key * max_dsize, 0.5, var_msg, var_a0_err, diff --git a/poulpy-hal/src/reference/vec_znx/normalize.rs b/poulpy-hal/src/reference/vec_znx/normalize.rs index 6392f84..139c8a5 100644 --- a/poulpy-hal/src/reference/vec_znx/normalize.rs +++ b/poulpy-hal/src/reference/vec_znx/normalize.rs @@ -53,7 +53,7 @@ pub fn vec_znx_normalize( let res_size: usize = res.size(); let a_size: usize = a.size(); - let carry = &mut carry[..2 * n]; + let carry: &mut [i64] = &mut carry[..2 * n]; if res_base2k == a_base2k { if a_size > res_size {