diff --git a/poulpy-core/src/automorphism/gglwe_atk.rs b/poulpy-core/src/automorphism/gglwe_atk.rs index 9650aa2..db57e60 100644 --- a/poulpy-core/src/automorphism/gglwe_atk.rs +++ b/poulpy-core/src/automorphism/gglwe_atk.rs @@ -1,198 +1,169 @@ use poulpy_hal::{ - api::{ - ScratchAvailable, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigNormalize, - VecZnxBigNormalizeTmpBytes, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, - VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, - }, - layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero}, + api::VecZnxAutomorphism, + layouts::{Backend, DataMut, GaloisElement, Module, Scratch}, }; -use crate::layouts::{AutomorphismKey, GGLWEInfos, GLWE, prepared::AutomorphismKeyPrepared}; +use crate::{ + ScratchTakeCore, + automorphism::glwe_ct::GLWEAutomorphism, + layouts::{ + AutomorphismKey, AutomorphismKeyToMut, AutomorphismKeyToRef, GGLWEInfos, GLWE, GLWEInfos, + prepared::{ + AutomorphismKeyPrepared, AutomorphismKeyPreparedToRef, GetAutomorphismGaloisElement, SetAutomorphismGaloisElement, + }, + }, +}; impl AutomorphismKey> { - pub fn automorphism_tmp_bytes( - module: &Module, - out_infos: &OUT, - in_infos: &IN, - key_infos: &KEY, - ) -> usize + pub fn automorphism_tmp_bytes(module: &M, res_infos: &R, a_infos: &A, key_infos: &K) -> usize where - OUT: GGLWEInfos, - IN: GGLWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, + R: GGLWEInfos, + A: GGLWEInfos, + K: GGLWEInfos, + M: AutomorphismKeyAutomorphism, { - GLWE::keyswitch_tmp_bytes( - module, - &out_infos.glwe_layout(), - &in_infos.glwe_layout(), - key_infos, - ) - } - - pub fn automorphism_inplace_tmp_bytes(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize - where - OUT: GGLWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, - { - AutomorphismKey::automorphism_tmp_bytes(module, out_infos, out_infos, key_infos) + module.automorphism_key_automorphism_tmp_bytes(res_infos, a_infos, key_infos) } } impl AutomorphismKey { - pub fn automorphism( - &mut self, - module: &Module, - lhs: &AutomorphismKey, - rhs: &AutomorphismKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxAutomorphism - + VecZnxAutomorphismInplace - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable, + pub fn automorphism(&mut self, module: &M, a: &A, key: &K, scratch: &mut Scratch) + where + A: AutomorphismKeyToRef + GetAutomorphismGaloisElement, + K: AutomorphismKeyPreparedToRef + GetAutomorphismGaloisElement, + Scratch: ScratchTakeCore, + M: AutomorphismKeyAutomorphism, { - #[cfg(debug_assertions)] - { - use crate::layouts::LWEInfos; - - assert_eq!( - self.rank_in(), - lhs.rank_in(), - "ksk_out input rank: {} != ksk_in input rank: {}", - self.rank_in(), - lhs.rank_in() - ); - assert_eq!( - self.rank_out(), - rhs.rank_in(), - "ksk_in output rank: {} != ksk_apply input rank: {}", - self.rank_out(), - rhs.rank_in() - ); - assert_eq!( - self.rank_out(), - rhs.rank_out(), - "ksk_out output rank: {} != ksk_apply output rank: {}", - self.rank_out(), - rhs.rank_out() - ); - assert!( - self.k() <= lhs.k(), - "output k={} cannot be greater than input k={}", - self.k(), - lhs.k() - ) - } - - let cols_out: usize = (rhs.rank_out() + 1).into(); - - let p: i64 = lhs.p(); - let p_inv: i64 = module.galois_element_inv(p); - - (0..self.rank_in().into()).for_each(|col_i| { - (0..self.dnum().into()).for_each(|row_j| { - let mut res_ct: GLWE<&mut [u8]> = self.at_mut(row_j, col_i); - let lhs_ct: GLWE<&[u8]> = lhs.at(row_j, col_i); - - // Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a) - (0..cols_out).for_each(|i| { - module.vec_znx_automorphism(lhs.p(), &mut res_ct.data, i, &lhs_ct.data, i); - }); - - // Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) - res_ct.keyswitch_inplace(module, &rhs.key, scratch); - - // Applies back the automorphism X^{-k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) to (-pi^{-1}_{k'+k}(s)a + s, a) - (0..cols_out).for_each(|i| { - module.vec_znx_automorphism_inplace(p_inv, &mut res_ct.data, i, scratch); - }); - }); - }); - - (self.dnum().min(lhs.dnum()).into()..self.dnum().into()).for_each(|row_i| { - (0..self.rank_in().into()).for_each(|col_j| { - self.at_mut(row_i, col_j).data.zero(); - }); - }); - - self.p = (lhs.p * rhs.p) % (module.cyclotomic_order() as i64); + module.automorphism_key_automorphism(self, a, key, scratch); } - pub fn automorphism_inplace( - &mut self, - module: &Module, - rhs: &AutomorphismKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxAutomorphism - + VecZnxAutomorphismInplace - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable, + pub fn automorphism_inplace(&mut self, module: &M, key: &K, scratch: &mut Scratch) + where + K: AutomorphismKeyPreparedToRef + GetAutomorphismGaloisElement, + Scratch: ScratchTakeCore, + M: AutomorphismKeyAutomorphism, { - #[cfg(debug_assertions)] - { - assert_eq!( - self.rank_out(), - rhs.rank_in(), - "ksk_in output rank: {} != ksk_apply input rank: {}", - self.rank_out(), - rhs.rank_in() - ); - assert_eq!( - self.rank_out(), - rhs.rank_out(), - "ksk_out output rank: {} != ksk_apply output rank: {}", - self.rank_out(), - rhs.rank_out() - ); - } - - let cols_out: usize = (rhs.rank_out() + 1).into(); - - let p: i64 = self.p(); - let p_inv = module.galois_element_inv(p); - - (0..self.rank_in().into()).for_each(|col_i| { - (0..self.dnum().into()).for_each(|row_j| { - let mut res_ct: GLWE<&mut [u8]> = self.at_mut(row_j, col_i); - - // Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a) - (0..cols_out).for_each(|i| { - module.vec_znx_automorphism_inplace(p_inv, &mut res_ct.data, i, scratch); - }); - - // Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) - res_ct.keyswitch_inplace(module, &rhs.key, scratch); - - // Applies back the automorphism X^{-k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) to (-pi^{-1}_{k'+k}(s)a + s, a) - (0..cols_out).for_each(|i| { - module.vec_znx_automorphism_inplace(p_inv, &mut res_ct.data, i, scratch); - }); - }); - }); - - self.p = (self.p * rhs.p) % (module.cyclotomic_order() as i64); + module.automorphism_key_automorphism_inplace(self, key, scratch); + } +} + +impl AutomorphismKeyAutomorphism for Module where + Self: GaloisElement + GLWEAutomorphism + VecZnxAutomorphism +{ +} + +pub trait AutomorphismKeyAutomorphism +where + Self: GaloisElement + GLWEAutomorphism + VecZnxAutomorphism, +{ + fn automorphism_key_automorphism_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize + where + R: GGLWEInfos, + A: GGLWEInfos, + K: GGLWEInfos, + { + self.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos) + } + + fn automorphism_key_automorphism(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) + where + R: AutomorphismKeyToMut + SetAutomorphismGaloisElement, + A: AutomorphismKeyToRef + GetAutomorphismGaloisElement, + K: AutomorphismKeyPreparedToRef + GetAutomorphismGaloisElement, + Scratch: ScratchTakeCore, + { + { + let res: &mut AutomorphismKey<&mut [u8]> = &mut res.to_mut(); + let a: &AutomorphismKey<&[u8]> = &a.to_ref(); + let key: &AutomorphismKeyPrepared<&[u8], _> = &key.to_ref(); + + assert!( + res.dnum().as_u32() <= a.dnum().as_u32(), + "res dnum: {} > a dnum: {}", + res.dnum(), + a.dnum() + ); + + assert_eq!( + res.dsize(), + a.dsize(), + "res dnum: {} != a dnum: {}", + res.dsize(), + a.dsize() + ); + + let cols_out: usize = (key.rank_out() + 1).into(); + + let p: i64 = a.p(); + let p_inv: i64 = self.galois_element_inv(p); + + for row in 0..res.dnum().as_usize() { + for col in 0..cols_out { + let mut res_tmp: GLWE<&mut [u8]> = res.at_mut(row, col); + let a_ct: GLWE<&[u8]> = a.at(row, col); + + // Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a) + for i in 0..cols_out { + self.vec_znx_automorphism(a.p(), res_tmp.data_mut(), i, &a_ct.data, i); + } + + // Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) + self.glwe_keyswitch_inplace(&mut res_tmp, &key.key, scratch); + + // Applies back the automorphism X^{-k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) to (-pi^{-1}_{k'+k}(s)a + s, a) + (0..cols_out).for_each(|i| { + self.vec_znx_automorphism_inplace(p_inv, res_tmp.data_mut(), i, scratch); + }); + } + } + } + + res.set_p((a.p() * key.p()) % (self.cyclotomic_order() as i64)); + } + + fn automorphism_key_automorphism_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) + where + R: AutomorphismKeyToMut + SetAutomorphismGaloisElement + GetAutomorphismGaloisElement, + K: AutomorphismKeyPreparedToRef + GetAutomorphismGaloisElement, + Scratch: ScratchTakeCore, + { + { + let res: &mut AutomorphismKey<&mut [u8]> = &mut res.to_mut(); + let key: &AutomorphismKeyPrepared<&[u8], _> = &key.to_ref(); + + assert_eq!( + res.rank(), + key.rank(), + "key rank: {} != key rank: {}", + res.rank(), + key.rank() + ); + + let cols_out: usize = (key.rank_out() + 1).into(); + + let p: i64 = res.p(); + let p_inv: i64 = self.galois_element_inv(p); + + for row in 0..res.dnum().as_usize() { + for col in 0..cols_out { + let mut res_tmp: GLWE<&mut [u8]> = res.at_mut(row, col); + + // Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a) + for i in 0..cols_out { + self.vec_znx_automorphism_inplace(p_inv, res_tmp.data_mut(), i, scratch); + } + + // Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) + self.glwe_keyswitch_inplace(&mut res_tmp, &key.key, scratch); + + // Applies back the automorphism X^{-k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) to (-pi^{-1}_{k'+k}(s)a + s, a) + for i in 0..cols_out { + self.vec_znx_automorphism_inplace(p_inv, res_tmp.data_mut(), i, scratch); + } + } + } + } + + res.set_p((res.p() * key.p()) % (self.cyclotomic_order() as i64)); } } diff --git a/poulpy-core/src/keyswitching/gglwe_ct.rs b/poulpy-core/src/keyswitching/gglwe_ct.rs index 363f394..4b92ed1 100644 --- a/poulpy-core/src/keyswitching/gglwe_ct.rs +++ b/poulpy-core/src/keyswitching/gglwe_ct.rs @@ -16,7 +16,7 @@ impl AutomorphismKey> { R: GGLWEInfos, A: GGLWEInfos, K: GGLWEInfos, - M: GGLWEKeySwitch, + M: GGLWEKeyswitch, { module.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos) } @@ -28,7 +28,7 @@ impl AutomorphismKey { A: AutomorphismKeyToRef, B: GLWESwitchingKeyPreparedToRef, Scratch: ScratchTakeCore, - M: GGLWEKeySwitch, + M: GGLWEKeyswitch, { module.gglwe_keyswitch(&mut self.key.key, &a.to_ref().key.key, b, scratch); } @@ -37,7 +37,7 @@ impl AutomorphismKey { where A: GLWESwitchingKeyPreparedToRef, Scratch: ScratchTakeCore, - M: GGLWEKeySwitch, + M: GGLWEKeyswitch, { module.gglwe_keyswitch_inplace(&mut self.key.key, a, scratch); } @@ -49,7 +49,7 @@ impl GLWESwitchingKey> { R: GGLWEInfos, A: GGLWEInfos, K: GGLWEInfos, - M: GGLWEKeySwitch, + M: GGLWEKeyswitch, { module.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos) } @@ -61,7 +61,7 @@ impl GLWESwitchingKey { A: GLWESwitchingKeyToRef, B: GLWESwitchingKeyPreparedToRef, Scratch: ScratchTakeCore, - M: GGLWEKeySwitch, + M: GGLWEKeyswitch, { module.gglwe_keyswitch(&mut self.key, &a.to_ref().key, b, scratch); } @@ -70,7 +70,7 @@ impl GLWESwitchingKey { where A: GLWESwitchingKeyPreparedToRef, Scratch: ScratchTakeCore, - M: GGLWEKeySwitch, + M: GGLWEKeyswitch, { module.gglwe_keyswitch_inplace(&mut self.key, a, scratch); } @@ -82,7 +82,7 @@ impl GGLWE> { R: GGLWEInfos, A: GGLWEInfos, K: GGLWEInfos, - M: GGLWEKeySwitch, + M: GGLWEKeyswitch, { module.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos) } @@ -94,7 +94,7 @@ impl GGLWE { A: GGLWEToRef, B: GLWESwitchingKeyPreparedToRef, Scratch: ScratchTakeCore, - M: GGLWEKeySwitch, + M: GGLWEKeyswitch, { module.gglwe_keyswitch(self, a, b, scratch); } @@ -103,15 +103,15 @@ impl GGLWE { where A: GLWESwitchingKeyPreparedToRef, Scratch: ScratchTakeCore, - M: GGLWEKeySwitch, + M: GGLWEKeyswitch, { module.gglwe_keyswitch_inplace(self, a, scratch); } } -impl GGLWEKeySwitch for Module where Self: GLWEKeyswitch {} +impl GGLWEKeyswitch for Module where Self: GLWEKeyswitch {} -pub trait GGLWEKeySwitch +pub trait GGLWEKeyswitch where Self: GLWEKeyswitch, { diff --git a/poulpy-core/src/keyswitching/ggsw_ct.rs b/poulpy-core/src/keyswitching/ggsw_ct.rs index c864975..67f4278 100644 --- a/poulpy-core/src/keyswitching/ggsw_ct.rs +++ b/poulpy-core/src/keyswitching/ggsw_ct.rs @@ -22,7 +22,7 @@ impl GGSW> { A: GGSWInfos, K: GGLWEInfos, T: GGLWEInfos, - M: GGSWKeySwitch, + M: GGSWKeyswitch, { module.ggsw_keyswitch_tmp_bytes(res_infos, a_infos, key_infos, tsk_infos) } @@ -35,7 +35,7 @@ impl GGSW { K: GLWESwitchingKeyPreparedToRef, T: TensorKeyPreparedToRef, Scratch: ScratchTakeCore, - M: GGSWKeySwitch, + M: GGSWKeyswitch, { module.ggsw_keyswitch(self, a, key, tsk, scratch); } @@ -45,13 +45,13 @@ impl GGSW { K: GLWESwitchingKeyPreparedToRef, T: TensorKeyPreparedToRef, Scratch: ScratchTakeCore, - M: GGSWKeySwitch, + M: GGSWKeyswitch, { module.ggsw_keyswitch_inplace(self, key, tsk, scratch); } } -pub trait GGSWKeySwitch +pub trait GGSWKeyswitch where Self: GLWEKeyswitch + GGSWExpandRows, { diff --git a/poulpy-core/src/keyswitching/glwe_ct.rs b/poulpy-core/src/keyswitching/glwe_ct.rs index de48cbf..f82a4d1 100644 --- a/poulpy-core/src/keyswitching/glwe_ct.rs +++ b/poulpy-core/src/keyswitching/glwe_ct.rs @@ -16,14 +16,14 @@ use crate::{ }; impl GLWE> { - pub fn keyswitch_tmp_bytes(module: &M, res_infos: &R, a_infos: &A, b_infos: &B) -> usize + pub fn keyswitch_tmp_bytes(module: &M, res_infos: &R, a_infos: &A, key_infos: &B) -> usize where R: GLWEInfos, A: GLWEInfos, B: GGLWEInfos, M: GLWEKeyswitch, { - module.glwe_keyswitch_tmp_bytes(res_infos, a_infos, b_infos) + module.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos) } } @@ -89,7 +89,7 @@ where + VecZnxNormalize + VecZnxNormalizeTmpBytes, { - fn glwe_keyswitch_tmp_bytes(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize + fn glwe_keyswitch_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &B) -> usize where R: GLWEInfos, A: GLWEInfos, @@ -97,44 +97,44 @@ where { let in_size: usize = a_infos .k() - .div_ceil(b_infos.base2k()) - .div_ceil(b_infos.dsize().into()) as usize; + .div_ceil(key_infos.base2k()) + .div_ceil(key_infos.dsize().into()) as usize; let out_size: usize = res_infos.size(); - let ksk_size: usize = b_infos.size(); - let res_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank_out() + 1).into(), ksk_size); // TODO OPTIMIZE - let ai_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank_in()).into(), in_size); + let ksk_size: usize = key_infos.size(); + let res_dft: usize = self.bytes_of_vec_znx_dft((key_infos.rank_out() + 1).into(), ksk_size); // TODO OPTIMIZE + let ai_dft: usize = self.bytes_of_vec_znx_dft((key_infos.rank_in()).into(), in_size); let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes( out_size, in_size, in_size, - (b_infos.rank_in()).into(), - (b_infos.rank_out() + 1).into(), + (key_infos.rank_in()).into(), + (key_infos.rank_out() + 1).into(), ksk_size, - ) + self.bytes_of_vec_znx_dft((b_infos.rank_in()).into(), in_size); + ) + self.bytes_of_vec_znx_dft((key_infos.rank_in()).into(), in_size); let normalize_big: usize = self.vec_znx_big_normalize_tmp_bytes(); - if a_infos.base2k() == b_infos.base2k() { + if a_infos.base2k() == key_infos.base2k() { res_dft + ((ai_dft + vmp) | normalize_big) - } else if b_infos.dsize() == 1 { + } else if key_infos.dsize() == 1 { // In this case, we only need one column, temporary, that we can drop once a_dft is computed. let normalize_conv: usize = VecZnx::bytes_of(self.n(), 1, in_size) + self.vec_znx_normalize_tmp_bytes(); res_dft + (((ai_dft + normalize_conv) | vmp) | normalize_big) } else { // Since we stride over a to get a_dft when dsize > 1, we need to store the full columns of a with in the base conversion. - let normalize_conv: usize = VecZnx::bytes_of(self.n(), (b_infos.rank_in()).into(), in_size); + let normalize_conv: usize = VecZnx::bytes_of(self.n(), (key_infos.rank_in()).into(), in_size); res_dft + ((ai_dft + normalize_conv + (self.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big) } } - fn glwe_keyswitch(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) + fn glwe_keyswitch(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) where R: GLWEToMut, A: GLWEToRef, - B: GLWESwitchingKeyPreparedToRef, + K: GLWESwitchingKeyPreparedToRef, Scratch: ScratchTakeCore, { let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let a: &GLWE<&[u8]> = &a.to_ref(); - let b: &GLWESwitchingKeyPrepared<&[u8], BE> = &b.to_ref(); + let b: &GLWESwitchingKeyPrepared<&[u8], BE> = &key.to_ref(); assert_eq!( a.rank(), @@ -181,14 +181,14 @@ where }) } - fn glwe_keyswitch_inplace(&self, res: &mut R, a: &A, scratch: &mut Scratch) + fn glwe_keyswitch_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) where R: GLWEToMut, - A: GLWESwitchingKeyPreparedToRef, + K: GLWESwitchingKeyPreparedToRef, Scratch: ScratchTakeCore, { let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); - let a: &GLWESwitchingKeyPrepared<&[u8], BE> = &a.to_ref(); + let a: &GLWESwitchingKeyPrepared<&[u8], BE> = &key.to_ref(); assert_eq!( res.rank(), @@ -243,7 +243,7 @@ pub(crate) fn keyswitch_internal( module: &M, mut res: VecZnxDft, a: &GLWE, - b: &GLWESwitchingKeyPrepared, + key: &GLWESwitchingKeyPrepared, scratch: &mut Scratch, ) -> VecZnxBig where @@ -265,12 +265,12 @@ where Scratch: ScratchTakeCore, { let base2k_in: usize = a.base2k().into(); - let base2k_out: usize = b.base2k().into(); + let base2k_out: usize = key.base2k().into(); let cols: usize = (a.rank() + 1).into(); let a_size: usize = (a.size() * base2k_in).div_ceil(base2k_out); - let pmat: &VmpPMat = &b.key.data; + let pmat: &VmpPMat = &key.key.data; - if b.dsize() == 1 { + if key.dsize() == 1 { let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols - 1, a.size()); if base2k_in == base2k_out { @@ -295,7 +295,7 @@ where module.vmp_apply_dft_to_dft(&mut res, &ai_dft, pmat, scratch_1); } else { - let dsize: usize = b.dsize().into(); + let dsize: usize = key.dsize().into(); let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols - 1, a_size.div_ceil(dsize)); ai_dft.data_mut().fill(0); diff --git a/poulpy-core/src/layouts/gglwe_atk.rs b/poulpy-core/src/layouts/gglwe_atk.rs index eb93bf4..3c0afc1 100644 --- a/poulpy-core/src/layouts/gglwe_atk.rs +++ b/poulpy-core/src/layouts/gglwe_atk.rs @@ -6,6 +6,7 @@ use poulpy_hal::{ use crate::layouts::{ Base2K, Dnum, Dsize, GGLWEInfos, GLWE, GLWEInfos, GLWESwitchingKey, GLWESwitchingKeyAlloc, GLWESwitchingKeyToMut, GLWESwitchingKeyToRef, LWEInfos, Rank, RingDegree, TorusPrecision, + prepared::{GetAutomorphismGaloisElement, SetAutomorphismGaloisElement}, }; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; @@ -27,6 +28,18 @@ pub struct AutomorphismKey { pub(crate) p: i64, } +impl SetAutomorphismGaloisElement for AutomorphismKey { + fn set_p(&mut self, p: i64) { + self.p = p + } +} + +impl GetAutomorphismGaloisElement for AutomorphismKey { + fn p(&self) -> i64 { + self.p + } +} + impl AutomorphismKey { pub fn p(&self) -> i64 { self.p diff --git a/poulpy-hal/src/layouts/module.rs b/poulpy-hal/src/layouts/module.rs index 61e312c..44fccbc 100644 --- a/poulpy-hal/src/layouts/module.rs +++ b/poulpy-hal/src/layouts/module.rs @@ -2,7 +2,7 @@ use std::{fmt::Display, marker::PhantomData, ptr::NonNull}; use rand_distr::num_traits::Zero; -use crate::GALOISGENERATOR; +use crate::{GALOISGENERATOR, api::ModuleN}; #[allow(clippy::missing_safety_doc)] pub trait Backend: Sized { @@ -75,36 +75,47 @@ impl Module { pub fn log_n(&self) -> usize { (usize::BITS - (self.n() - 1).leading_zeros()) as _ } +} - #[inline] - pub fn cyclotomic_order(&self) -> u64 { +pub trait CyclotomicOrder +where + Self: ModuleN, +{ + fn cyclotomic_order(&self) -> i64 { (self.n() << 1) as _ } +} +impl CyclotomicOrder for Module where Self: ModuleN {} + +pub trait GaloisElement +where + Self: CyclotomicOrder, +{ // Returns GALOISGENERATOR^|generator| * sign(generator) - #[inline] - pub fn galois_element(&self, generator: i64) -> i64 { + fn galois_element(&self, generator: i64) -> i64 { if generator == 0 { return 1; } - ((mod_exp_u64(GALOISGENERATOR, generator.unsigned_abs() as usize) & (self.cyclotomic_order() - 1)) as i64) - * generator.signum() + + let g_exp: u64 = mod_exp_u64(GALOISGENERATOR, generator.unsigned_abs() as usize) & (self.cyclotomic_order() - 1) as u64; + g_exp as i64 * generator.signum() } // Returns gen^-1 - #[inline] - pub fn galois_element_inv(&self, gal_el: i64) -> i64 { + fn galois_element_inv(&self, gal_el: i64) -> i64 { if gal_el == 0 { panic!("cannot invert 0") } - ((mod_exp_u64( - gal_el.unsigned_abs(), - (self.cyclotomic_order() - 1) as usize, - ) & (self.cyclotomic_order() - 1)) as i64) - * gal_el.signum() + + let g_exp: u64 = + mod_exp_u64(GALOISGENERATOR, (self.cyclotomic_order() - 1) as usize) & (self.cyclotomic_order() - 1) as u64; + g_exp as i64 * gal_el.signum() } } +impl GaloisElement for Module where Self: CyclotomicOrder {} + impl Drop for Module { fn drop(&mut self) { unsafe { B::destroy(self.ptr) }