diff --git a/rlwe/src/elem_rgsw.rs b/rlwe/src/elem_rgsw.rs index eab378a..75f468f 100644 --- a/rlwe/src/elem_rgsw.rs +++ b/rlwe/src/elem_rgsw.rs @@ -1,12 +1,14 @@ use base2k::{ - Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, - ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxOps, + Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, + ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigOps, VecZnxBigScratch, + VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxZero, }; use sampling::source::Source; use crate::{ elem::Infos, + elem_grlwe::GRLWECt, elem_rlwe::{RLWECt, RLWECtDft, RLWEPt, encrypt_rlwe_sk}, keys::SecretKeyDft, utils::derive_size, @@ -69,20 +71,42 @@ impl RGSWCt, FFT64> { + module.bytes_of_vec_znx(1, size) + module.bytes_of_vec_znx_dft(2, size) } + + pub fn mul_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, rgsw_size: usize) -> usize { + module.bytes_of_vec_znx_dft(2, rgsw_size) + + ((module.bytes_of_vec_znx_dft(2, a_size) + module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 2, 2, rgsw_size)) + | module.vec_znx_big_normalize_tmp_bytes()) + } + + pub fn mul_rlwe_inplace_scratch_space(module: &Module, res_size: usize, rgsw_size: usize) -> usize { + Self::mul_rlwe_scratch_space(module, res_size, res_size, rgsw_size) + } } impl RGSWCt where MatZnxDft: MatZnxDftToRef, { - pub fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut RLWECtDft) + pub fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut RLWECtDft) where - VecZnxDft: VecZnxDftToMut, + VecZnxDft: VecZnxDftToMut, { module.vmp_extract_row(res, self, row_i, col_j); } } +impl RGSWCt +where + MatZnxDft: MatZnxDftToMut, +{ + pub fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &RLWECtDft) + where + VecZnxDft: VecZnxDftToRef, + { + module.vmp_prepare_row(self, row_i, col_j, a); + } +} + pub fn encrypt_rgsw_sk( module: &Module, ct: &mut RGSWCt, @@ -168,4 +192,237 @@ impl RGSWCt { module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch, ) } + + pub fn mul_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef, + VecZnx: VecZnxToMut, + VecZnx: VecZnxToRef, + { + let log_base2k: usize = self.log_base2k(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.log_base2k(), log_base2k); + assert_eq!(a.log_base2k(), log_base2k); + assert_eq!(self.n(), module.n()); + assert_eq!(res.n(), module.n()); + assert_eq!(a.n(), module.n()); + } + + let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, 2, self.size()); // Todo optimise + + { + let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, 2, a.size()); + module.vec_znx_dft(&mut a_dft, 0, a, 0); + module.vec_znx_dft(&mut a_dft, 1, a, 1); + module.vmp_apply(&mut res_dft, &a_dft, self, scratch2); + } + + let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); + + module.vec_znx_big_normalize(log_base2k, res, 0, &res_big, 0, scratch1); + module.vec_znx_big_normalize(log_base2k, res, 1, &res_big, 1, scratch1); + } + + pub fn mul_rlwe_inplace(&self, module: &Module, res: &mut RLWECt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + VecZnx: VecZnxToMut + VecZnxToRef, + { + unsafe { + let res_ptr: *mut RLWECt = res as *mut RLWECt; // This is ok because [Self::mul_rlwe] only updates res at the end. + self.mul_rlwe(&module, &mut *res_ptr, &*res_ptr, scratch); + } + } + + pub fn mul_rlwe_dft( + &self, + module: &Module, + res: &mut RLWECtDft, + a: &RLWECtDft, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef + ZnxInfos, + VecZnxDft: VecZnxDftToRef + ZnxInfos, + { + let log_base2k: usize = self.log_base2k(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.log_base2k(), log_base2k); + assert_eq!(self.n(), module.n()); + assert_eq!(res.n(), module.n()); + } + + let (a_data, scratch_1) = scratch.tmp_vec_znx(module, 2, a.size()); + + let mut a_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> { + data: a_data, + log_base2k: a.log_base2k(), + log_k: a.log_k(), + }; + + a.idft(module, &mut a_idft, scratch_1); + + let (res_data, scratch_2) = scratch_1.tmp_vec_znx(module, 2, res.size()); + + let mut res_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> { + data: res_data, + log_base2k: res.log_base2k(), + log_k: res.log_k(), + }; + + self.mul_rlwe(module, &mut res_idft, &a_idft, scratch_2); + + module.vec_znx_dft(res, 0, &res_idft, 0); + module.vec_znx_dft(res, 1, &res_idft, 1); + } + + pub fn mul_rlwe_dft_inplace(&self, module: &Module, res: &mut RLWECtDft, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef + ZnxInfos, + { + let log_base2k: usize = self.log_base2k(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.log_base2k(), log_base2k); + assert_eq!(self.n(), module.n()); + assert_eq!(res.n(), module.n()); + } + + let (res_data, scratch_1) = scratch.tmp_vec_znx(module, 2, res.size()); + + let mut res_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> { + data: res_data, + log_base2k: res.log_base2k(), + log_k: res.log_k(), + }; + + res.idft(module, &mut res_idft, scratch_1); + + self.mul_rlwe_inplace(module, &mut res_idft, scratch_1); + + module.vec_znx_dft(res, 0, &res_idft, 0); + module.vec_znx_dft(res, 1, &res_idft, 1); + } + + pub fn mul_grlwe( + &self, + module: &Module, + res: &mut GRLWECt, + a: &GRLWECt, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef + ZnxInfos, + MatZnxDft: MatZnxDftToRef + ZnxInfos, + { + let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, a.size()); + + let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { + data: tmp_row_data, + log_base2k: a.log_base2k(), + log_k: a.log_k(), + }; + + let min_rows: usize = res.rows().min(a.rows()); + + (0..min_rows).for_each(|row_i| { + a.get_row(module, row_i, &mut tmp_row); + self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); + res.set_row(module, row_i, &tmp_row); + }); + + tmp_row.data.zero(); + + (min_rows..res.rows()).for_each(|row_i| { + res.set_row(module, row_i, &tmp_row); + }) + } + + pub fn mul_grlwe_inplace(&self, module: &Module, res: &mut GRLWECt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef + ZnxInfos, + { + let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, res.size()); + + let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { + data: tmp_row_data, + log_base2k: res.log_base2k(), + log_k: res.log_k(), + }; + + (0..res.rows()).for_each(|row_i| { + res.get_row(module, row_i, &mut tmp_row); + self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); + res.set_row(module, row_i, &tmp_row); + }); + } + + pub fn mul_rgsw(&self, module: &Module, res: &mut RGSWCt, a: &RGSWCt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef + ZnxInfos, + MatZnxDft: MatZnxDftToRef + ZnxInfos, + { + let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, a.size()); + + let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { + data: tmp_row_data, + log_base2k: a.log_base2k(), + log_k: a.log_k(), + }; + + let min_rows: usize = res.rows().min(a.rows()); + + (0..min_rows).for_each(|row_i| { + a.get_row(module, row_i, 0, &mut tmp_row); + self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); + res.set_row(module, row_i, 0, &tmp_row); + }); + + (0..min_rows).for_each(|row_i| { + a.get_row(module, row_i, 1, &mut tmp_row); + self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); + res.set_row(module, row_i, 1, &tmp_row); + }); + + tmp_row.data.zero(); + + (min_rows..res.rows()).for_each(|row_i| { + res.set_row(module, row_i, 0, &tmp_row); + res.set_row(module, row_i, 1, &tmp_row); + }) + } + + pub fn mul_rgsw_inplace(&self, module: &Module, res: &mut RGSWCt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef + ZnxInfos, + { + let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, res.size()); + + let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { + data: tmp_row_data, + log_base2k: res.log_base2k(), + log_k: res.log_k(), + }; + + (0..res.rows()).for_each(|row_i| { + res.get_row(module, row_i, 0, &mut tmp_row); + self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); + res.set_row(module, row_i, 0, &tmp_row); + }); + + (0..res.rows()).for_each(|row_i| { + res.get_row(module, row_i, 1, &mut tmp_row); + self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); + res.set_row(module, row_i, 1, &tmp_row); + }); + } }