diff --git a/poulpy-core/src/automorphism/gglwe_atk.rs b/poulpy-core/src/automorphism/gglwe_atk.rs index 9650aa2..db57e60 100644 --- a/poulpy-core/src/automorphism/gglwe_atk.rs +++ b/poulpy-core/src/automorphism/gglwe_atk.rs @@ -1,198 +1,169 @@ use poulpy_hal::{ - api::{ - ScratchAvailable, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigNormalize, - VecZnxBigNormalizeTmpBytes, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, - VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, - }, - layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero}, + api::VecZnxAutomorphism, + layouts::{Backend, DataMut, GaloisElement, Module, Scratch}, }; -use crate::layouts::{AutomorphismKey, GGLWEInfos, GLWE, prepared::AutomorphismKeyPrepared}; +use crate::{ + ScratchTakeCore, + automorphism::glwe_ct::GLWEAutomorphism, + layouts::{ + AutomorphismKey, AutomorphismKeyToMut, AutomorphismKeyToRef, GGLWEInfos, GLWE, GLWEInfos, + prepared::{ + AutomorphismKeyPrepared, AutomorphismKeyPreparedToRef, GetAutomorphismGaloisElement, SetAutomorphismGaloisElement, + }, + }, +}; impl AutomorphismKey> { - pub fn automorphism_tmp_bytes( - module: &Module, - out_infos: &OUT, - in_infos: &IN, - key_infos: &KEY, - ) -> usize + pub fn automorphism_tmp_bytes(module: &M, res_infos: &R, a_infos: &A, key_infos: &K) -> usize where - OUT: GGLWEInfos, - IN: GGLWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, + R: GGLWEInfos, + A: GGLWEInfos, + K: GGLWEInfos, + M: AutomorphismKeyAutomorphism, { - GLWE::keyswitch_tmp_bytes( - module, - &out_infos.glwe_layout(), - &in_infos.glwe_layout(), - key_infos, - ) - } - - pub fn automorphism_inplace_tmp_bytes(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize - where - OUT: GGLWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, - { - AutomorphismKey::automorphism_tmp_bytes(module, out_infos, out_infos, key_infos) + module.automorphism_key_automorphism_tmp_bytes(res_infos, a_infos, key_infos) } } impl AutomorphismKey { - pub fn automorphism( - &mut self, - module: &Module, - lhs: &AutomorphismKey, - rhs: &AutomorphismKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxAutomorphism - + VecZnxAutomorphismInplace - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable, + pub fn automorphism(&mut self, module: &M, a: &A, key: &K, scratch: &mut Scratch) + where + A: AutomorphismKeyToRef + GetAutomorphismGaloisElement, + K: AutomorphismKeyPreparedToRef + GetAutomorphismGaloisElement, + Scratch: ScratchTakeCore, + M: AutomorphismKeyAutomorphism, { - #[cfg(debug_assertions)] - { - use crate::layouts::LWEInfos; - - assert_eq!( - self.rank_in(), - lhs.rank_in(), - "ksk_out input rank: {} != ksk_in input rank: {}", - self.rank_in(), - lhs.rank_in() - ); - assert_eq!( - self.rank_out(), - rhs.rank_in(), - "ksk_in output rank: {} != ksk_apply input rank: {}", - self.rank_out(), - rhs.rank_in() - ); - assert_eq!( - self.rank_out(), - rhs.rank_out(), - "ksk_out output rank: {} != ksk_apply output rank: {}", - self.rank_out(), - rhs.rank_out() - ); - assert!( - self.k() <= lhs.k(), - "output k={} cannot be greater than input k={}", - self.k(), - lhs.k() - ) - } - - let cols_out: usize = (rhs.rank_out() + 1).into(); - - let p: i64 = lhs.p(); - let p_inv: i64 = module.galois_element_inv(p); - - (0..self.rank_in().into()).for_each(|col_i| { - (0..self.dnum().into()).for_each(|row_j| { - let mut res_ct: GLWE<&mut [u8]> = self.at_mut(row_j, col_i); - let lhs_ct: GLWE<&[u8]> = lhs.at(row_j, col_i); - - // Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a) - (0..cols_out).for_each(|i| { - module.vec_znx_automorphism(lhs.p(), &mut res_ct.data, i, &lhs_ct.data, i); - }); - - // Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) - res_ct.keyswitch_inplace(module, &rhs.key, scratch); - - // Applies back the automorphism X^{-k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) to (-pi^{-1}_{k'+k}(s)a + s, a) - (0..cols_out).for_each(|i| { - module.vec_znx_automorphism_inplace(p_inv, &mut res_ct.data, i, scratch); - }); - }); - }); - - (self.dnum().min(lhs.dnum()).into()..self.dnum().into()).for_each(|row_i| { - (0..self.rank_in().into()).for_each(|col_j| { - self.at_mut(row_i, col_j).data.zero(); - }); - }); - - self.p = (lhs.p * rhs.p) % (module.cyclotomic_order() as i64); + module.automorphism_key_automorphism(self, a, key, scratch); } - pub fn automorphism_inplace( - &mut self, - module: &Module, - rhs: &AutomorphismKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxAutomorphism - + VecZnxAutomorphismInplace - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable, + pub fn automorphism_inplace(&mut self, module: &M, key: &K, scratch: &mut Scratch) + where + K: AutomorphismKeyPreparedToRef + GetAutomorphismGaloisElement, + Scratch: ScratchTakeCore, + M: AutomorphismKeyAutomorphism, { - #[cfg(debug_assertions)] - { - assert_eq!( - self.rank_out(), - rhs.rank_in(), - "ksk_in output rank: {} != ksk_apply input rank: {}", - self.rank_out(), - rhs.rank_in() - ); - assert_eq!( - self.rank_out(), - rhs.rank_out(), - "ksk_out output rank: {} != ksk_apply output rank: {}", - self.rank_out(), - rhs.rank_out() - ); - } - - let cols_out: usize = (rhs.rank_out() + 1).into(); - - let p: i64 = self.p(); - let p_inv = module.galois_element_inv(p); - - (0..self.rank_in().into()).for_each(|col_i| { - (0..self.dnum().into()).for_each(|row_j| { - let mut res_ct: GLWE<&mut [u8]> = self.at_mut(row_j, col_i); - - // Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a) - (0..cols_out).for_each(|i| { - module.vec_znx_automorphism_inplace(p_inv, &mut res_ct.data, i, scratch); - }); - - // Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) - res_ct.keyswitch_inplace(module, &rhs.key, scratch); - - // Applies back the automorphism X^{-k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) to (-pi^{-1}_{k'+k}(s)a + s, a) - (0..cols_out).for_each(|i| { - module.vec_znx_automorphism_inplace(p_inv, &mut res_ct.data, i, scratch); - }); - }); - }); - - self.p = (self.p * rhs.p) % (module.cyclotomic_order() as i64); + module.automorphism_key_automorphism_inplace(self, key, scratch); + } +} + +impl AutomorphismKeyAutomorphism for Module where + Self: GaloisElement + GLWEAutomorphism + VecZnxAutomorphism +{ +} + +pub trait AutomorphismKeyAutomorphism +where + Self: GaloisElement + GLWEAutomorphism + VecZnxAutomorphism, +{ + fn automorphism_key_automorphism_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize + where + R: GGLWEInfos, + A: GGLWEInfos, + K: GGLWEInfos, + { + self.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos) + } + + fn automorphism_key_automorphism(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) + where + R: AutomorphismKeyToMut + SetAutomorphismGaloisElement, + A: AutomorphismKeyToRef + GetAutomorphismGaloisElement, + K: AutomorphismKeyPreparedToRef + GetAutomorphismGaloisElement, + Scratch: ScratchTakeCore, + { + { + let res: &mut AutomorphismKey<&mut [u8]> = &mut res.to_mut(); + let a: &AutomorphismKey<&[u8]> = &a.to_ref(); + let key: &AutomorphismKeyPrepared<&[u8], _> = &key.to_ref(); + + assert!( + res.dnum().as_u32() <= a.dnum().as_u32(), + "res dnum: {} > a dnum: {}", + res.dnum(), + a.dnum() + ); + + assert_eq!( + res.dsize(), + a.dsize(), + "res dnum: {} != a dnum: {}", + res.dsize(), + a.dsize() + ); + + let cols_out: usize = (key.rank_out() + 1).into(); + + let p: i64 = a.p(); + let p_inv: i64 = self.galois_element_inv(p); + + for row in 0..res.dnum().as_usize() { + for col in 0..cols_out { + let mut res_tmp: GLWE<&mut [u8]> = res.at_mut(row, col); + let a_ct: GLWE<&[u8]> = a.at(row, col); + + // Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a) + for i in 0..cols_out { + self.vec_znx_automorphism(a.p(), res_tmp.data_mut(), i, &a_ct.data, i); + } + + // Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) + self.glwe_keyswitch_inplace(&mut res_tmp, &key.key, scratch); + + // Applies back the automorphism X^{-k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) to (-pi^{-1}_{k'+k}(s)a + s, a) + (0..cols_out).for_each(|i| { + self.vec_znx_automorphism_inplace(p_inv, res_tmp.data_mut(), i, scratch); + }); + } + } + } + + res.set_p((a.p() * key.p()) % (self.cyclotomic_order() as i64)); + } + + fn automorphism_key_automorphism_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) + where + R: AutomorphismKeyToMut + SetAutomorphismGaloisElement + GetAutomorphismGaloisElement, + K: AutomorphismKeyPreparedToRef + GetAutomorphismGaloisElement, + Scratch: ScratchTakeCore, + { + { + let res: &mut AutomorphismKey<&mut [u8]> = &mut res.to_mut(); + let key: &AutomorphismKeyPrepared<&[u8], _> = &key.to_ref(); + + assert_eq!( + res.rank(), + key.rank(), + "key rank: {} != key rank: {}", + res.rank(), + key.rank() + ); + + let cols_out: usize = (key.rank_out() + 1).into(); + + let p: i64 = res.p(); + let p_inv: i64 = self.galois_element_inv(p); + + for row in 0..res.dnum().as_usize() { + for col in 0..cols_out { + let mut res_tmp: GLWE<&mut [u8]> = res.at_mut(row, col); + + // Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a) + for i in 0..cols_out { + self.vec_znx_automorphism_inplace(p_inv, res_tmp.data_mut(), i, scratch); + } + + // Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) + self.glwe_keyswitch_inplace(&mut res_tmp, &key.key, scratch); + + // Applies back the automorphism X^{-k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) to (-pi^{-1}_{k'+k}(s)a + s, a) + for i in 0..cols_out { + self.vec_znx_automorphism_inplace(p_inv, res_tmp.data_mut(), i, scratch); + } + } + } + } + + res.set_p((res.p() * key.p()) % (self.cyclotomic_order() as i64)); } } diff --git a/poulpy-core/src/automorphism/ggsw_ct.rs b/poulpy-core/src/automorphism/ggsw_ct.rs index a3cef86..8035063 100644 --- a/poulpy-core/src/automorphism/ggsw_ct.rs +++ b/poulpy-core/src/automorphism/ggsw_ct.rs @@ -1,165 +1,127 @@ use poulpy_hal::{ - api::{ - ScratchAvailable, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigBytesOf, VecZnxBigNormalize, - VecZnxBigNormalizeTmpBytes, VecZnxDftAddInplace, VecZnxDftApply, VecZnxDftBytesOf, VecZnxDftCopy, VecZnxIdftApplyConsume, - VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, - VmpApplyDftToDftTmpBytes, - }, - layouts::{Backend, DataMut, DataRef, Module, Scratch}, + api::ScratchAvailable, + layouts::{Backend, DataMut, Module, Scratch}, }; -use crate::layouts::{ - GGLWEInfos, GGSW, GGSWInfos, GLWE, - prepared::{AutomorphismKeyPrepared, TensorKeyPrepared}, +use crate::{ + GGSWExpandRows, ScratchTakeCore, + automorphism::glwe_ct::GLWEAutomorphism, + layouts::{ + GGLWEInfos, GGSW, GGSWInfos, GGSWToMut, GGSWToRef, GLWEInfos, LWEInfos, + prepared::{AutomorphismKeyPrepared, AutomorphismKeyPreparedToRef, TensorKeyPrepared, TensorKeyPreparedToRef}, + }, }; impl GGSW> { - pub fn automorphism_tmp_bytes( - module: &Module, - out_infos: &OUT, - in_infos: &IN, - key_infos: &KEY, - tsk_infos: &TSK, + pub fn automorphism_tmp_bytes( + module: &M, + res_infos: &R, + a_infos: &A, + key_infos: &K, + tsk_infos: &T, ) -> usize where - OUT: GGSWInfos, - IN: GGSWInfos, - KEY: GGLWEInfos, - TSK: GGLWEInfos, - Module: - VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigBytesOf + VecZnxNormalizeTmpBytes + VecZnxBigNormalizeTmpBytes, + R: GGSWInfos, + A: GGSWInfos, + K: GGLWEInfos, + T: GGLWEInfos, + M: GGSWAutomorphism, { - let out_size: usize = out_infos.size(); - let ci_dft: usize = module.bytes_of_vec_znx_dft((key_infos.rank_out() + 1).into(), out_size); - let ks_internal: usize = GLWE::keyswitch_tmp_bytes( - module, - &out_infos.glwe_layout(), - &in_infos.glwe_layout(), - key_infos, - ); - let expand: usize = GGSW::expand_row_tmp_bytes(module, out_infos, tsk_infos); - ci_dft + (ks_internal | expand) - } - - pub fn automorphism_inplace_tmp_bytes( - module: &Module, - out_infos: &OUT, - key_infos: &KEY, - tsk_infos: &TSK, - ) -> usize - where - OUT: GGSWInfos, - KEY: GGLWEInfos, - TSK: GGLWEInfos, - Module: - VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigBytesOf + VecZnxNormalizeTmpBytes + VecZnxBigNormalizeTmpBytes, - { - GGSW::automorphism_tmp_bytes(module, out_infos, out_infos, key_infos, tsk_infos) + module.ggsw_automorphism_tmp_bytes(res_infos, a_infos, key_infos, tsk_infos) } } -impl GGSW { - pub fn automorphism( - &mut self, - module: &Module, - lhs: &GGSW, - auto_key: &AutomorphismKeyPrepared, - tensor_key: &TensorKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxAutomorphismInplace - + VecZnxBigBytesOf - + VecZnxNormalizeTmpBytes - + VecZnxDftCopy - + VecZnxDftAddInplace - + VecZnxIdftApplyTmpA - + VecZnxNormalize, - Scratch: ScratchAvailable, +impl GGSW { + pub fn automorphism(&mut self, module: &M, a: &A, key: &K, tsk: &T, scratch: &mut Scratch) + where + A: GGSWToRef, + K: AutomorphismKeyPreparedToRef, + T: TensorKeyPreparedToRef, + Scratch: ScratchTakeCore, + M: GGSWAutomorphism, { - #[cfg(debug_assertions)] - { - use crate::layouts::{GLWEInfos, LWEInfos}; - - assert_eq!(self.n(), module.n() as u32); - assert_eq!(lhs.n(), module.n() as u32); - assert_eq!(auto_key.n(), module.n() as u32); - assert_eq!(tensor_key.n(), module.n() as u32); - - assert_eq!( - self.rank(), - lhs.rank(), - "ggsw_out rank: {} != ggsw_in rank: {}", - self.rank(), - lhs.rank() - ); - assert_eq!( - self.rank(), - auto_key.rank_out(), - "ggsw_in rank: {} != auto_key rank: {}", - self.rank(), - auto_key.rank_out() - ); - assert_eq!( - self.rank(), - tensor_key.rank_out(), - "ggsw_in rank: {} != tensor_key rank: {}", - self.rank(), - tensor_key.rank_out() - ); - assert!(scratch.available() >= GGSW::automorphism_tmp_bytes(module, self, lhs, auto_key, tensor_key)) - }; - - // Keyswitch the j-th row of the col 0 - (0..lhs.dnum().into()).for_each(|row_i| { - // Key-switch column 0, i.e. - // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2) - self.at_mut(row_i, 0) - .automorphism(module, &lhs.at(row_i, 0), auto_key, scratch); - }); - self.expand_row(module, tensor_key, scratch); + module.ggsw_automorphism(self, a, key, tsk, scratch); } - pub fn automorphism_inplace( - &mut self, - module: &Module, - auto_key: &AutomorphismKeyPrepared, - tensor_key: &TensorKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxAutomorphismInplace - + VecZnxBigBytesOf - + VecZnxNormalizeTmpBytes - + VecZnxDftCopy - + VecZnxDftAddInplace - + VecZnxIdftApplyTmpA - + VecZnxNormalize, - Scratch: ScratchAvailable, + pub fn automorphism_inplace(&mut self, module: &M, key: &K, tsk: &T, scratch: &mut Scratch) + where + K: AutomorphismKeyPreparedToRef, + T: TensorKeyPreparedToRef, + Scratch: ScratchTakeCore, + M: GGSWAutomorphism, { - // Keyswitch the j-th row of the col 0 - (0..self.dnum().into()).for_each(|row_i| { - // Key-switch column 0, i.e. - // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2) - self.at_mut(row_i, 0) - .automorphism_inplace(module, auto_key, scratch); - }); - self.expand_row(module, tensor_key, scratch); + module.ggsw_automorphism_inplace(self, key, tsk, scratch); } } + +impl GGSWAutomorphism for Module where Self: GLWEAutomorphism + GGSWExpandRows {} + +pub trait GGSWAutomorphism +where + Self: GLWEAutomorphism + GGSWExpandRows, +{ + fn ggsw_automorphism_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &K, tsk_infos: &T) -> usize + where + R: GGSWInfos, + A: GGSWInfos, + K: GGLWEInfos, + T: GGLWEInfos, + { + let out_size: usize = res_infos.size(); + let ci_dft: usize = self.bytes_of_vec_znx_dft((key_infos.rank_out() + 1).into(), out_size); + let ks_internal: usize = self.glwe_automorphism_tmp_bytes(res_infos, a_infos, key_infos); + let expand: usize = self.ggsw_expand_rows_tmp_bytes(res_infos, tsk_infos); + ci_dft + (ks_internal.max(expand)) + } + + fn ggsw_automorphism(&self, res: &mut R, a: &A, key: &K, tsk: &T, scratch: &mut Scratch) + where + R: GGSWToMut, + A: GGSWToRef, + K: AutomorphismKeyPreparedToRef, + T: TensorKeyPreparedToRef, + Scratch: ScratchTakeCore, + { + let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); + let a: &GGSW<&[u8]> = &a.to_ref(); + let key: &AutomorphismKeyPrepared<&[u8], BE> = &key.to_ref(); + let tsk: &TensorKeyPrepared<&[u8], BE> = &tsk.to_ref(); + + assert_eq!(res.ggsw_layout(), a.ggsw_layout()); + assert_eq!(res.glwe_layout(), a.glwe_layout()); + assert_eq!(res.lwe_layout(), a.lwe_layout()); + assert!(scratch.available() >= self.ggsw_automorphism_tmp_bytes(res, a, key, tsk)); + + // Keyswitch the j-th row of the col 0 + for row in 0..res.dnum().as_usize() { + // Key-switch column 0, i.e. + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2) + self.glwe_automorphism(&mut res.at_mut(row, 0), &a.at(row, 0), key, scratch); + } + + self.ggsw_expand_row(res, tsk, scratch); + } + + fn ggsw_automorphism_inplace(&self, res: &mut R, key: &K, tsk: &T, scratch: &mut Scratch) + where + R: GGSWToMut, + K: AutomorphismKeyPreparedToRef, + T: TensorKeyPreparedToRef, + Scratch: ScratchTakeCore, + { + let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); + let key: &AutomorphismKeyPrepared<&[u8], BE> = &key.to_ref(); + let tsk: &TensorKeyPrepared<&[u8], BE> = &tsk.to_ref(); + + // Keyswitch the j-th row of the col 0 + for row in 0..res.dnum().as_usize() { + // Key-switch column 0, i.e. + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2) + self.glwe_automorphism_inplace(&mut res.at_mut(row, 0), key, scratch); + } + + self.ggsw_expand_row(res, tsk, scratch); + } +} + +impl GGSW {} diff --git a/poulpy-core/src/automorphism/glwe_ct.rs b/poulpy-core/src/automorphism/glwe_ct.rs index 0c8b581..d989891 100644 --- a/poulpy-core/src/automorphism/glwe_ct.rs +++ b/poulpy-core/src/automorphism/glwe_ct.rs @@ -1,345 +1,331 @@ use poulpy_hal::{ api::{ - ScratchAvailable, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace, VecZnxBigNormalize, - VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallInplace, VecZnxBigSubSmallNegateInplace, VecZnxDftApply, VecZnxDftBytesOf, - VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, - VmpApplyDftToDftTmpBytes, + ScratchTakeBasic, VecZnxAutomorphismInplace, VecZnxBigAutomorphismInplace, VecZnxBigSubSmallInplace, + VecZnxBigSubSmallNegateInplace, }, - layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnxBig}, + layouts::{Backend, DataMut, Module, Scratch, VecZnxBig}, }; -use crate::layouts::{GGLWEInfos, GLWE, GLWEInfos, LWEInfos, prepared::AutomorphismKeyPrepared}; +use crate::{ + GLWEKeyswitch, ScratchTakeCore, keyswitch_internal, + layouts::{ + GGLWEInfos, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, + prepared::{AutomorphismKeyPrepared, AutomorphismKeyPreparedToRef, GetAutomorphismGaloisElement}, + }, +}; impl GLWE> { - pub fn automorphism_tmp_bytes( - module: &Module, - out_infos: &OUT, - in_infos: &IN, - key_infos: &KEY, - ) -> usize + pub fn automorphism_tmp_bytes(module: &M, res_infos: &R, a_infos: &A, key_infos: &K) -> usize where - OUT: GLWEInfos, - IN: GLWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, + R: GLWEInfos, + A: GLWEInfos, + K: GGLWEInfos, + M: GLWEAutomorphism, { - Self::keyswitch_tmp_bytes(module, out_infos, in_infos, key_infos) - } - - pub fn automorphism_inplace_tmp_bytes(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize - where - OUT: GLWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, - { - Self::keyswitch_inplace_tmp_bytes(module, out_infos, key_infos) + module.glwe_automorphism_tmp_bytes(res_infos, a_infos, key_infos) } } impl GLWE { - pub fn automorphism( - &mut self, - module: &Module, - lhs: &GLWE, - rhs: &AutomorphismKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxAutomorphismInplace - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable, + pub fn automorphism(&mut self, module: &M, a: &A, key: &K, scratch: &mut Scratch) + where + M: GLWEAutomorphism, + A: GLWEToRef, + K: AutomorphismKeyPreparedToRef + GetAutomorphismGaloisElement, + Scratch: ScratchTakeCore, { - self.keyswitch(module, lhs, &rhs.key, scratch); - (0..(self.rank() + 1).into()).for_each(|i| { - module.vec_znx_automorphism_inplace(rhs.p(), &mut self.data, i, scratch); - }) + module.glwe_automorphism(self, a, key, scratch); } - pub fn automorphism_inplace( - &mut self, - module: &Module, - rhs: &AutomorphismKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxAutomorphismInplace - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable, + pub fn automorphism_add(&mut self, module: &M, a: &A, key: &K, scratch: &mut Scratch) + where + M: GLWEAutomorphism, + A: GLWEToRef, + K: AutomorphismKeyPreparedToRef + GetAutomorphismGaloisElement, + Scratch: ScratchTakeCore, { - self.keyswitch_inplace(module, &rhs.key, scratch); - (0..(self.rank() + 1).into()).for_each(|i| { - module.vec_znx_automorphism_inplace(rhs.p(), &mut self.data, i, scratch); - }) + module.glwe_automorphism_add(self, a, key, scratch); } - pub fn automorphism_add( - &mut self, - module: &Module, - lhs: &GLWE, - rhs: &AutomorphismKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxBigAutomorphismInplace - + VecZnxNormalizeTmpBytes - + VecZnxNormalize, - Scratch: ScratchAvailable, + pub fn automorphism_sub(&mut self, module: &M, a: &A, key: &K, scratch: &mut Scratch) + where + M: GLWEAutomorphism, + A: GLWEToRef, + K: AutomorphismKeyPreparedToRef + GetAutomorphismGaloisElement, + Scratch: ScratchTakeCore, { - #[cfg(debug_assertions)] - { - self.assert_keyswitch(module, lhs, &rhs.key, scratch); - } - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch_1); - (0..(self.rank() + 1).into()).for_each(|i| { - module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1); - module.vec_znx_big_add_small_inplace(&mut res_big, i, &lhs.data, i); - module.vec_znx_big_normalize( - self.base2k().into(), - &mut self.data, - i, - rhs.base2k().into(), - &res_big, - i, - scratch_1, - ); - }) + module.glwe_automorphism_sub(self, a, key, scratch); } - pub fn automorphism_add_inplace( - &mut self, - module: &Module, - rhs: &AutomorphismKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxBigAutomorphismInplace - + VecZnxNormalizeTmpBytes - + VecZnxNormalize, - Scratch: ScratchAvailable, + pub fn glwe_automorphism_sub_negate(&mut self, module: &M, a: &A, key: &K, scratch: &mut Scratch) + where + M: GLWEAutomorphism, + A: GLWEToRef, + K: AutomorphismKeyPreparedToRef + GetAutomorphismGaloisElement, + Scratch: ScratchTakeCore, { - #[cfg(debug_assertions)] - { - self.assert_keyswitch_inplace(module, &rhs.key, scratch); - } - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, B> = self.keyswitch_internal(module, res_dft, &rhs.key, scratch_1); - (0..(self.rank() + 1).into()).for_each(|i| { - module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1); - module.vec_znx_big_add_small_inplace(&mut res_big, i, &self.data, i); - module.vec_znx_big_normalize( - self.base2k().into(), - &mut self.data, - i, - rhs.base2k().into(), - &res_big, - i, - scratch_1, - ); - }) + module.glwe_automorphism_sub_negate(self, a, key, scratch); } - pub fn automorphism_sub_ab( - &mut self, - module: &Module, - lhs: &GLWE, - rhs: &AutomorphismKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxBigAutomorphismInplace - + VecZnxBigSubSmallInplace - + VecZnxNormalizeTmpBytes - + VecZnxNormalize, - Scratch: ScratchAvailable, + pub fn automorphism_inplace(&mut self, module: &M, key: &K, scratch: &mut Scratch) + where + M: GLWEAutomorphism, + K: AutomorphismKeyPreparedToRef + GetAutomorphismGaloisElement, + Scratch: ScratchTakeCore, { - #[cfg(debug_assertions)] - { - self.assert_keyswitch(module, lhs, &rhs.key, scratch); - } - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch_1); - (0..(self.rank() + 1).into()).for_each(|i| { - module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1); - module.vec_znx_big_sub_small_inplace(&mut res_big, i, &lhs.data, i); - module.vec_znx_big_normalize( - self.base2k().into(), - &mut self.data, - i, - rhs.base2k().into(), - &res_big, - i, - scratch_1, - ); - }) + module.glwe_automorphism_inplace(self, key, scratch); } - pub fn automorphism_sub_inplace( - &mut self, - module: &Module, - rhs: &AutomorphismKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxBigAutomorphismInplace - + VecZnxBigSubSmallInplace - + VecZnxNormalizeTmpBytes - + VecZnxNormalize, - Scratch: ScratchAvailable, + pub fn automorphism_add_inplace(&mut self, module: &M, key: &K, scratch: &mut Scratch) + where + M: GLWEAutomorphism, + K: AutomorphismKeyPreparedToRef + GetAutomorphismGaloisElement, + Scratch: ScratchTakeCore, { - #[cfg(debug_assertions)] - { - self.assert_keyswitch_inplace(module, &rhs.key, scratch); - } - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, B> = self.keyswitch_internal(module, res_dft, &rhs.key, scratch_1); - (0..(self.rank() + 1).into()).for_each(|i| { - module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1); - module.vec_znx_big_sub_small_inplace(&mut res_big, i, &self.data, i); - module.vec_znx_big_normalize( - self.base2k().into(), - &mut self.data, - i, - rhs.base2k().into(), - &res_big, - i, - scratch_1, - ); - }) + module.glwe_automorphism_add_inplace(self, key, scratch); } - pub fn automorphism_sub_negate( - &mut self, - module: &Module, - lhs: &GLWE, - rhs: &AutomorphismKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxBigAutomorphismInplace - + VecZnxBigSubSmallNegateInplace - + VecZnxNormalizeTmpBytes - + VecZnxNormalize, - Scratch: ScratchAvailable, + pub fn automorphism_sub_inplace(&mut self, module: &M, key: &K, scratch: &mut Scratch) + where + M: GLWEAutomorphism, + K: AutomorphismKeyPreparedToRef + GetAutomorphismGaloisElement, + Scratch: ScratchTakeCore, { - #[cfg(debug_assertions)] - { - self.assert_keyswitch(module, lhs, &rhs.key, scratch); - } - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch_1); - (0..(self.rank() + 1).into()).for_each(|i| { - module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1); - module.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, &lhs.data, i); - module.vec_znx_big_normalize( - self.base2k().into(), - &mut self.data, - i, - rhs.base2k().into(), - &res_big, - i, - scratch_1, - ); - }) + module.glwe_automorphism_sub_inplace(self, key, scratch); } - pub fn automorphism_sub_negate_inplace( - &mut self, - module: &Module, - rhs: &AutomorphismKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxBigAutomorphismInplace - + VecZnxBigSubSmallNegateInplace - + VecZnxNormalizeTmpBytes - + VecZnxNormalize, - Scratch: ScratchAvailable, + pub fn automorphism_sub_negate_inplace(&mut self, module: &M, key: &K, scratch: &mut Scratch) + where + M: GLWEAutomorphism, + K: AutomorphismKeyPreparedToRef + GetAutomorphismGaloisElement, + Scratch: ScratchTakeCore, { - #[cfg(debug_assertions)] - { - self.assert_keyswitch_inplace(module, &rhs.key, scratch); - } - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, B> = self.keyswitch_internal(module, res_dft, &rhs.key, scratch_1); - (0..(self.rank() + 1).into()).for_each(|i| { - module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1); - module.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, &self.data, i); - module.vec_znx_big_normalize( - self.base2k().into(), - &mut self.data, - i, - rhs.base2k().into(), - &res_big, - i, - scratch_1, - ); - }) + module.glwe_automorphism_sub_negate_inplace(self, key, scratch); } } + +pub trait GLWEAutomorphism +where + Self: GLWEKeyswitch + + VecZnxAutomorphismInplace + + VecZnxBigAutomorphismInplace + + VecZnxBigSubSmallInplace + + VecZnxBigSubSmallNegateInplace, +{ + fn glwe_automorphism_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize + where + R: GLWEInfos, + A: GLWEInfos, + K: GGLWEInfos, + { + self.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos) + } + + fn glwe_automorphism(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + A: GLWEToRef, + K: AutomorphismKeyPreparedToRef + GetAutomorphismGaloisElement, + Scratch: ScratchTakeCore, + { + self.glwe_keyswitch(res, a, &key.to_ref().key, scratch); + + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_automorphism_inplace(key.p(), res.data_mut(), i, scratch); + } + } + + fn glwe_automorphism_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + K: AutomorphismKeyPreparedToRef + GetAutomorphismGaloisElement, + Scratch: ScratchTakeCore, + { + self.glwe_keyswitch_inplace(res, &key.to_ref().key, scratch); + + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_automorphism_inplace(key.p(), res.data_mut(), i, scratch); + } + } + + fn glwe_automorphism_add(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + A: GLWEToRef, + K: AutomorphismKeyPreparedToRef + GetAutomorphismGaloisElement, + Scratch: ScratchTakeCore, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); + let key: &AutomorphismKeyPrepared<&[u8], BE> = &key.to_ref(); + + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size + let mut res_big: VecZnxBig<_, BE> = keyswitch_internal(self, res_dft, a, &key.key, scratch_1); + + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); + self.vec_znx_big_add_small_inplace(&mut res_big, i, a.data(), i); + self.vec_znx_big_normalize( + res.base2k().into(), + res.data_mut(), + i, + key.base2k().into(), + &res_big, + i, + scratch_1, + ); + } + } + + fn glwe_automorphism_add_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + K: AutomorphismKeyPreparedToRef + GetAutomorphismGaloisElement, + Scratch: ScratchTakeCore, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let key: &AutomorphismKeyPrepared<&[u8], BE> = &key.to_ref(); + + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size + let mut res_big: VecZnxBig<_, BE> = keyswitch_internal(self, res_dft, res, &key.key, scratch_1); + + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); + self.vec_znx_big_add_small_inplace(&mut res_big, i, res.data(), i); + self.vec_znx_big_normalize( + res.base2k().into(), + res.data_mut(), + i, + key.base2k().into(), + &res_big, + i, + scratch_1, + ); + } + } + + fn glwe_automorphism_sub(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + A: GLWEToRef, + K: AutomorphismKeyPreparedToRef + GetAutomorphismGaloisElement, + Scratch: ScratchTakeCore, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); + let key: &AutomorphismKeyPrepared<&[u8], BE> = &key.to_ref(); + + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size + let mut res_big: VecZnxBig<_, BE> = keyswitch_internal(self, res_dft, a, &key.key, scratch_1); + + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); + self.vec_znx_big_sub_small_inplace(&mut res_big, i, a.data(), i); + self.vec_znx_big_normalize( + res.base2k().into(), + res.data_mut(), + i, + key.base2k().into(), + &res_big, + i, + scratch_1, + ); + } + } + + fn glwe_automorphism_sub_negate(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + A: GLWEToRef, + K: AutomorphismKeyPreparedToRef + GetAutomorphismGaloisElement, + Scratch: ScratchTakeCore, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); + let key: &AutomorphismKeyPrepared<&[u8], BE> = &key.to_ref(); + + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size + let mut res_big: VecZnxBig<_, BE> = keyswitch_internal(self, res_dft, a, &key.key, scratch_1); + + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); + self.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, a.data(), i); + self.vec_znx_big_normalize( + res.base2k().into(), + res.data_mut(), + i, + key.base2k().into(), + &res_big, + i, + scratch_1, + ); + } + } + + fn glwe_automorphism_sub_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + K: AutomorphismKeyPreparedToRef + GetAutomorphismGaloisElement, + Scratch: ScratchTakeCore, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let key: &AutomorphismKeyPrepared<&[u8], BE> = &key.to_ref(); + + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size + let mut res_big: VecZnxBig<_, BE> = keyswitch_internal(self, res_dft, res, &key.key, scratch_1); + + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); + self.vec_znx_big_sub_small_inplace(&mut res_big, i, res.data(), i); + self.vec_znx_big_normalize( + res.base2k().into(), + res.data_mut(), + i, + key.base2k().into(), + &res_big, + i, + scratch_1, + ); + } + } + + fn glwe_automorphism_sub_negate_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + K: AutomorphismKeyPreparedToRef + GetAutomorphismGaloisElement, + Scratch: ScratchTakeCore, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let key: &AutomorphismKeyPrepared<&[u8], BE> = &key.to_ref(); + + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size + let mut res_big: VecZnxBig<_, BE> = keyswitch_internal(self, res_dft, res, &key.key, scratch_1); + + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); + self.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, res.data(), i); + self.vec_znx_big_normalize( + res.base2k().into(), + res.data_mut(), + i, + key.base2k().into(), + &res_big, + i, + scratch_1, + ); + } + } +} + +impl GLWEAutomorphism for Module where + Self: GLWEKeyswitch + + VecZnxAutomorphismInplace + + VecZnxBigAutomorphismInplace + + VecZnxBigSubSmallInplace + + VecZnxBigSubSmallNegateInplace +{ +} diff --git a/poulpy-core/src/automorphism/mod.rs b/poulpy-core/src/automorphism/mod.rs index f985c5e..fd10f33 100644 --- a/poulpy-core/src/automorphism/mod.rs +++ b/poulpy-core/src/automorphism/mod.rs @@ -1,3 +1,7 @@ mod gglwe_atk; mod ggsw_ct; mod glwe_ct; + +pub use gglwe_atk::*; +pub use ggsw_ct::*; +pub use glwe_ct::*; diff --git a/poulpy-core/src/conversion/gglwe_to_ggsw.rs b/poulpy-core/src/conversion/gglwe_to_ggsw.rs index a7b86fa..24d02bd 100644 --- a/poulpy-core/src/conversion/gglwe_to_ggsw.rs +++ b/poulpy-core/src/conversion/gglwe_to_ggsw.rs @@ -1,19 +1,18 @@ use poulpy_hal::{ api::{ - ModuleN, ScratchAvailable, ScratchTakeBasic, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxCopy, VecZnxDftAddInplace, - VecZnxDftApply, VecZnxDftBytesOf, VecZnxDftCopy, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeTmpBytes, - VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + ModuleN, ScratchAvailable, ScratchTakeBasic, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxDftAddInplace, VecZnxDftApply, + VecZnxDftBytesOf, VecZnxDftCopy, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, + VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, Module, Scratch, VmpPMat, ZnxInfos}, }; use crate::{ - ScratchTakeCore, + GLWECopy, ScratchTakeCore, layouts::{ GGLWE, GGLWEInfos, GGLWEToRef, GGSW, GGSWInfos, GGSWToMut, GLWEInfos, LWEInfos, prepared::{TensorKeyPrepared, TensorKeyPreparedToRef}, }, - operations::GLWEOperations, }; impl GGLWE> { @@ -39,11 +38,11 @@ impl GGSW { } } -impl GGSWFromGGLWE for Module where Self: GGSWExpandRows + VecZnxCopy {} +impl GGSWFromGGLWE for Module where Self: GGSWExpandRows + GLWECopy {} pub trait GGSWFromGGLWE where - Self: GGSWExpandRows + VecZnxCopy, + Self: GGSWExpandRows + GLWECopy, { fn ggsw_from_gglwe_tmp_bytes(&self, res_infos: &R, tsk_infos: &A) -> usize where @@ -71,7 +70,7 @@ where assert_eq!(tsk.n(), self.n() as u32); for row in 0..res.dnum().into() { - res.at_mut(row, 0).copy(self, &a.at(row, 0)); + self.glwe_copy(&mut res.at_mut(row, 0), &a.at(row, 0)); } self.ggsw_expand_row(res, tsk, scratch); diff --git a/poulpy-core/src/encryption/compressed/gglwe_tsk.rs b/poulpy-core/src/encryption/compressed/gglwe_tsk.rs index f3086df..02393f1 100644 --- a/poulpy-core/src/encryption/compressed/gglwe_tsk.rs +++ b/poulpy-core/src/encryption/compressed/gglwe_tsk.rs @@ -141,3 +141,18 @@ impl TensorKeyCompressed { module.gglwe_tensor_key_encrypt_sk(self, sk, seed_xa, source_xe, scratch); } } + +impl TensorKeyCompressed { + pub fn encrypt_sk( + &mut self, + module: &Module, + sk: &GLWESecret, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + Module: GGLWETensorKeyCompressedEncryptSk, + { + module.gglwe_tensor_key_encrypt_sk(self, sk, seed_xa, source_xe, scratch); + } +} diff --git a/poulpy-core/src/encryption/gglwe_atk.rs b/poulpy-core/src/encryption/gglwe_atk.rs index 972e06c..2e87707 100644 --- a/poulpy-core/src/encryption/gglwe_atk.rs +++ b/poulpy-core/src/encryption/gglwe_atk.rs @@ -6,6 +6,7 @@ use poulpy_hal::{ VecZnxSwitchRing, }, layouts::{Backend, DataMut, Module, Scratch}, + layouts::{Backend, DataMut, Module, Scratch}, source::Source, }; @@ -16,6 +17,8 @@ use crate::{ }, }; +impl AutomorphismKey> { + pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize impl AutomorphismKey> { pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize where @@ -28,8 +31,10 @@ impl AutomorphismKey> { "rank_in != rank_out is not supported for GGLWEAutomorphismKey" ); GLWESwitchingKey::encrypt_sk_tmp_bytes(module, infos) + GLWESecret::bytes_of_from_infos(module, &infos.glwe_layout()) + GLWESwitchingKey::encrypt_sk_tmp_bytes(module, infos) + GLWESecret::bytes_of_from_infos(module, &infos.glwe_layout()) } + pub fn encrypt_pk_tmp_bytes(module: &Module, _infos: &A) -> usize pub fn encrypt_pk_tmp_bytes(module: &Module, _infos: &A) -> usize where A: GGLWEInfos, @@ -40,6 +45,7 @@ impl AutomorphismKey> { "rank_in != rank_out is not supported for GGLWEAutomorphismKey" ); GLWESwitchingKey::encrypt_pk_tmp_bytes(module, _infos) + GLWESwitchingKey::encrypt_pk_tmp_bytes(module, _infos) } } @@ -57,6 +63,25 @@ pub trait GGLWEAutomorphismKeyEncryptSk { B: GLWESecretToRef; } +impl AutomorphismKey +where + Self: AutomorphismKeyToMut, +{ + pub fn encrypt_sk( +pub trait GGLWEAutomorphismKeyEncryptSk { + fn gglwe_automorphism_key_encrypt_sk( + &self, + res: &mut A, + p: i64, + sk: &B, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + A: AutomorphismKeyToMut, + B: GLWESecretToRef; +} + impl AutomorphismKey where Self: AutomorphismKeyToMut, @@ -64,11 +89,14 @@ where pub fn encrypt_sk( &mut self, module: &Module, + module: &Module, p: i64, sk: &S, + sk: &S, source_xa: &mut Source, source_xe: &mut Source, scratch: &mut Scratch, + scratch: &mut Scratch, ) where S: GLWESecretToRef, Module: GGLWEAutomorphismKeyEncryptSk, @@ -121,20 +149,29 @@ where { use crate::layouts::{GLWEInfos, LWEInfos}; + assert_eq!(res.n(), sk.n()); + assert_eq!(res.rank_out(), res.rank_in()); + assert_eq!(sk.rank(), res.rank_out()); assert_eq!(res.n(), sk.n()); assert_eq!(res.rank_out(), res.rank_in()); assert_eq!(sk.rank(), res.rank_out()); assert!( + scratch.available() >= AutomorphismKey::encrypt_sk_tmp_bytes(self, res), + "scratch.available(): {} < AutomorphismKey::encrypt_sk_tmp_bytes: {:?}", scratch.available() >= AutomorphismKey::encrypt_sk_tmp_bytes(self, res), "scratch.available(): {} < AutomorphismKey::encrypt_sk_tmp_bytes: {:?}", scratch.available(), AutomorphismKey::encrypt_sk_tmp_bytes(self, res) + AutomorphismKey::encrypt_sk_tmp_bytes(self, res) ) } let (mut sk_out, scratch_1) = scratch.take_glwe_secret(self, sk.rank()); { + (0..res.rank_out().into()).for_each(|i| { + self.vec_znx_automorphism( + self.galois_element_inv(p), (0..res.rank_out().into()).for_each(|i| { self.vec_znx_automorphism( self.galois_element_inv(p), @@ -146,9 +183,12 @@ where }); } + res.key + .encrypt_sk(self, sk, &sk_out, source_xa, source_xe, scratch_1); res.key .encrypt_sk(self, sk, &sk_out, source_xa, source_xe, scratch_1); res.p = p; + res.p = p; } } diff --git a/poulpy-core/src/glwe_packing.rs b/poulpy-core/src/glwe_packing.rs index 7dacb97..5304df8 100644 --- a/poulpy-core/src/glwe_packing.rs +++ b/poulpy-core/src/glwe_packing.rs @@ -1,19 +1,16 @@ use std::collections::HashMap; use poulpy_hal::{ - api::{ - ScratchAvailable, VecZnxAddInplace, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace, - VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxCopy, VecZnxDftApply, - VecZnxDftBytesOf, VecZnxDftCopy, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNegateInplace, VecZnxNormalize, - VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, - VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, - }, - layouts::{Backend, DataMut, DataRef, Module, Scratch}, + api::{ModuleLogN, VecZnxCopy, VecZnxRotateInplace}, + layouts::{Backend, DataMut, DataRef, GaloisElement, Module, Scratch}, }; use crate::{ - GLWEOperations, - layouts::{GGLWEInfos, GLWE, GLWEInfos, LWEInfos, prepared::AutomorphismKeyPrepared}, + GLWEAdd, GLWEAutomorphism, GLWENormalize, GLWERotate, GLWEShift, GLWESub, ScratchTakeCore, + layouts::{ + GGLWEInfos, GLWE, GLWEAlloc, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, + prepared::{AutomorphismKeyPreparedToRef, GetAutomorphismGaloisElement}, + }, }; /// [GLWEPacker] enables only the fly GLWE packing @@ -43,9 +40,10 @@ impl Accumulator { /// * `base2k`: base 2 logarithm of the GLWE ciphertext in memory digit representation. /// * `k`: base 2 precision of the GLWE ciphertext precision over the Torus. /// * `rank`: rank of the GLWE ciphertext. - pub fn alloc(module: &Module, infos: &A) -> Self + pub fn alloc(module: &M, infos: &A) -> Self where A: GLWEInfos, + M: GLWEAlloc, { Self { data: GLWE::alloc_from_infos(module, infos), @@ -66,9 +64,10 @@ impl GLWEPacker { /// and N GLWE ciphertext can be packed. With `log_batch=2` all coefficients /// which are multiples of X^{N/4} are packed. Meaning that N/4 ciphertexts /// can be packed. - pub fn new(module: Module, infos: &A, log_batch: usize) -> Self + pub fn new(module: &M, infos: &A, log_batch: usize) -> Self where A: GLWEInfos, + M: GLWEAlloc, { let mut accumulators: Vec = Vec::::new(); let log_n: usize = infos.n().log2(); @@ -90,13 +89,13 @@ impl GLWEPacker { } /// Number of scratch space bytes required to call [Self::add]. - pub fn tmp_bytes(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize + pub fn tmp_bytes(module: &M, res_infos: &R, key_infos: &K) -> usize where - OUT: GLWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, + R: GLWEInfos, + K: GGLWEInfos, + M: GLWEAlloc + GLWEAutomorphism, { - pack_core_tmp_bytes(module, out_infos, key_infos) + pack_core_tmp_bytes(module, res_infos, key_infos) } pub fn galois_elements(module: &Module) -> Vec { @@ -112,37 +111,12 @@ impl GLWEPacker { /// * `a`: ciphertext to pack. Can optionally give None to pack a 0 ciphertext. /// * `auto_keys`: a [HashMap] containing the [AutomorphismKeyExec]s. /// * `scratch`: scratch space of size at least [Self::tmp_bytes]. - pub fn add( - &mut self, - module: &Module, - a: Option<&GLWE>, - auto_keys: &HashMap>, - scratch: &mut Scratch, - ) where - Module: VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxCopy - + VecZnxRotateInplace - + VecZnxSub - + VecZnxNegateInplace - + VecZnxRshInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxSubInplace - + VecZnxRotate - + VecZnxAutomorphismInplace - + VecZnxBigSubSmallNegateInplace - + VecZnxBigAutomorphismInplace - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable, + pub fn add(&mut self, module: &M, a: Option<&A>, auto_keys: &HashMap, scratch: &mut Scratch) + where + A: GLWEToRef, + K: AutomorphismKeyPreparedToRef, + M: GLWEAutomorphism, + Scratch: ScratchTakeCore, { assert!( (self.counter as u32) < self.accumulators[0].data.n(), @@ -177,47 +151,27 @@ impl GLWEPacker { } } -fn pack_core_tmp_bytes(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize +fn pack_core_tmp_bytes(module: &M, res_infos: &R, key_infos: &K) -> usize where - OUT: GLWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, + R: GLWEInfos, + K: GGLWEInfos, + M: GLWEAlloc + GLWEAutomorphism, { - combine_tmp_bytes(module, out_infos, key_infos) + combine_tmp_bytes(module, res_infos, key_infos) } -fn pack_core( - module: &Module, - a: Option<&GLWE>, +fn pack_core( + module: &M, + a: Option<&A>, accumulators: &mut [Accumulator], i: usize, - auto_keys: &HashMap>, - scratch: &mut Scratch, + auto_keys: &HashMap, + scratch: &mut Scratch, ) where - Module: VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxCopy - + VecZnxRotateInplace - + VecZnxSub - + VecZnxNegateInplace - + VecZnxRshInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxSubInplace - + VecZnxRotate - + VecZnxAutomorphismInplace - + VecZnxBigSubSmallNegateInplace - + VecZnxBigAutomorphismInplace - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable, + A: GLWEToRef + GLWEInfos, + K: AutomorphismKeyPreparedToRef, + M: GLWEAutomorphism + ModuleLogN + VecZnxCopy, + Scratch: ScratchTakeCore, { let log_n: usize = module.log_n(); @@ -268,49 +222,29 @@ fn pack_core( } } -fn combine_tmp_bytes(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize +fn combine_tmp_bytes(module: &M, res_infos: &R, key_infos: &K) -> usize where - OUT: GLWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, + R: GLWEInfos, + K: GGLWEInfos, + M: GLWEAlloc + GLWEAutomorphism, { - GLWE::bytes_of_from_infos(module, out_infos) - + (GLWE::rsh_tmp_bytes(module.n()) | GLWE::automorphism_inplace_tmp_bytes(module, out_infos, key_infos)) + GLWE::bytes_of_from_infos(module, res_infos) + + (GLWE::rsh_tmp_bytes(module.n()) | module.glwe_automorphism_tmp_bytes(res_infos, res_infos, key_infos)) } /// [combine] merges two ciphertexts together. -fn combine( - module: &Module, +fn combine( + module: &M, acc: &mut Accumulator, - b: Option<&GLWE>, + b: Option<&B>, i: usize, - auto_keys: &HashMap>, - scratch: &mut Scratch, + auto_keys: &HashMap, + scratch: &mut Scratch, ) where - Module: VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxCopy - + VecZnxRotateInplace - + VecZnxSub - + VecZnxNegateInplace - + VecZnxRshInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxSubInplace - + VecZnxRotate - + VecZnxAutomorphismInplace - + VecZnxBigSubSmallNegateInplace - + VecZnxBigAutomorphismInplace - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable, + B: GLWEToRef + GLWEInfos, + K: AutomorphismKeyPreparedToRef, + M: GLWEAutomorphism + GaloisElement + VecZnxRotateInplace, + Scratch: ScratchTakeCore, { let log_n: usize = acc.data.n().log2(); let a: &mut GLWE> = &mut acc.data; @@ -335,7 +269,7 @@ fn combine( // since 2*(I(X) * Q/2) = I(X) * Q = 0 mod Q. if acc.value { if let Some(b) = b { - let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(a); + let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(module, a); // a = a * X^-t a.rotate_inplace(module, -t, scratch_1); @@ -390,110 +324,76 @@ fn combine( } } -/// Packs [x_0: GLWE(m_0), x_1: GLWE(m_1), ..., x_i: GLWE(m_i)] -/// to [0: GLWE(m_0 * X^x_0 + m_1 * X^x_1 + ... + m_i * X^x_i)] -pub fn glwe_packing( - module: &Module, - cts: &mut HashMap>, - log_gap_out: usize, - auto_keys: &HashMap>, - scratch: &mut Scratch, -) where - ATK: DataRef, - Module: VecZnxRotateInplace - + VecZnxNormalizeInplace - + VecZnxNormalizeTmpBytes - + VecZnxSwitchRing - + VecZnxBigAutomorphismInplace - + VecZnxRshInplace - + VecZnxDftCopy - + VecZnxIdftApplyTmpA - + VecZnxSub - + VecZnxAddInplace - + VecZnxNegateInplace - + VecZnxCopy - + VecZnxSubInplace - + VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxAutomorphismInplace - + VecZnxBigSubSmallNegateInplace - + VecZnxRotate - + VecZnxNormalize, - Scratch: ScratchAvailable, +pub trait GLWEPacking +where + Self: GLWEAutomorphism + + GaloisElement + + ModuleLogN + + GLWERotate + + GLWESub + + GLWEShift + + GLWEAdd + + GLWENormalize, { - #[cfg(debug_assertions)] + /// Packs [x_0: GLWE(m_0), x_1: GLWE(m_1), ..., x_i: GLWE(m_i)] + /// to [0: GLWE(m_0 * X^x_0 + m_1 * X^x_1 + ... + m_i * X^x_i)] + fn glwe_pack( + &self, + cts: &mut HashMap, + log_gap_out: usize, + keys: &HashMap, + scratch: &mut Scratch, + ) where + R: GLWEToMut + GLWEToRef + GLWEInfos, + K: AutomorphismKeyPreparedToRef + GetAutomorphismGaloisElement, + Scratch: ScratchTakeCore, { - assert!(*cts.keys().max().unwrap() < module.n()) - } + #[cfg(debug_assertions)] + { + assert!(*cts.keys().max().unwrap() < self.n()) + } - let log_n: usize = module.log_n(); + let log_n: usize = self.log_n(); - (0..log_n - log_gap_out).for_each(|i| { - let t: usize = (1 << log_n).min(1 << (log_n - 1 - i)); + for i in 0..(log_n - log_gap_out){ + let t: usize = (1 << log_n).min(1 << (log_n - 1 - i)); - let auto_key: &AutomorphismKeyPrepared = if i == 0 { - auto_keys.get(&-1).unwrap() - } else { - auto_keys.get(&module.galois_element(1 << (i - 1))).unwrap() + let key: &K = if i == 0 { + keys.get(&-1).unwrap() + } else { + keys.get(&self.galois_element(1 << (i - 1))).unwrap() + }; + + for j in 0..t{ + let mut a: Option<&mut R> = cts.remove(&j); + let mut b: Option<&mut R> = cts.remove(&(j + t)); + + pack_internal(self, &mut a, &mut b, i, key, scratch); + + if let Some(a) = a { + cts.insert(j, a); + } else if let Some(b) = b { + cts.insert(j, b); + } + }; }; - - (0..t).for_each(|j| { - let mut a: Option<&mut GLWE> = cts.remove(&j); - let mut b: Option<&mut GLWE> = cts.remove(&(j + t)); - - pack_internal(module, &mut a, &mut b, i, auto_key, scratch); - - if let Some(a) = a { - cts.insert(j, a); - } else if let Some(b) = b { - cts.insert(j, b); - } - }); - }); + } } #[allow(clippy::too_many_arguments)] -fn pack_internal( - module: &Module, - a: &mut Option<&mut GLWE>, - b: &mut Option<&mut GLWE>, +fn pack_internal( + module: &M, + a: &mut Option<&mut A>, + b: &mut Option<&mut B>, i: usize, - auto_key: &AutomorphismKeyPrepared, - scratch: &mut Scratch, + auto_key: &K, + scratch: &mut Scratch, ) where - Module: VecZnxRotateInplace - + VecZnxNormalizeInplace - + VecZnxNormalizeTmpBytes - + VecZnxBigAutomorphismInplace - + VecZnxRshInplace - + VecZnxDftCopy - + VecZnxIdftApplyTmpA - + VecZnxSub - + VecZnxAddInplace - + VecZnxNegateInplace - + VecZnxCopy - + VecZnxSubInplace - + VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxAutomorphismInplace - + VecZnxBigSubSmallNegateInplace - + VecZnxRotate - + VecZnxNormalize, - Scratch: ScratchAvailable, + M: GLWEAutomorphism + GLWERotate + GLWESub + GLWEShift + GLWEAdd + GLWENormalize, + A: GLWEToMut + GLWEToRef + GLWEInfos, + B: GLWEToMut + GLWEToRef + GLWEInfos, + K: AutomorphismKeyPreparedToRef + GetAutomorphismGaloisElement, + Scratch: ScratchTakeCore, { // Goal is to evaluate: a = a + b*X^t + phi(a - b*X^t)) // We also use the identity: AUTO(a * X^t, g) = -X^t * AUTO(a, g) @@ -509,45 +409,45 @@ fn pack_internal( let t: i64 = 1 << (a.n().log2() - i - 1); if let Some(b) = b.as_deref_mut() { - let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(a); + let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(module, a); // a = a * X^-t - a.rotate_inplace(module, -t, scratch_1); + module.glwe_rotate_inplace(-t, a, scratch_1); // tmp_b = a * X^-t - b - tmp_b.sub(module, a, b); - tmp_b.rsh(module, 1, scratch_1); + module.glwe_sub(&mut tmp_b, a, b); + module.glwe_rsh(1, &mut tmp_b, scratch_1); // a = a * X^-t + b - a.add_inplace(module, b); - a.rsh(module, 1, scratch_1); + module.glwe_add_inplace(a, b); + module.glwe_rsh(1, a, scratch_1); - tmp_b.normalize_inplace(module, scratch_1); + module.glwe_normalize_inplace(&mut tmp_b, scratch_1); // tmp_b = phi(a * X^-t - b) - tmp_b.automorphism_inplace(module, auto_key, scratch_1); + module.glwe_automorphism_inplace(&mut tmp_b, auto_key, scratch_1); // a = a * X^-t + b - phi(a * X^-t - b) - a.sub_inplace_ab(module, &tmp_b); - a.normalize_inplace(module, scratch_1); + module.glwe_sub_inplace(a, &tmp_b); + module.glwe_normalize_inplace(a, scratch_1); // a = a + b * X^t - phi(a * X^-t - b) * X^t // = a + b * X^t - phi(a * X^-t - b) * - phi(X^t) // = a + b * X^t + phi(a - b * X^t) - a.rotate_inplace(module, t, scratch_1); + module.glwe_rotate_inplace(t, a, scratch_1); } else { - a.rsh(module, 1, scratch); + module.glwe_rsh(1, a, scratch); // a = a + phi(a) - a.automorphism_add_inplace(module, auto_key, scratch); + module.glwe_automorphism_add_inplace(a, auto_key, scratch); } } else if let Some(b) = b.as_deref_mut() { let t: i64 = 1 << (b.n().log2() - i - 1); - let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(b); - tmp_b.rotate(module, t, b); - tmp_b.rsh(module, 1, scratch_1); + let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(module, b); + module.glwe_rotate(t, &mut tmp_b, b); + module.glwe_rsh(1, &mut tmp_b, scratch_1); // a = (b* X^t - phi(b* X^t)) - b.automorphism_sub_negate(module, &tmp_b, auto_key, scratch_1); + module.glwe_automorphism_sub_negate(b, &tmp_b, auto_key, scratch_1); } } diff --git a/poulpy-core/src/glwe_trace.rs b/poulpy-core/src/glwe_trace.rs index 36aabb9..48608d1 100644 --- a/poulpy-core/src/glwe_trace.rs +++ b/poulpy-core/src/glwe_trace.rs @@ -1,173 +1,188 @@ use std::collections::HashMap; use poulpy_hal::{ - api::{ - ScratchAvailable, VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, - VecZnxCopy, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, - VecZnxRshInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, - }, - layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx}, + api::ModuleLogN, + layouts::{Backend, DataMut, GaloisElement, Module, Scratch, VecZnx}, }; use crate::{ - layouts::{Base2K, GGLWEInfos, GLWE, GLWEInfos, GLWELayout, LWEInfos, prepared::AutomorphismKeyPrepared}, - operations::GLWEOperations, + GLWEAutomorphism, GLWECopy, GLWEShift, ScratchTakeCore, + layouts::{ + Base2K, GGLWEInfos, GLWE, GLWEInfos, GLWELayout, GLWEToMut, GLWEToRef, LWEInfos, + prepared::{AutomorphismKeyPreparedToRef, GetAutomorphismGaloisElement}, + }, }; impl GLWE> { - pub fn trace_galois_elements(module: &Module) -> Vec { - let mut gal_els: Vec = Vec::new(); - (0..module.log_n()).for_each(|i| { - if i == 0 { - gal_els.push(-1); - } else { - gal_els.push(module.galois_element(1 << (i - 1))); - } - }); - gal_els + pub fn trace_galois_elements(module: &M) -> Vec + where + M: GLWETrace, + { + module.glwe_trace_galois_elements() } - pub fn trace_tmp_bytes(module: &Module, out_infos: &OUT, in_infos: &IN, key_infos: &KEY) -> usize + pub fn trace_tmp_bytes(module: &M, res_infos: &R, a_infos: &A, key_infos: &K) -> usize where - OUT: GLWEInfos, - IN: GLWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, + R: GLWEInfos, + A: GLWEInfos, + K: GGLWEInfos, + M: GLWETrace, { - let trace: usize = Self::automorphism_inplace_tmp_bytes(module, out_infos, key_infos); - if in_infos.base2k() != key_infos.base2k() { + module.glwe_automorphism_tmp_bytes(res_infos, a_infos, key_infos) + } +} + +impl GLWE { + pub fn trace( + &mut self, + module: &M, + start: usize, + end: usize, + a: &A, + keys: &HashMap, + scratch: &mut Scratch, + ) where + A: GLWEToRef, + K: AutomorphismKeyPreparedToRef + GGLWEInfos + GetAutomorphismGaloisElement, + Scratch: ScratchTakeCore, + M: GLWETrace, + { + module.glwe_trace(self, start, end, a, keys, scratch); + } + + pub fn trace_inplace( + &mut self, + module: &M, + start: usize, + end: usize, + keys: &HashMap, + scratch: &mut Scratch, + ) where + K: AutomorphismKeyPreparedToRef + GGLWEInfos + GetAutomorphismGaloisElement, + Scratch: ScratchTakeCore, + M: GLWETrace, + { + module.glwe_trace_inplace(self, start, end, keys, scratch); + } +} + +impl GLWETrace for Module where + Self: ModuleLogN + GaloisElement + GLWEAutomorphism + GLWEShift + GLWECopy +{ +} + +pub trait GLWETrace +where + Self: ModuleLogN + GaloisElement + GLWEAutomorphism + GLWEShift + GLWECopy, +{ + fn glwe_trace_galois_elements(&self) -> Vec { + (0..self.log_n()) + .map(|i| { + if i == 0 { + -1 + } else { + self.galois_element(1 << (i - 1)) + } + }) + .collect() + } + + fn glwe_trace_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize + where + R: GLWEInfos, + A: GLWEInfos, + K: GGLWEInfos, + { + let trace: usize = self.glwe_automorphism_tmp_bytes(res_infos, a_infos, key_infos); + if a_infos.base2k() != key_infos.base2k() { let glwe_conv: usize = VecZnx::bytes_of( - module.n(), + self.n(), (key_infos.rank_out() + 1).into(), - out_infos.k().min(in_infos.k()).div_ceil(key_infos.base2k()) as usize, - ) + module.vec_znx_normalize_tmp_bytes(); + res_infos.k().min(a_infos.k()).div_ceil(key_infos.base2k()) as usize, + ) + self.vec_znx_normalize_tmp_bytes(); return glwe_conv + trace; } trace } - pub fn trace_inplace_tmp_bytes(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize + fn glwe_trace(&self, res: &mut R, start: usize, end: usize, a: &A, keys: &HashMap, scratch: &mut Scratch) where - OUT: GLWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, + R: GLWEToMut, + A: GLWEToRef, + K: AutomorphismKeyPreparedToRef + GGLWEInfos + GetAutomorphismGaloisElement, + Scratch: ScratchTakeCore, { - Self::trace_tmp_bytes(module, out_infos, out_infos, key_infos) - } -} - -impl GLWE { - pub fn trace( - &mut self, - module: &Module, - start: usize, - end: usize, - lhs: &GLWE, - auto_keys: &HashMap>, - scratch: &mut Scratch, - ) where - Module: VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxBigAutomorphismInplace - + VecZnxRshInplace - + VecZnxCopy - + VecZnxNormalizeTmpBytes - + VecZnxNormalize, - Scratch: ScratchAvailable, - { - self.copy(module, lhs); - self.trace_inplace(module, start, end, auto_keys, scratch); + self.glwe_copy(res, a); + self.glwe_trace_inplace(res, start, end, keys, scratch); } - pub fn trace_inplace( - &mut self, - module: &Module, - start: usize, - end: usize, - auto_keys: &HashMap>, - scratch: &mut Scratch, - ) where - Module: VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxBigAutomorphismInplace - + VecZnxRshInplace - + VecZnxNormalizeTmpBytes - + VecZnxNormalize, - Scratch: ScratchAvailable, + fn glwe_trace_inplace(&self, res: &mut R, start: usize, end: usize, keys: &HashMap, scratch: &mut Scratch) + where + R: GLWEToMut, + K: AutomorphismKeyPreparedToRef + GGLWEInfos + GetAutomorphismGaloisElement, + Scratch: ScratchTakeCore, { - let basek_ksk: Base2K = auto_keys - .get(auto_keys.keys().next().unwrap()) - .unwrap() - .base2k(); + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + + let basek_ksk: Base2K = keys.get(keys.keys().next().unwrap()).unwrap().base2k(); #[cfg(debug_assertions)] { - assert_eq!(self.n(), module.n() as u32); + assert_eq!(res.n(), self.n() as u32); assert!(start < end); - assert!(end <= module.log_n()); - for key in auto_keys.values() { - assert_eq!(key.n(), module.n() as u32); + assert!(end <= self.log_n()); + for key in keys.values() { + assert_eq!(key.n(), self.n() as u32); assert_eq!(key.base2k(), basek_ksk); - assert_eq!(key.rank_in(), self.rank()); - assert_eq!(key.rank_out(), self.rank()); + assert_eq!(key.rank_in(), res.rank()); + assert_eq!(key.rank_out(), res.rank()); } } - if self.base2k() != basek_ksk { - let (mut self_conv, scratch_1) = scratch.take_glwe_ct(&GLWELayout { - n: module.n().into(), - base2k: basek_ksk, - k: self.k(), - rank: self.rank(), - }); + if res.base2k() != basek_ksk { + let (mut self_conv, scratch_1) = scratch.take_glwe_ct( + self, + &GLWELayout { + n: self.n().into(), + base2k: basek_ksk, + k: res.k(), + rank: res.rank(), + }, + ); - for j in 0..(self.rank() + 1).into() { - module.vec_znx_normalize( + for j in 0..(res.rank() + 1).into() { + self.vec_znx_normalize( basek_ksk.into(), &mut self_conv.data, j, basek_ksk.into(), - &self.data, + res.data(), j, scratch_1, ); } for i in start..end { - self_conv.rsh(module, 1, scratch_1); + self.glwe_rsh(1, &mut self_conv, scratch_1); let p: i64 = if i == 0 { -1 } else { - module.galois_element(1 << (i - 1)) + self.galois_element(1 << (i - 1)) }; - if let Some(key) = auto_keys.get(&p) { - self_conv.automorphism_add_inplace(module, key, scratch_1); + if let Some(key) = keys.get(&p) { + self.glwe_automorphism_add_inplace(&mut self_conv, key, scratch_1); } else { - panic!("auto_keys[{p}] is empty") + panic!("keys[{p}] is empty") } } - for j in 0..(self.rank() + 1).into() { - module.vec_znx_normalize( - self.base2k().into(), - &mut self.data, + for j in 0..(res.rank() + 1).into() { + self.vec_znx_normalize( + res.base2k().into(), + res.data_mut(), j, basek_ksk.into(), &self_conv.data, @@ -177,18 +192,18 @@ impl GLWE { } } else { for i in start..end { - self.rsh(module, 1, scratch); + self.glwe_rsh(1, res, scratch); let p: i64 = if i == 0 { -1 } else { - module.galois_element(1 << (i - 1)) + self.galois_element(1 << (i - 1)) }; - if let Some(key) = auto_keys.get(&p) { - self.automorphism_add_inplace(module, key, scratch); + if let Some(key) = keys.get(&p) { + self.glwe_automorphism_add_inplace(res, key, scratch); } else { - panic!("auto_keys[{p}] is empty") + panic!("keys[{p}] is empty") } } } diff --git a/poulpy-core/src/keyswitching/gglwe_ct.rs b/poulpy-core/src/keyswitching/gglwe_ct.rs index edda267..4b92ed1 100644 --- a/poulpy-core/src/keyswitching/gglwe_ct.rs +++ b/poulpy-core/src/keyswitching/gglwe_ct.rs @@ -2,7 +2,7 @@ use poulpy_hal::layouts::{Backend, DataMut, Module, Scratch}; use crate::{ ScratchTakeCore, - keyswitching::glwe_ct::GLWEKeySwitch, + keyswitching::glwe_ct::GLWEKeyswitch, layouts::{ AutomorphismKey, AutomorphismKeyToRef, GGLWE, GGLWEInfos, GGLWEToMut, GGLWEToRef, GLWESwitchingKey, GLWESwitchingKeyToRef, @@ -16,7 +16,7 @@ impl AutomorphismKey> { R: GGLWEInfos, A: GGLWEInfos, K: GGLWEInfos, - M: GGLWEKeySwitch, + M: GGLWEKeyswitch, { module.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos) } @@ -28,7 +28,7 @@ impl AutomorphismKey { A: AutomorphismKeyToRef, B: GLWESwitchingKeyPreparedToRef, Scratch: ScratchTakeCore, - M: GGLWEKeySwitch, + M: GGLWEKeyswitch, { module.gglwe_keyswitch(&mut self.key.key, &a.to_ref().key.key, b, scratch); } @@ -37,7 +37,7 @@ impl AutomorphismKey { where A: GLWESwitchingKeyPreparedToRef, Scratch: ScratchTakeCore, - M: GGLWEKeySwitch, + M: GGLWEKeyswitch, { module.gglwe_keyswitch_inplace(&mut self.key.key, a, scratch); } @@ -49,7 +49,7 @@ impl GLWESwitchingKey> { R: GGLWEInfos, A: GGLWEInfos, K: GGLWEInfos, - M: GGLWEKeySwitch, + M: GGLWEKeyswitch, { module.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos) } @@ -61,7 +61,7 @@ impl GLWESwitchingKey { A: GLWESwitchingKeyToRef, B: GLWESwitchingKeyPreparedToRef, Scratch: ScratchTakeCore, - M: GGLWEKeySwitch, + M: GGLWEKeyswitch, { module.gglwe_keyswitch(&mut self.key, &a.to_ref().key, b, scratch); } @@ -70,7 +70,7 @@ impl GLWESwitchingKey { where A: GLWESwitchingKeyPreparedToRef, Scratch: ScratchTakeCore, - M: GGLWEKeySwitch, + M: GGLWEKeyswitch, { module.gglwe_keyswitch_inplace(&mut self.key, a, scratch); } @@ -82,7 +82,7 @@ impl GGLWE> { R: GGLWEInfos, A: GGLWEInfos, K: GGLWEInfos, - M: GGLWEKeySwitch, + M: GGLWEKeyswitch, { module.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos) } @@ -94,7 +94,7 @@ impl GGLWE { A: GGLWEToRef, B: GLWESwitchingKeyPreparedToRef, Scratch: ScratchTakeCore, - M: GGLWEKeySwitch, + M: GGLWEKeyswitch, { module.gglwe_keyswitch(self, a, b, scratch); } @@ -103,17 +103,17 @@ impl GGLWE { where A: GLWESwitchingKeyPreparedToRef, Scratch: ScratchTakeCore, - M: GGLWEKeySwitch, + M: GGLWEKeyswitch, { module.gglwe_keyswitch_inplace(self, a, scratch); } } -impl GGLWEKeySwitch for Module where Self: GLWEKeySwitch {} +impl GGLWEKeyswitch for Module where Self: GLWEKeyswitch {} -pub trait GGLWEKeySwitch +pub trait GGLWEKeyswitch where - Self: GLWEKeySwitch, + Self: GLWEKeyswitch, { fn gglwe_keyswitch_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize where diff --git a/poulpy-core/src/keyswitching/ggsw_ct.rs b/poulpy-core/src/keyswitching/ggsw_ct.rs index cfb4d8e..67f4278 100644 --- a/poulpy-core/src/keyswitching/ggsw_ct.rs +++ b/poulpy-core/src/keyswitching/ggsw_ct.rs @@ -2,7 +2,7 @@ use poulpy_hal::layouts::{Backend, DataMut, Scratch, VecZnx}; use crate::{ GGSWExpandRows, ScratchTakeCore, - keyswitching::glwe_ct::GLWEKeySwitch, + keyswitching::glwe_ct::GLWEKeyswitch, layouts::{ GGLWEInfos, GGSW, GGSWInfos, GGSWToMut, GGSWToRef, prepared::{GLWESwitchingKeyPreparedToRef, TensorKeyPreparedToRef}, @@ -22,7 +22,7 @@ impl GGSW> { A: GGSWInfos, K: GGLWEInfos, T: GGLWEInfos, - M: GGSWKeySwitch, + M: GGSWKeyswitch, { module.ggsw_keyswitch_tmp_bytes(res_infos, a_infos, key_infos, tsk_infos) } @@ -35,7 +35,7 @@ impl GGSW { K: GLWESwitchingKeyPreparedToRef, T: TensorKeyPreparedToRef, Scratch: ScratchTakeCore, - M: GGSWKeySwitch, + M: GGSWKeyswitch, { module.ggsw_keyswitch(self, a, key, tsk, scratch); } @@ -45,15 +45,15 @@ impl GGSW { K: GLWESwitchingKeyPreparedToRef, T: TensorKeyPreparedToRef, Scratch: ScratchTakeCore, - M: GGSWKeySwitch, + M: GGSWKeyswitch, { module.ggsw_keyswitch_inplace(self, key, tsk, scratch); } } -pub trait GGSWKeySwitch +pub trait GGSWKeyswitch where - Self: GLWEKeySwitch + GGSWExpandRows, + Self: GLWEKeyswitch + GGSWExpandRows, { fn ggsw_keyswitch_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &K, tsk_infos: &T) -> usize where @@ -127,5 +127,3 @@ where self.ggsw_expand_row(res, tsk, scratch); } } - -impl GGSW {} diff --git a/poulpy-core/src/keyswitching/glwe_ct.rs b/poulpy-core/src/keyswitching/glwe_ct.rs index 6d7bff9..f82a4d1 100644 --- a/poulpy-core/src/keyswitching/glwe_ct.rs +++ b/poulpy-core/src/keyswitching/glwe_ct.rs @@ -16,14 +16,14 @@ use crate::{ }; impl GLWE> { - pub fn keyswitch_tmp_bytes(module: &M, res_infos: &R, a_infos: &A, b_infos: &B) -> usize + pub fn keyswitch_tmp_bytes(module: &M, res_infos: &R, a_infos: &A, key_infos: &B) -> usize where R: GLWEInfos, A: GLWEInfos, B: GGLWEInfos, - M: GLWEKeySwitch, + M: GLWEKeyswitch, { - module.glwe_keyswitch_tmp_bytes(res_infos, a_infos, b_infos) + module.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos) } } @@ -32,7 +32,7 @@ impl GLWE { where A: GLWEToRef, B: GLWESwitchingKeyPreparedToRef, - M: GLWEKeySwitch, + M: GLWEKeyswitch, Scratch: ScratchTakeCore, { module.glwe_keyswitch(self, a, b, scratch); @@ -41,14 +41,14 @@ impl GLWE { pub fn keyswitch_inplace(&mut self, module: &M, a: &A, scratch: &mut Scratch) where A: GLWESwitchingKeyPreparedToRef, - M: GLWEKeySwitch, + M: GLWEKeyswitch, Scratch: ScratchTakeCore, { module.glwe_keyswitch_inplace(self, a, scratch); } } -impl GLWEKeySwitch for Module where +impl GLWEKeyswitch for Module where Self: Sized + ModuleN + VecZnxDftBytesOf @@ -69,7 +69,7 @@ impl GLWEKeySwitch for Module where { } -pub trait GLWEKeySwitch +pub trait GLWEKeyswitch where Self: Sized + ModuleN @@ -89,7 +89,7 @@ where + VecZnxNormalize + VecZnxNormalizeTmpBytes, { - fn glwe_keyswitch_tmp_bytes(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize + fn glwe_keyswitch_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &B) -> usize where R: GLWEInfos, A: GLWEInfos, @@ -97,44 +97,44 @@ where { let in_size: usize = a_infos .k() - .div_ceil(b_infos.base2k()) - .div_ceil(b_infos.dsize().into()) as usize; + .div_ceil(key_infos.base2k()) + .div_ceil(key_infos.dsize().into()) as usize; let out_size: usize = res_infos.size(); - let ksk_size: usize = b_infos.size(); - let res_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank_out() + 1).into(), ksk_size); // TODO OPTIMIZE - let ai_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank_in()).into(), in_size); + let ksk_size: usize = key_infos.size(); + let res_dft: usize = self.bytes_of_vec_znx_dft((key_infos.rank_out() + 1).into(), ksk_size); // TODO OPTIMIZE + let ai_dft: usize = self.bytes_of_vec_znx_dft((key_infos.rank_in()).into(), in_size); let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes( out_size, in_size, in_size, - (b_infos.rank_in()).into(), - (b_infos.rank_out() + 1).into(), + (key_infos.rank_in()).into(), + (key_infos.rank_out() + 1).into(), ksk_size, - ) + self.bytes_of_vec_znx_dft((b_infos.rank_in()).into(), in_size); + ) + self.bytes_of_vec_znx_dft((key_infos.rank_in()).into(), in_size); let normalize_big: usize = self.vec_znx_big_normalize_tmp_bytes(); - if a_infos.base2k() == b_infos.base2k() { + if a_infos.base2k() == key_infos.base2k() { res_dft + ((ai_dft + vmp) | normalize_big) - } else if b_infos.dsize() == 1 { + } else if key_infos.dsize() == 1 { // In this case, we only need one column, temporary, that we can drop once a_dft is computed. let normalize_conv: usize = VecZnx::bytes_of(self.n(), 1, in_size) + self.vec_znx_normalize_tmp_bytes(); res_dft + (((ai_dft + normalize_conv) | vmp) | normalize_big) } else { // Since we stride over a to get a_dft when dsize > 1, we need to store the full columns of a with in the base conversion. - let normalize_conv: usize = VecZnx::bytes_of(self.n(), (b_infos.rank_in()).into(), in_size); + let normalize_conv: usize = VecZnx::bytes_of(self.n(), (key_infos.rank_in()).into(), in_size); res_dft + ((ai_dft + normalize_conv + (self.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big) } } - fn glwe_keyswitch(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) + fn glwe_keyswitch(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) where R: GLWEToMut, A: GLWEToRef, - B: GLWESwitchingKeyPreparedToRef, + K: GLWESwitchingKeyPreparedToRef, Scratch: ScratchTakeCore, { let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let a: &GLWE<&[u8]> = &a.to_ref(); - let b: &GLWESwitchingKeyPrepared<&[u8], BE> = &b.to_ref(); + let b: &GLWESwitchingKeyPrepared<&[u8], BE> = &key.to_ref(); assert_eq!( a.rank(), @@ -181,14 +181,14 @@ where }) } - fn glwe_keyswitch_inplace(&self, res: &mut R, a: &A, scratch: &mut Scratch) + fn glwe_keyswitch_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) where R: GLWEToMut, - A: GLWESwitchingKeyPreparedToRef, + K: GLWESwitchingKeyPreparedToRef, Scratch: ScratchTakeCore, { let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); - let a: &GLWESwitchingKeyPrepared<&[u8], BE> = &a.to_ref(); + let a: &GLWESwitchingKeyPrepared<&[u8], BE> = &key.to_ref(); assert_eq!( res.rank(), @@ -239,11 +239,11 @@ impl GLWE> {} impl GLWE {} -fn keyswitch_internal( +pub(crate) fn keyswitch_internal( module: &M, mut res: VecZnxDft, a: &GLWE, - b: &GLWESwitchingKeyPrepared, + key: &GLWESwitchingKeyPrepared, scratch: &mut Scratch, ) -> VecZnxBig where @@ -265,12 +265,12 @@ where Scratch: ScratchTakeCore, { let base2k_in: usize = a.base2k().into(); - let base2k_out: usize = b.base2k().into(); + let base2k_out: usize = key.base2k().into(); let cols: usize = (a.rank() + 1).into(); let a_size: usize = (a.size() * base2k_in).div_ceil(base2k_out); - let pmat: &VmpPMat = &b.key.data; + let pmat: &VmpPMat = &key.key.data; - if b.dsize() == 1 { + if key.dsize() == 1 { let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols - 1, a.size()); if base2k_in == base2k_out { @@ -295,7 +295,7 @@ where module.vmp_apply_dft_to_dft(&mut res, &ai_dft, pmat, scratch_1); } else { - let dsize: usize = b.dsize().into(); + let dsize: usize = key.dsize().into(); let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols - 1, a_size.div_ceil(dsize)); ai_dft.data_mut().fill(0); diff --git a/poulpy-core/src/keyswitching/lwe_ct.rs b/poulpy-core/src/keyswitching/lwe_ct.rs index 8546ccb..ff9fb5f 100644 --- a/poulpy-core/src/keyswitching/lwe_ct.rs +++ b/poulpy-core/src/keyswitching/lwe_ct.rs @@ -5,7 +5,7 @@ use poulpy_hal::{ use crate::{ ScratchTakeCore, - keyswitching::glwe_ct::GLWEKeySwitch, + keyswitching::glwe_ct::GLWEKeyswitch, layouts::{ GGLWEInfos, GLWE, GLWEAlloc, GLWELayout, LWE, LWEInfos, LWEToMut, LWEToRef, Rank, TorusPrecision, prepared::{LWESwitchingKeyPrepared, LWESwitchingKeyPreparedToRef}, @@ -40,7 +40,7 @@ impl LWEKeySwitch for Module where Self: LWEKeySwitch { pub trait LWEKeySwitch where - Self: GLWEKeySwitch + GLWEAlloc, + Self: GLWEKeyswitch + GLWEAlloc, { fn lwe_keyswitch_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize where diff --git a/poulpy-core/src/keyswitching/mod.rs b/poulpy-core/src/keyswitching/mod.rs index c6a3610..7071680 100644 --- a/poulpy-core/src/keyswitching/mod.rs +++ b/poulpy-core/src/keyswitching/mod.rs @@ -2,3 +2,8 @@ mod gglwe_ct; mod ggsw_ct; mod glwe_ct; mod lwe_ct; + +pub use gglwe_ct::*; +// pub use gglwe_ct::*; +pub use glwe_ct::*; +pub use lwe_ct::*; diff --git a/poulpy-core/src/layouts/gglwe_atk.rs b/poulpy-core/src/layouts/gglwe_atk.rs index eb93bf4..3c0afc1 100644 --- a/poulpy-core/src/layouts/gglwe_atk.rs +++ b/poulpy-core/src/layouts/gglwe_atk.rs @@ -6,6 +6,7 @@ use poulpy_hal::{ use crate::layouts::{ Base2K, Dnum, Dsize, GGLWEInfos, GLWE, GLWEInfos, GLWESwitchingKey, GLWESwitchingKeyAlloc, GLWESwitchingKeyToMut, GLWESwitchingKeyToRef, LWEInfos, Rank, RingDegree, TorusPrecision, + prepared::{GetAutomorphismGaloisElement, SetAutomorphismGaloisElement}, }; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; @@ -27,6 +28,18 @@ pub struct AutomorphismKey { pub(crate) p: i64, } +impl SetAutomorphismGaloisElement for AutomorphismKey { + fn set_p(&mut self, p: i64) { + self.p = p + } +} + +impl GetAutomorphismGaloisElement for AutomorphismKey { + fn p(&self) -> i64 { + self.p + } +} + impl AutomorphismKey { pub fn p(&self) -> i64 { self.p diff --git a/poulpy-core/src/lib.rs b/poulpy-core/src/lib.rs index 15e6c76..2dcc77a 100644 --- a/poulpy-core/src/lib.rs +++ b/poulpy-core/src/lib.rs @@ -14,10 +14,12 @@ mod utils; pub use operations::*; pub mod layouts; +pub use automorphism::*; pub use conversion::*; pub use dist::*; pub use external_product::*; pub use glwe_packing::*; +pub use keyswitching::*; pub use encryption::SIGMA; diff --git a/poulpy-core/src/operations/glwe.rs b/poulpy-core/src/operations/glwe.rs index b8b32ce..3e507e2 100644 --- a/poulpy-core/src/operations/glwe.rs +++ b/poulpy-core/src/operations/glwe.rs @@ -1,320 +1,292 @@ use poulpy_hal::{ api::{ - VecZnxAdd, VecZnxAddInplace, VecZnxCopy, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegateInplace, + ModuleN, VecZnxAdd, VecZnxAddInplace, VecZnxCopy, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubInplace, VecZnxSubNegateInplace, }, - layouts::{Backend, DataMut, Module, Scratch, VecZnx, ZnxZero}, + layouts::{Backend, Module, Scratch, VecZnx, ZnxZero}, }; -use crate::layouts::{GLWE, GLWEInfos, GLWEPlaintext, GLWEToMut, GLWEToRef, LWEInfos, SetGLWEInfos, TorusPrecision}; +use crate::{ + ScratchTakeCore, + layouts::{GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, SetGLWEInfos, TorusPrecision}, +}; -impl GLWEOperations for GLWEPlaintext +pub trait GLWEAdd where - D: DataMut, - GLWEPlaintext: GLWEToMut + GLWEInfos, + Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace, { + fn glwe_add(&self, res: &mut R, a: &A, b: &B) + where + R: GLWEToMut, + A: GLWEToRef, + B: GLWEToRef, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &mut GLWE<&[u8]> = &mut a.to_ref(); + let b: &GLWE<&[u8]> = &b.to_ref(); + + assert_eq!(a.n(), self.n() as u32); + assert_eq!(b.n(), self.n() as u32); + assert_eq!(res.n(), self.n() as u32); + assert_eq!(a.base2k(), b.base2k()); + assert!(res.rank() >= a.rank().max(b.rank())); + + let min_col: usize = (a.rank().min(b.rank()) + 1).into(); + let max_col: usize = (a.rank().max(b.rank() + 1)).into(); + let self_col: usize = (res.rank() + 1).into(); + + (0..min_col).for_each(|i| { + self.vec_znx_add(res.data_mut(), i, a.data(), i, b.data(), i); + }); + + if a.rank() > b.rank() { + (min_col..max_col).for_each(|i| { + self.vec_znx_copy(res.data_mut(), i, a.data(), i); + }); + } else { + (min_col..max_col).for_each(|i| { + self.vec_znx_copy(res.data_mut(), i, b.data(), i); + }); + } + + let size: usize = res.size(); + (max_col..self_col).for_each(|i| { + (0..size).for_each(|j| { + res.data.zero_at(i, j); + }); + }); + + res.set_base2k(a.base2k()); + res.set_k(set_k_binary(res, a, b)); + } + + fn glwe_add_inplace(&self, res: &mut R, a: &A) + where + R: GLWEToMut, + A: GLWEToRef, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); + + assert_eq!(res.n(), self.n() as u32); + assert_eq!(a.n(), self.n() as u32); + assert_eq!(res.base2k(), a.base2k()); + assert!(res.rank() >= a.rank()); + + (0..(a.rank() + 1).into()).for_each(|i| { + self.vec_znx_add_inplace(res.data_mut(), i, a.data(), i); + }); + + res.set_k(set_k_unary(res, a)) + } } -impl GLWEOperations for GLWE where GLWE: GLWEToMut + GLWEInfos {} +impl GLWEAdd for Module where Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace {} -pub trait GLWEOperations: GLWEToMut + GLWEInfos + SetGLWEInfos + Sized { - fn add(&mut self, module: &Module, a: &A, b: &B) +pub trait GLWESub +where + Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegateInplace + VecZnxSubInplace + VecZnxSubNegateInplace, +{ + fn glwe_sub(&self, res: &mut R, a: &A, b: &B) where - A: GLWEToRef + GLWEInfos, - B: GLWEToRef + GLWEInfos, - Module: VecZnxAdd + VecZnxCopy, + R: GLWEToMut, + A: GLWEToRef, + B: GLWEToRef, { - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(b.n(), self.n()); - assert_eq!(a.base2k(), b.base2k()); - assert!(self.rank() >= a.rank().max(b.rank())); - } + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); + let b: &GLWE<&[u8]> = &b.to_ref(); + + assert_eq!(a.n(), self.n() as u32); + assert_eq!(b.n(), self.n() as u32); + assert_eq!(a.base2k(), b.base2k()); + assert!(res.rank() >= a.rank().max(b.rank())); let min_col: usize = (a.rank().min(b.rank()) + 1).into(); let max_col: usize = (a.rank().max(b.rank() + 1)).into(); - let self_col: usize = (self.rank() + 1).into(); - - let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut(); - let a_ref: &GLWE<&[u8]> = &a.to_ref(); - let b_ref: &GLWE<&[u8]> = &b.to_ref(); + let self_col: usize = (res.rank() + 1).into(); (0..min_col).for_each(|i| { - module.vec_znx_add(&mut self_mut.data, i, &a_ref.data, i, &b_ref.data, i); + self.vec_znx_sub(res.data_mut(), i, a.data(), i, b.data(), i); }); if a.rank() > b.rank() { (min_col..max_col).for_each(|i| { - module.vec_znx_copy(&mut self_mut.data, i, &a_ref.data, i); + self.vec_znx_copy(res.data_mut(), i, a.data(), i); }); } else { (min_col..max_col).for_each(|i| { - module.vec_znx_copy(&mut self_mut.data, i, &b_ref.data, i); + self.vec_znx_copy(res.data_mut(), i, b.data(), i); + self.vec_znx_negate_inplace(res.data_mut(), i); }); } - let size: usize = self_mut.size(); + let size: usize = res.size(); (max_col..self_col).for_each(|i| { (0..size).for_each(|j| { - self_mut.data.zero_at(i, j); + res.data.zero_at(i, j); }); }); - self.set_base2k(a.base2k()); - self.set_k(set_k_binary(self, a, b)); + res.set_base2k(a.base2k()); + res.set_k(set_k_binary(res, a, b)); } - fn add_inplace(&mut self, module: &Module, a: &A) + fn glwe_sub_inplace(&self, res: &mut R, a: &A) where - A: GLWEToRef + GLWEInfos, - Module: VecZnxAddInplace, + R: GLWEToMut, + A: GLWEToRef, { - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(self.base2k(), a.base2k()); - assert!(self.rank() >= a.rank()) - } + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); - let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut(); - let a_ref: &GLWE<&[u8]> = &a.to_ref(); + assert_eq!(res.n(), self.n() as u32); + assert_eq!(a.n(), self.n() as u32); + assert_eq!(res.base2k(), a.base2k()); + assert!(res.rank() >= a.rank()); (0..(a.rank() + 1).into()).for_each(|i| { - module.vec_znx_add_inplace(&mut self_mut.data, i, &a_ref.data, i); + self.vec_znx_sub_inplace(res.data_mut(), i, a.data(), i); }); - self.set_k(set_k_unary(self, a)) + res.set_k(set_k_unary(res, a)) } - fn sub(&mut self, module: &Module, a: &A, b: &B) + fn glwe_sub_negate_inplace(&self, res: &mut R, a: &A) where - A: GLWEToRef + GLWEInfos, - B: GLWEToRef + GLWEInfos, - Module: VecZnxSub + VecZnxCopy + VecZnxNegateInplace, + R: GLWEToMut, + A: GLWEToRef, { - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(b.n(), self.n()); - assert_eq!(a.base2k(), b.base2k()); - assert!(self.rank() >= a.rank().max(b.rank())); - } + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); - let min_col: usize = (a.rank().min(b.rank()) + 1).into(); - let max_col: usize = (a.rank().max(b.rank() + 1)).into(); - let self_col: usize = (self.rank() + 1).into(); - - let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut(); - let a_ref: &GLWE<&[u8]> = &a.to_ref(); - let b_ref: &GLWE<&[u8]> = &b.to_ref(); - - (0..min_col).for_each(|i| { - module.vec_znx_sub(&mut self_mut.data, i, &a_ref.data, i, &b_ref.data, i); - }); - - if a.rank() > b.rank() { - (min_col..max_col).for_each(|i| { - module.vec_znx_copy(&mut self_mut.data, i, &a_ref.data, i); - }); - } else { - (min_col..max_col).for_each(|i| { - module.vec_znx_copy(&mut self_mut.data, i, &b_ref.data, i); - module.vec_znx_negate_inplace(&mut self_mut.data, i); - }); - } - - let size: usize = self_mut.size(); - (max_col..self_col).for_each(|i| { - (0..size).for_each(|j| { - self_mut.data.zero_at(i, j); - }); - }); - - self.set_base2k(a.base2k()); - self.set_k(set_k_binary(self, a, b)); - } - - fn sub_inplace_ab(&mut self, module: &Module, a: &A) - where - A: GLWEToRef + GLWEInfos, - Module: VecZnxSubInplace, - { - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(self.base2k(), a.base2k()); - assert!(self.rank() >= a.rank()) - } - - let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut(); - let a_ref: &GLWE<&[u8]> = &a.to_ref(); + assert_eq!(res.n(), self.n() as u32); + assert_eq!(a.n(), self.n() as u32); + assert_eq!(res.base2k(), a.base2k()); + assert!(res.rank() >= a.rank()); (0..(a.rank() + 1).into()).for_each(|i| { - module.vec_znx_sub_inplace(&mut self_mut.data, i, &a_ref.data, i); + self.vec_znx_sub_negate_inplace(res.data_mut(), i, a.data(), i); }); - self.set_k(set_k_unary(self, a)) + res.set_k(set_k_unary(res, a)) } +} - fn sub_inplace_ba(&mut self, module: &Module, a: &A) +pub trait GLWERotate +where + Self: ModuleN + VecZnxRotate + VecZnxRotateInplace, +{ + fn glwe_rotate(&self, k: i64, res: &mut R, a: &A) where - A: GLWEToRef + GLWEInfos, - Module: VecZnxSubNegateInplace, + R: GLWEToMut, + A: GLWEToRef, { - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(self.base2k(), a.base2k()); - assert!(self.rank() >= a.rank()) - } + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); - let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut(); - let a_ref: &GLWE<&[u8]> = &a.to_ref(); + assert_eq!(a.n(), self.n() as u32); + assert_eq!(res.rank(), a.rank()); (0..(a.rank() + 1).into()).for_each(|i| { - module.vec_znx_sub_negate_inplace(&mut self_mut.data, i, &a_ref.data, i); + self.vec_znx_rotate(k, res.data_mut(), i, a.data(), i); }); - self.set_k(set_k_unary(self, a)) + res.set_base2k(a.base2k()); + res.set_k(set_k_unary(res, a)) } - fn rotate(&mut self, module: &Module, k: i64, a: &A) + fn glwe_rotate_inplace(&self, k: i64, res: &mut R, scratch: &mut Scratch) where - A: GLWEToRef + GLWEInfos, - Module: VecZnxRotate, + R: GLWEToMut, + Scratch: ScratchTakeCore, { - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(self.rank(), a.rank()) + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + + (0..(res.rank() + 1).into()).for_each(|i| { + self.vec_znx_rotate_inplace(k, res.data_mut(), i, scratch); + }); + } +} + +pub trait GLWEMulXpMinusOne +where + Self: ModuleN + VecZnxMulXpMinusOne + VecZnxMulXpMinusOneInplace, +{ + fn glwe_mul_xp_minus_one(&self, k: i64, res: &mut R, a: &A) + where + R: GLWEToMut, + A: GLWEToRef, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); + + assert_eq!(res.n(), self.n() as u32); + assert_eq!(a.n(), self.n() as u32); + assert_eq!(res.rank(), a.rank()); + + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_mul_xp_minus_one(k, res.data_mut(), i, a.data(), i); } - let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut(); - let a_ref: &GLWE<&[u8]> = &a.to_ref(); - - (0..(a.rank() + 1).into()).for_each(|i| { - module.vec_znx_rotate(k, &mut self_mut.data, i, &a_ref.data, i); - }); - - self.set_base2k(a.base2k()); - self.set_k(set_k_unary(self, a)) + res.set_base2k(a.base2k()); + res.set_k(set_k_unary(res, a)) } - fn rotate_inplace(&mut self, module: &Module, k: i64, scratch: &mut Scratch) + fn glwe_mul_xp_minus_one_inplace(&self, k: i64, res: &mut R, scratch: &mut Scratch) where - Module: VecZnxRotateInplace, + R: GLWEToMut, { - let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut(); + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); - (0..(self_mut.rank() + 1).into()).for_each(|i| { - module.vec_znx_rotate_inplace(k, &mut self_mut.data, i, scratch); - }); + assert_eq!(res.n(), self.n() as u32); + + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_mul_xp_minus_one_inplace(k, res.data_mut(), i, scratch); + } } +} - fn mul_xp_minus_one(&mut self, module: &Module, k: i64, a: &A) +pub trait GLWECopy +where + Self: ModuleN + VecZnxCopy, +{ + fn glwe_copy(&self, res: &mut R, a: &A) where - A: GLWEToRef + GLWEInfos, - Module: VecZnxMulXpMinusOne, + R: GLWEToMut, + A: GLWEToRef, { - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(self.rank(), a.rank()) + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); + + assert_eq!(res.n(), self.n() as u32); + assert_eq!(a.n(), self.n() as u32); + assert_eq!(res.rank(), a.rank()); + + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_copy(res.data_mut(), i, a.data(), i); } - let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut(); - let a_ref: &GLWE<&[u8]> = &a.to_ref(); - - (0..(a.rank() + 1).into()).for_each(|i| { - module.vec_znx_mul_xp_minus_one(k, &mut self_mut.data, i, &a_ref.data, i); - }); - - self.set_base2k(a.base2k()); - self.set_k(set_k_unary(self, a)) + res.set_k(a.k().min(res.max_k())); + res.set_base2k(a.base2k()); } +} - fn mul_xp_minus_one_inplace(&mut self, module: &Module, k: i64, scratch: &mut Scratch) +pub trait GLWEShift +where + Self: ModuleN + VecZnxRshInplace, +{ + fn glwe_rsh(&self, k: usize, res: &mut R, scratch: &mut Scratch) where - Module: VecZnxMulXpMinusOneInplace, + R: GLWEToMut, + Scratch: ScratchTakeCore, { - let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut(); - - (0..(self_mut.rank() + 1).into()).for_each(|i| { - module.vec_znx_mul_xp_minus_one_inplace(k, &mut self_mut.data, i, scratch); - }); - } - - fn copy(&mut self, module: &M, a: &A) - where - A: GLWEToRef + GLWEInfos, - M: VecZnxCopy, - { - #[cfg(debug_assertions)] - { - assert_eq!(self.n(), a.n()); - assert_eq!(self.rank(), a.rank()); + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let base2k: usize = res.base2k().into(); + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_rsh_inplace(base2k, k, res.data_mut(), i, scratch); } - - let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut(); - let a_ref: &GLWE<&[u8]> = &a.to_ref(); - - (0..(self_mut.rank() + 1).into()).for_each(|i| { - module.vec_znx_copy(&mut self_mut.data, i, &a_ref.data, i); - }); - - self.set_k(a.k().min(self.max_k())); - self.set_base2k(a.base2k()); - } - - fn rsh(&mut self, module: &Module, k: usize, scratch: &mut Scratch) - where - Module: VecZnxRshInplace, - { - let base2k: usize = self.base2k().into(); - (0..(self.rank() + 1).into()).for_each(|i| { - module.vec_znx_rsh_inplace(base2k, k, &mut self.to_mut().data, i, scratch); - }) - } - - fn normalize(&mut self, module: &Module, a: &A, scratch: &mut Scratch) - where - A: GLWEToRef + GLWEInfos, - Module: VecZnxNormalize, - { - #[cfg(debug_assertions)] - { - assert_eq!(self.n(), a.n()); - assert_eq!(self.rank(), a.rank()); - } - - let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut(); - let a_ref: &GLWE<&[u8]> = &a.to_ref(); - - (0..(self_mut.rank() + 1).into()).for_each(|i| { - module.vec_znx_normalize( - a.base2k().into(), - &mut self_mut.data, - i, - a.base2k().into(), - &a_ref.data, - i, - scratch, - ); - }); - self.set_base2k(a.base2k()); - self.set_k(a.k().min(self.k())); - } - - fn normalize_inplace(&mut self, module: &Module, scratch: &mut Scratch) - where - Module: VecZnxNormalizeInplace, - { - let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut(); - (0..(self_mut.rank() + 1).into()).for_each(|i| { - module.vec_znx_normalize_inplace(self_mut.base2k().into(), &mut self_mut.data, i, scratch); - }); } } @@ -324,6 +296,50 @@ impl GLWE> { } } +pub trait GLWENormalize +where + Self: ModuleN + VecZnxNormalize + VecZnxNormalizeInplace, +{ + fn glwe_normalize(&self, res: &mut R, a: &A, scratch: &mut Scratch) + where + R: GLWEToMut, + A: GLWEToRef, + Scratch: ScratchTakeCore, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); + + assert_eq!(res.n(), self.n() as u32); + assert_eq!(a.n(), self.n() as u32); + assert_eq!(res.rank(), a.rank()); + + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_normalize( + res.base2k().into(), + res.data_mut(), + i, + a.base2k().into(), + a.data(), + i, + scratch, + ); + } + + res.set_k(a.k().min(res.k())); + } + + fn glwe_normalize_inplace(&self, res: &mut R, scratch: &mut Scratch) + where + R: GLWEToMut, + Scratch: ScratchTakeCore, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_normalize_inplace(res.base2k().into(), res.data_mut(), i, scratch); + } + } +} + // c = op(a, b) fn set_k_binary(c: &impl GLWEInfos, a: &impl GLWEInfos, b: &impl GLWEInfos) -> TorusPrecision { // If either operands is a ciphertext diff --git a/poulpy-hal/src/api/module.rs b/poulpy-hal/src/api/module.rs index 3dd6176..a18af44 100644 --- a/poulpy-hal/src/api/module.rs +++ b/poulpy-hal/src/api/module.rs @@ -8,3 +8,12 @@ pub trait ModuleNew { pub trait ModuleN { fn n(&self) -> usize; } + +pub trait ModuleLogN +where + Self: ModuleN, +{ + fn log_n(&self) -> usize { + (u64::BITS - (self.n() as u64 - 1).leading_zeros()) as usize + } +} diff --git a/poulpy-hal/src/layouts/module.rs b/poulpy-hal/src/layouts/module.rs index 61e312c..3382774 100644 --- a/poulpy-hal/src/layouts/module.rs +++ b/poulpy-hal/src/layouts/module.rs @@ -2,7 +2,10 @@ use std::{fmt::Display, marker::PhantomData, ptr::NonNull}; use rand_distr::num_traits::Zero; -use crate::GALOISGENERATOR; +use crate::{ + GALOISGENERATOR, + api::{ModuleLogN, ModuleN}, +}; #[allow(clippy::missing_safety_doc)] pub trait Backend: Sized { @@ -75,36 +78,49 @@ impl Module { pub fn log_n(&self) -> usize { (usize::BITS - (self.n() - 1).leading_zeros()) as _ } +} - #[inline] - pub fn cyclotomic_order(&self) -> u64 { +pub trait CyclotomicOrder +where + Self: ModuleN, +{ + fn cyclotomic_order(&self) -> i64 { (self.n() << 1) as _ } +} +impl ModuleLogN for Module where Self: ModuleN {} + +impl CyclotomicOrder for Module where Self: ModuleN {} + +pub trait GaloisElement +where + Self: CyclotomicOrder, +{ // Returns GALOISGENERATOR^|generator| * sign(generator) - #[inline] - pub fn galois_element(&self, generator: i64) -> i64 { + fn galois_element(&self, generator: i64) -> i64 { if generator == 0 { return 1; } - ((mod_exp_u64(GALOISGENERATOR, generator.unsigned_abs() as usize) & (self.cyclotomic_order() - 1)) as i64) - * generator.signum() + + let g_exp: u64 = mod_exp_u64(GALOISGENERATOR, generator.unsigned_abs() as usize) & (self.cyclotomic_order() - 1) as u64; + g_exp as i64 * generator.signum() } // Returns gen^-1 - #[inline] - pub fn galois_element_inv(&self, gal_el: i64) -> i64 { + fn galois_element_inv(&self, gal_el: i64) -> i64 { if gal_el == 0 { panic!("cannot invert 0") } - ((mod_exp_u64( - gal_el.unsigned_abs(), - (self.cyclotomic_order() - 1) as usize, - ) & (self.cyclotomic_order() - 1)) as i64) - * gal_el.signum() + + let g_exp: u64 = + mod_exp_u64(GALOISGENERATOR, (self.cyclotomic_order() - 1) as usize) & (self.cyclotomic_order() - 1) as u64; + g_exp as i64 * gal_el.signum() } } +impl GaloisElement for Module where Self: CyclotomicOrder {} + impl Drop for Module { fn drop(&mut self) { unsafe { B::destroy(self.ptr) }