diff --git a/rlwe/src/elem.rs b/rlwe/src/elem.rs index c943de1..98e2677 100644 --- a/rlwe/src/elem.rs +++ b/rlwe/src/elem.rs @@ -1,6 +1,13 @@ -use base2k::ZnxInfos; +use base2k::{ + Backend, FFT64, MatZnxDft, MatZnxDftToMut, MatZnxDftToRef, Module, Scratch, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxDftToMut, + VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxZero, +}; -use crate::utils::derive_size; +use crate::{ + elem_grlwe::GRLWECt, + elem_rlwe::{RLWECt, RLWECtDft}, + utils::derive_size, +}; pub trait Infos { type Inner: ZnxInfos; @@ -45,3 +52,162 @@ pub trait Infos { /// Returns the bit precision of the ciphertext. fn log_k(&self) -> usize; } + +pub trait GetRow { + fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut RLWECtDft) + where + VecZnxDft: VecZnxDftToMut; +} + +pub trait SetRow { + fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &RLWECtDft) + where + VecZnxDft: VecZnxDftToRef; +} + +pub(crate) trait MatZnxDftProducts: Infos +where + MatZnxDft: MatZnxDftToRef + ZnxInfos, +{ + fn mul_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef, + VecZnx: VecZnxToMut, + VecZnx: VecZnxToRef; + + fn mul_rlwe_inplace(&self, module: &Module, res: &mut RLWECt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + VecZnx: VecZnxToMut + VecZnxToRef, + { + unsafe { + let res_ptr: *mut RLWECt = res as *mut RLWECt; // This is ok because [Self::mul_rlwe] only updates res at the end. + self.mul_rlwe(&module, &mut *res_ptr, &*res_ptr, scratch); + } + } + + fn mul_rlwe_dft( + &self, + module: &Module, + res: &mut RLWECtDft, + a: &RLWECtDft, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef + ZnxInfos, + VecZnxDft: VecZnxDftToRef + ZnxInfos, + { + let log_base2k: usize = self.log_base2k(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.log_base2k(), log_base2k); + assert_eq!(self.n(), module.n()); + assert_eq!(res.n(), module.n()); + } + + let (a_data, scratch_1) = scratch.tmp_vec_znx(module, 2, a.size()); + + let mut a_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> { + data: a_data, + log_base2k: a.log_base2k(), + log_k: a.log_k(), + }; + + a.idft(module, &mut a_idft, scratch_1); + + let (res_data, scratch_2) = scratch_1.tmp_vec_znx(module, 2, res.size()); + + let mut res_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> { + data: res_data, + log_base2k: res.log_base2k(), + log_k: res.log_k(), + }; + + self.mul_rlwe(module, &mut res_idft, &a_idft, scratch_2); + + module.vec_znx_dft(res, 0, &res_idft, 0); + module.vec_znx_dft(res, 1, &res_idft, 1); + } + + fn mul_rlwe_dft_inplace(&self, module: &Module, res: &mut RLWECtDft, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + VecZnxDft: VecZnxDftToRef + VecZnxDftToMut, + { + let log_base2k: usize = self.log_base2k(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.log_base2k(), log_base2k); + assert_eq!(self.n(), module.n()); + assert_eq!(res.n(), module.n()); + } + + let (res_data, scratch_1) = scratch.tmp_vec_znx(module, 2, res.size()); + + let mut res_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> { + data: res_data, + log_base2k: res.log_base2k(), + log_k: res.log_k(), + }; + + res.idft(module, &mut res_idft, scratch_1); + + self.mul_rlwe_inplace(module, &mut res_idft, scratch_1); + + module.vec_znx_dft(res, 0, &res_idft, 0); + module.vec_znx_dft(res, 1, &res_idft, 1); + } + + fn mul_grlwe(&self, module: &Module, res: &mut GRLWECt, a: &GRLWECt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef + ZnxInfos, + MatZnxDft: MatZnxDftToRef + ZnxInfos, + { + let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, a.size()); + + let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { + data: tmp_row_data, + log_base2k: a.log_base2k(), + log_k: a.log_k(), + }; + + let min_rows: usize = res.rows().min(a.rows()); + + (0..min_rows).for_each(|row_i| { + a.get_row(module, row_i, &mut tmp_row); + self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); + res.set_row(module, row_i, &tmp_row); + }); + + tmp_row.data.zero(); + + (min_rows..res.rows()).for_each(|row_i| { + res.set_row(module, row_i, &tmp_row); + }) + } + + fn mul_grlwe_inplace(&self, module: &Module, res: &mut R, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + R: GetRow + SetRow + Infos, + { + let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, res.size()); + + let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { + data: tmp_row_data, + log_base2k: res.log_base2k(), + log_k: res.log_k(), + }; + + (0..self.cols()).for_each(|col_j| { + (0..res.rows()).for_each(|row_i| { + res.get_row(module, row_i, col_j, &mut tmp_row); + self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); + res.set_row(module, row_i, col_j, &tmp_row); + }); + }) + } +} diff --git a/rlwe/src/elem_grlwe.rs b/rlwe/src/elem_grlwe.rs index b865c1e..0567c07 100644 --- a/rlwe/src/elem_grlwe.rs +++ b/rlwe/src/elem_grlwe.rs @@ -7,7 +7,7 @@ use base2k::{ use sampling::source::Source; use crate::{ - elem::Infos, + elem::{GetRow, Infos, MatZnxDftProducts, SetRow}, elem_rlwe::{RLWECt, RLWECtDft, RLWEPt}, keys::SecretKeyDft, utils::derive_size, @@ -211,8 +211,106 @@ impl GRLWECt { } pub fn mul_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef, + VecZnx: VecZnxToMut, + VecZnx: VecZnxToRef, + { + MatZnxDftProducts::mul_rlwe(self, module, res, a, scratch); + } + + pub fn mul_rlwe_inplace(&self, module: &Module, res: &mut RLWECt, scratch: &mut Scratch) where MatZnxDft: MatZnxDftToRef + ZnxInfos, + VecZnx: VecZnxToMut + VecZnxToRef, + { + MatZnxDftProducts::mul_rlwe_inplace(self, module, res, scratch); + } + + pub fn mul_rlwe_dft( + &self, + module: &Module, + res: &mut RLWECtDft, + a: &RLWECtDft, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef + ZnxInfos, + VecZnxDft: VecZnxDftToRef + ZnxInfos, + { + MatZnxDftProducts::mul_rlwe_dft(self, module, res, a, scratch); + } + + pub fn mul_rlwe_dft_inplace(&self, module: &Module, res: &mut RLWECtDft, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + VecZnxDft: VecZnxDftToRef + VecZnxDftToMut, + { + MatZnxDftProducts::mul_rlwe_dft_inplace(self, module, res, scratch); + } + + pub fn mul_grlwe( + &self, + module: &Module, + res: &mut GRLWECt, + a: &GRLWECt, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef + ZnxInfos, + MatZnxDft: MatZnxDftToRef + ZnxInfos, + { + MatZnxDftProducts::mul_grlwe(self, module, res, a, scratch); + } + + pub fn mul_grlwe_inplace(&self, module: &Module, res: &mut R, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + R: GetRow + SetRow + Infos, + { + MatZnxDftProducts::mul_grlwe_inplace(self, module, res, scratch); + } +} + +impl GetRow for GRLWECt +where + MatZnxDft: MatZnxDftToRef, +{ + fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut RLWECtDft) + where + VecZnxDft: VecZnxDftToMut, + { + #[cfg(debug_assertions)] + { + assert_eq!(col_j, 0); + } + module.vmp_extract_row(res, self, row_i, col_j); + } +} + +impl SetRow for GRLWECt +where + MatZnxDft: MatZnxDftToMut, +{ + fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &RLWECtDft) + where + VecZnxDft: VecZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!(col_j, 0); + } + module.vmp_prepare_row(self, row_i, col_j, a); + } +} + +impl MatZnxDftProducts, C> for GRLWECt +where + MatZnxDft: MatZnxDftToRef + ZnxInfos, +{ + fn mul_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef, VecZnx: VecZnxToMut, VecZnx: VecZnxToRef, { @@ -242,143 +340,4 @@ impl GRLWECt { module.vec_znx_big_normalize(log_base2k, res, 0, &res_big, 0, scratch1); module.vec_znx_big_normalize(log_base2k, res, 1, &res_big, 1, scratch1); } - - pub fn mul_rlwe_inplace(&self, module: &Module, res: &mut RLWECt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - VecZnx: VecZnxToMut + VecZnxToRef, - { - unsafe { - let res_ptr: *mut RLWECt = res as *mut RLWECt; // This is ok because [Self::mul_rlwe] only updates res at the end. - self.mul_rlwe(&module, &mut *res_ptr, &*res_ptr, scratch); - } - } - - pub fn mul_rlwe_dft( - &self, - module: &Module, - res: &mut RLWECtDft, - a: &RLWECtDft, - scratch: &mut Scratch, - ) where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef + ZnxInfos, - VecZnxDft: VecZnxDftToRef + ZnxInfos, - { - let log_base2k: usize = self.log_base2k(); - - #[cfg(debug_assertions)] - { - assert_eq!(res.log_base2k(), log_base2k); - assert_eq!(self.n(), module.n()); - assert_eq!(res.n(), module.n()); - } - - let (a_data, scratch_1) = scratch.tmp_vec_znx(module, 2, a.size()); - - let mut a_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> { - data: a_data, - log_base2k: a.log_base2k(), - log_k: a.log_k(), - }; - - a.idft(module, &mut a_idft, scratch_1); - - let (res_data, scratch_2) = scratch_1.tmp_vec_znx(module, 2, res.size()); - - let mut res_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> { - data: res_data, - log_base2k: res.log_base2k(), - log_k: res.log_k(), - }; - - self.mul_rlwe(module, &mut res_idft, &a_idft, scratch_2); - - module.vec_znx_dft(res, 0, &res_idft, 0); - module.vec_znx_dft(res, 1, &res_idft, 1); - } - - pub fn mul_rlwe_dft_inplace(&self, module: &Module, res: &mut RLWECtDft, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef + ZnxInfos, - { - let log_base2k: usize = self.log_base2k(); - - #[cfg(debug_assertions)] - { - assert_eq!(res.log_base2k(), log_base2k); - assert_eq!(self.n(), module.n()); - assert_eq!(res.n(), module.n()); - } - - let (res_data, scratch_1) = scratch.tmp_vec_znx(module, 2, res.size()); - - let mut res_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> { - data: res_data, - log_base2k: res.log_base2k(), - log_k: res.log_k(), - }; - - res.idft(module, &mut res_idft, scratch_1); - - self.mul_rlwe_inplace(module, &mut res_idft, scratch_1); - - module.vec_znx_dft(res, 0, &res_idft, 0); - module.vec_znx_dft(res, 1, &res_idft, 1); - } - - pub fn mul_grlwe( - &self, - module: &Module, - res: &mut GRLWECt, - a: &GRLWECt, - scratch: &mut Scratch, - ) where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef + ZnxInfos, - MatZnxDft: MatZnxDftToRef + ZnxInfos, - { - let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, a.size()); - - let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { - data: tmp_row_data, - log_base2k: a.log_base2k(), - log_k: a.log_k(), - }; - - let min_rows: usize = res.rows().min(a.rows()); - - (0..min_rows).for_each(|row_i| { - a.get_row(module, row_i, &mut tmp_row); - self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); - res.set_row(module, row_i, &tmp_row); - }); - - tmp_row.data.zero(); - - (min_rows..res.rows()).for_each(|row_i| { - res.set_row(module, row_i, &tmp_row); - }) - } - - pub fn mul_grlwe_inplace(&self, module: &Module, res: &mut GRLWECt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef + ZnxInfos, - { - let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, res.size()); - - let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { - data: tmp_row_data, - log_base2k: res.log_base2k(), - log_k: res.log_k(), - }; - - (0..res.rows()).for_each(|row_i| { - res.get_row(module, row_i, &mut tmp_row); - self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); - res.set_row(module, row_i, &tmp_row); - }); - } } diff --git a/rlwe/src/elem_rgsw.rs b/rlwe/src/elem_rgsw.rs index 539aca0..beeeef9 100644 --- a/rlwe/src/elem_rgsw.rs +++ b/rlwe/src/elem_rgsw.rs @@ -7,7 +7,7 @@ use base2k::{ use sampling::source::Source; use crate::{ - elem::Infos, + elem::{GetRow, Infos, MatZnxDftProducts, SetRow}, elem_grlwe::GRLWECt, elem_rlwe::{RLWECt, RLWECtDft, RLWEPt, encrypt_rlwe_sk}, keys::SecretKeyDft, @@ -110,30 +110,6 @@ impl RGSWCt, FFT64> { } } -impl RGSWCt -where - MatZnxDft: MatZnxDftToRef, -{ - pub fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut RLWECtDft) - where - VecZnxDft: VecZnxDftToMut, - { - module.vmp_extract_row(res, self, row_i, col_j); - } -} - -impl RGSWCt -where - MatZnxDft: MatZnxDftToMut, -{ - pub fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &RLWECtDft) - where - VecZnxDft: VecZnxDftToRef, - { - module.vmp_prepare_row(self, row_i, col_j, a); - } -} - pub fn encrypt_rgsw_sk( module: &Module, ct: &mut RGSWCt, @@ -221,6 +197,96 @@ impl RGSWCt { } pub fn mul_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef, + VecZnx: VecZnxToMut, + VecZnx: VecZnxToRef, + { + MatZnxDftProducts::mul_rlwe(self, module, res, a, scratch); + } + + pub fn mul_rlwe_inplace(&self, module: &Module, res: &mut RLWECt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + VecZnx: VecZnxToMut + VecZnxToRef, + { + MatZnxDftProducts::mul_rlwe_inplace(self, module, res, scratch); + } + + pub fn mul_rlwe_dft( + &self, + module: &Module, + res: &mut RLWECtDft, + a: &RLWECtDft, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef + ZnxInfos, + VecZnxDft: VecZnxDftToRef + ZnxInfos, + { + MatZnxDftProducts::mul_rlwe_dft(self, module, res, a, scratch); + } + + pub fn mul_rlwe_dft_inplace(&self, module: &Module, res: &mut RLWECtDft, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + VecZnxDft: VecZnxDftToRef + VecZnxDftToMut, + { + MatZnxDftProducts::mul_rlwe_dft_inplace(self, module, res, scratch); + } + + pub fn mul_grlwe( + &self, + module: &Module, + res: &mut GRLWECt, + a: &GRLWECt, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef + ZnxInfos, + MatZnxDft: MatZnxDftToRef + ZnxInfos, + { + MatZnxDftProducts::mul_grlwe(self, module, res, a, scratch); + } + + pub fn mul_grlwe_inplace(&self, module: &Module, res: &mut R, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + R: GetRow + SetRow + Infos, + { + MatZnxDftProducts::mul_grlwe_inplace(self, module, res, scratch); + } +} + +impl GetRow for RGSWCt +where + MatZnxDft: MatZnxDftToRef, +{ + fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut RLWECtDft) + where + VecZnxDft: VecZnxDftToMut, + { + module.vmp_extract_row(res, self, row_i, col_j); + } +} + +impl SetRow for RGSWCt +where + MatZnxDft: MatZnxDftToMut, +{ + fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &RLWECtDft) + where + VecZnxDft: VecZnxDftToRef, + { + module.vmp_prepare_row(self, row_i, col_j, a); + } +} + +impl MatZnxDftProducts, C> for RGSWCt +where + MatZnxDft: MatZnxDftToRef + ZnxInfos, +{ + fn mul_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) where MatZnxDft: MatZnxDftToRef, VecZnx: VecZnxToMut, @@ -251,205 +317,4 @@ impl RGSWCt { module.vec_znx_big_normalize(log_base2k, res, 0, &res_big, 0, scratch1); module.vec_znx_big_normalize(log_base2k, res, 1, &res_big, 1, scratch1); } - - pub fn mul_rlwe_inplace(&self, module: &Module, res: &mut RLWECt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - VecZnx: VecZnxToMut + VecZnxToRef, - { - unsafe { - let res_ptr: *mut RLWECt = res as *mut RLWECt; // This is ok because [Self::mul_rlwe] only updates res at the end. - self.mul_rlwe(&module, &mut *res_ptr, &*res_ptr, scratch); - } - } - - pub fn mul_rlwe_dft( - &self, - module: &Module, - res: &mut RLWECtDft, - a: &RLWECtDft, - scratch: &mut Scratch, - ) where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef + ZnxInfos, - VecZnxDft: VecZnxDftToRef + ZnxInfos, - { - let log_base2k: usize = self.log_base2k(); - - #[cfg(debug_assertions)] - { - assert_eq!(res.log_base2k(), log_base2k); - assert_eq!(self.n(), module.n()); - assert_eq!(res.n(), module.n()); - } - - let (a_data, scratch_1) = scratch.tmp_vec_znx(module, 2, a.size()); - - let mut a_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> { - data: a_data, - log_base2k: a.log_base2k(), - log_k: a.log_k(), - }; - - a.idft(module, &mut a_idft, scratch_1); - - let (res_data, scratch_2) = scratch_1.tmp_vec_znx(module, 2, res.size()); - - let mut res_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> { - data: res_data, - log_base2k: res.log_base2k(), - log_k: res.log_k(), - }; - - self.mul_rlwe(module, &mut res_idft, &a_idft, scratch_2); - - module.vec_znx_dft(res, 0, &res_idft, 0); - module.vec_znx_dft(res, 1, &res_idft, 1); - } - - pub fn mul_rlwe_dft_inplace(&self, module: &Module, res: &mut RLWECtDft, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef + ZnxInfos, - { - let log_base2k: usize = self.log_base2k(); - - #[cfg(debug_assertions)] - { - assert_eq!(res.log_base2k(), log_base2k); - assert_eq!(self.n(), module.n()); - assert_eq!(res.n(), module.n()); - } - - let (res_data, scratch_1) = scratch.tmp_vec_znx(module, 2, res.size()); - - let mut res_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> { - data: res_data, - log_base2k: res.log_base2k(), - log_k: res.log_k(), - }; - - res.idft(module, &mut res_idft, scratch_1); - - self.mul_rlwe_inplace(module, &mut res_idft, scratch_1); - - module.vec_znx_dft(res, 0, &res_idft, 0); - module.vec_znx_dft(res, 1, &res_idft, 1); - } - - pub fn mul_grlwe( - &self, - module: &Module, - res: &mut GRLWECt, - a: &GRLWECt, - scratch: &mut Scratch, - ) where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef + ZnxInfos, - MatZnxDft: MatZnxDftToRef + ZnxInfos, - { - let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, a.size()); - - let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { - data: tmp_row_data, - log_base2k: a.log_base2k(), - log_k: a.log_k(), - }; - - let min_rows: usize = res.rows().min(a.rows()); - - (0..min_rows).for_each(|row_i| { - a.get_row(module, row_i, &mut tmp_row); - self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); - res.set_row(module, row_i, &tmp_row); - }); - - tmp_row.data.zero(); - - (min_rows..res.rows()).for_each(|row_i| { - res.set_row(module, row_i, &tmp_row); - }) - } - - pub fn mul_grlwe_inplace(&self, module: &Module, res: &mut GRLWECt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef + ZnxInfos, - { - let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, res.size()); - - let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { - data: tmp_row_data, - log_base2k: res.log_base2k(), - log_k: res.log_k(), - }; - - (0..res.rows()).for_each(|row_i| { - res.get_row(module, row_i, &mut tmp_row); - self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); - res.set_row(module, row_i, &tmp_row); - }); - } - - pub fn mul_rgsw(&self, module: &Module, res: &mut RGSWCt, a: &RGSWCt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef + ZnxInfos, - MatZnxDft: MatZnxDftToRef + ZnxInfos, - { - let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, a.size()); - - let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { - data: tmp_row_data, - log_base2k: a.log_base2k(), - log_k: a.log_k(), - }; - - let min_rows: usize = res.rows().min(a.rows()); - - (0..min_rows).for_each(|row_i| { - a.get_row(module, row_i, 0, &mut tmp_row); - self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); - res.set_row(module, row_i, 0, &tmp_row); - }); - - (0..min_rows).for_each(|row_i| { - a.get_row(module, row_i, 1, &mut tmp_row); - self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); - res.set_row(module, row_i, 1, &tmp_row); - }); - - tmp_row.data.zero(); - - (min_rows..res.rows()).for_each(|row_i| { - res.set_row(module, row_i, 0, &tmp_row); - res.set_row(module, row_i, 1, &tmp_row); - }) - } - - pub fn mul_rgsw_inplace(&self, module: &Module, res: &mut RGSWCt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef + ZnxInfos, - { - let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, res.size()); - - let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { - data: tmp_row_data, - log_base2k: res.log_base2k(), - log_k: res.log_k(), - }; - - (0..res.rows()).for_each(|row_i| { - res.get_row(module, row_i, 0, &mut tmp_row); - self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); - res.set_row(module, row_i, 0, &tmp_row); - }); - - (0..res.rows()).for_each(|row_i| { - res.get_row(module, row_i, 1, &mut tmp_row); - self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); - res.set_row(module, row_i, 1, &tmp_row); - }); - } } diff --git a/rlwe/src/test_fft64/elem_rgsw.rs b/rlwe/src/test_fft64/elem_rgsw.rs index 9ab790f..e076237 100644 --- a/rlwe/src/test_fft64/elem_rgsw.rs +++ b/rlwe/src/test_fft64/elem_rgsw.rs @@ -7,7 +7,7 @@ mod tests { use sampling::source::Source; use crate::{ - elem::Infos, + elem::{GetRow, Infos}, elem_rgsw::RGSWCt, elem_rlwe::{RLWECt, RLWECtDft, RLWEPt}, keys::{SecretKey, SecretKeyDft}, @@ -117,9 +117,9 @@ mod tests { pt_want.to_mut().at_mut(0, 0)[1] = 1; - let r: usize = 1; + let k: usize = 1; - pt_rgsw.raw_mut()[r] = 1; // X^{r} + pt_rgsw.raw_mut()[k] = 1; // X^{k} let mut scratch: ScratchOwned = ScratchOwned::new( RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) @@ -165,7 +165,7 @@ mod tests { ct_rlwe_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); - module.vec_znx_rotate_inplace(r as i64, &mut pt_want, 0); + module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);