diff --git a/core/src/elem.rs b/core/src/elem.rs index ac245ad..192bc74 100644 --- a/core/src/elem.rs +++ b/core/src/elem.rs @@ -1,10 +1,11 @@ use base2k::{ - Backend, FFT64, MatZnxDft, MatZnxDftToMut, MatZnxDftToRef, Module, Scratch, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxDftToMut, - VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxZero, + Backend, FFT64, MatZnxDft, MatZnxDftToRef, Module, Scratch, VecZnx, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, + VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxZero, }; use crate::{ grlwe::GRLWECt, + rgsw::RGSWCt, rlwe::{RLWECt, RLWECtDft}, utils::derive_size, }; @@ -65,6 +66,36 @@ pub trait SetRow { VecZnxDft: VecZnxDftToRef; } +pub trait ProdByScratchSpace { + fn prod_by_grlwe_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize; + fn prod_by_rgsw_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize; +} + +pub trait ProdBy { + fn prod_by_grlwe(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef; + + fn prod_by_rgsw(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef; +} + +pub trait FromProdByScratchSpace { + fn from_prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize; + fn from_prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize; +} + +pub trait FromProdBy { + fn from_prod_by_grlwe(&mut self, module: &Module, lhs: &L, rhs: &GRLWECt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef; + + fn from_prod_by_rgsw(&mut self, module: &Module, lhs: &L, rhs: &RGSWCt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef; +} + pub(crate) trait MatZnxDftProducts: Infos where MatZnxDft: MatZnxDftToRef + ZnxInfos, @@ -75,6 +106,31 @@ where VecZnx: VecZnxToMut, VecZnx: VecZnxToRef; + fn mul_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize; + + fn mul_rlwe_inplace_scratch_space(module: &Module, res_size: usize, mat_size: usize) -> usize { + Self::mul_rlwe_scratch_space(module, res_size, res_size, mat_size) + } + + fn mul_rlwe_dft_scratch_space(module: &Module, res_size: usize, a_size: usize, mat_size: usize) -> usize { + (Self::mul_rlwe_scratch_space(module, res_size, a_size, mat_size) | module.vec_znx_idft_tmp_bytes()) + + module.bytes_of_vec_znx(2, a_size) + + module.bytes_of_vec_znx(2, res_size) + } + + fn mul_rlwe_dft_inplace_scratch_space(module: &Module, res_size: usize, mat_size: usize) -> usize { + (Self::mul_rlwe_inplace_scratch_space(module, res_size, mat_size) | module.vec_znx_idft_tmp_bytes()) + + module.bytes_of_vec_znx(2, res_size) + } + + fn mul_mat_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, mat_size: usize) -> usize { + Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, mat_size) + module.bytes_of_vec_znx_dft(2, a_size) + } + + fn mul_mat_rlwe_inplace_scratch_space(module: &Module, res_size: usize, mat_size: usize) -> usize { + Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, mat_size) + module.bytes_of_vec_znx_dft(2, res_size) + } + fn mul_rlwe_inplace(&self, module: &Module, res: &mut RLWECt, scratch: &mut Scratch) where MatZnxDft: MatZnxDftToRef + ZnxInfos, @@ -132,7 +188,6 @@ where fn mul_rlwe_dft_inplace(&self, module: &Module, res: &mut RLWECtDft, scratch: &mut Scratch) where - MatZnxDft: MatZnxDftToRef + ZnxInfos, VecZnxDft: VecZnxDftToRef + VecZnxDftToMut, { let log_base2k: usize = self.log_base2k(); @@ -160,11 +215,10 @@ where module.vec_znx_dft(res, 1, &res_idft, 1); } - fn mul_grlwe(&self, module: &Module, res: &mut GRLWECt, a: &GRLWECt, scratch: &mut Scratch) + fn mul_mat_rlwe(&self, module: &Module, res: &mut R, a: &A, scratch: &mut Scratch) where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef + ZnxInfos, - MatZnxDft: MatZnxDftToRef + ZnxInfos, + A: GetRow + Infos, + R: SetRow + Infos, { let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, a.size()); @@ -176,22 +230,25 @@ where let min_rows: usize = res.rows().min(a.rows()); - (0..min_rows).for_each(|row_i| { - a.get_row(module, row_i, &mut tmp_row); - self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); - res.set_row(module, row_i, &tmp_row); + (0..res.rows()).for_each(|row_i| { + (0..self.cols()).for_each(|col_j| { + a.get_row(module, row_i, col_j, &mut tmp_row); + self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); + res.set_row(module, row_i, col_j, &tmp_row); + }); }); tmp_row.data.zero(); (min_rows..res.rows()).for_each(|row_i| { - res.set_row(module, row_i, &tmp_row); - }) + (0..self.cols()).for_each(|col_j| { + res.set_row(module, row_i, col_j, &tmp_row); + }); + }); } - fn mul_grlwe_inplace(&self, module: &Module, res: &mut R, scratch: &mut Scratch) + fn mul_mat_rlwe_inplace(&self, module: &Module, res: &mut R, scratch: &mut Scratch) where - MatZnxDft: MatZnxDftToRef + ZnxInfos, R: GetRow + SetRow + Infos, { let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, res.size()); @@ -202,12 +259,12 @@ where log_k: res.log_k(), }; - (0..self.cols()).for_each(|col_j| { - (0..res.rows()).for_each(|row_i| { + (0..res.rows()).for_each(|row_i| { + (0..self.cols()).for_each(|col_j| { res.get_row(module, row_i, col_j, &mut tmp_row); self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); res.set_row(module, row_i, col_j, &tmp_row); }); - }) + }); } } diff --git a/core/src/grlwe.rs b/core/src/grlwe.rs index df44a70..9c8c5b8 100644 --- a/core/src/grlwe.rs +++ b/core/src/grlwe.rs @@ -7,8 +7,9 @@ use base2k::{ use sampling::source::Source; use crate::{ - elem::{GetRow, Infos, MatZnxDftProducts, SetRow}, + elem::{FromProdBy, FromProdByScratchSpace, GetRow, Infos, MatZnxDftProducts, ProdBy, ProdByScratchSpace, SetRow}, keys::SecretKeyDft, + rgsw::RGSWCt, rlwe::{RLWECt, RLWECtDft, RLWEPt}, utils::derive_size, }; @@ -41,18 +42,6 @@ where } } -impl GRLWECt -where - MatZnxDft: MatZnxDftToMut, -{ - pub fn set_row(&mut self, module: &Module, row_i: usize, a: &RLWECtDft) - where - VecZnxDft: VecZnxDftToRef, - { - module.vmp_prepare_row(self, row_i, 0, a); - } -} - impl Infos for GRLWECt { type Inner = MatZnxDft; @@ -94,36 +83,6 @@ impl GRLWECt, FFT64> { + module.bytes_of_vec_znx(1, size) + module.bytes_of_vec_znx_dft(2, size) } - - pub fn mul_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { - module.bytes_of_vec_znx_dft(2, grlwe_size) - + (module.vec_znx_big_normalize_tmp_bytes() - | (module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 1, 2, grlwe_size) - + module.bytes_of_vec_znx_dft(1, a_size))) - } - - pub fn mul_rlwe_inplace_scratch_space(module: &Module, res_size: usize, grlwe_size: usize) -> usize { - Self::mul_rlwe_scratch_space(module, res_size, res_size, grlwe_size) - } - - pub fn mul_rlwe_dft_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { - (Self::mul_rlwe_scratch_space(module, res_size, a_size, grlwe_size) | module.vec_znx_idft_tmp_bytes()) - + module.bytes_of_vec_znx(2, a_size) - + module.bytes_of_vec_znx(2, res_size) - } - - pub fn mul_rlwe_dft_inplace_scratch_space(module: &Module, res_size: usize, grlwe_size: usize) -> usize { - (Self::mul_rlwe_inplace_scratch_space(module, res_size, grlwe_size) | module.vec_znx_idft_tmp_bytes()) - + module.bytes_of_vec_znx(2, res_size) - } - - pub fn mul_grlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { - Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, grlwe_size) + module.bytes_of_vec_znx_dft(2, a_size) - } - - pub fn mul_grlwe_inplace_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { - Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, grlwe_size) + module.bytes_of_vec_znx_dft(2, a_size) - } } pub fn encrypt_grlwe_sk( @@ -209,67 +168,6 @@ impl GRLWECt { module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch, ) } - - pub fn mul_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef, - VecZnx: VecZnxToMut, - VecZnx: VecZnxToRef, - { - MatZnxDftProducts::mul_rlwe(self, module, res, a, scratch); - } - - pub fn mul_rlwe_inplace(&self, module: &Module, res: &mut RLWECt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - VecZnx: VecZnxToMut + VecZnxToRef, - { - MatZnxDftProducts::mul_rlwe_inplace(self, module, res, scratch); - } - - pub fn mul_rlwe_dft( - &self, - module: &Module, - res: &mut RLWECtDft, - a: &RLWECtDft, - scratch: &mut Scratch, - ) where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef + ZnxInfos, - VecZnxDft: VecZnxDftToRef + ZnxInfos, - { - MatZnxDftProducts::mul_rlwe_dft(self, module, res, a, scratch); - } - - pub fn mul_rlwe_dft_inplace(&self, module: &Module, res: &mut RLWECtDft, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - VecZnxDft: VecZnxDftToRef + VecZnxDftToMut, - { - MatZnxDftProducts::mul_rlwe_dft_inplace(self, module, res, scratch); - } - - pub fn mul_grlwe( - &self, - module: &Module, - res: &mut GRLWECt, - a: &GRLWECt, - scratch: &mut Scratch, - ) where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef + ZnxInfos, - MatZnxDft: MatZnxDftToRef + ZnxInfos, - { - MatZnxDftProducts::mul_grlwe(self, module, res, a, scratch); - } - - pub fn mul_grlwe_inplace(&self, module: &Module, res: &mut R, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - R: GetRow + SetRow + Infos, - { - MatZnxDftProducts::mul_grlwe_inplace(self, module, res, scratch); - } } impl GetRow for GRLWECt @@ -308,6 +206,13 @@ impl MatZnxDftProducts, C> for GRLWECt where MatZnxDft: MatZnxDftToRef + ZnxInfos, { + fn mul_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { + module.bytes_of_vec_znx_dft(2, grlwe_size) + + (module.vec_znx_big_normalize_tmp_bytes() + | (module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 1, 2, grlwe_size) + + module.bytes_of_vec_znx_dft(1, a_size))) + } + fn mul_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) where MatZnxDft: MatZnxDftToRef, @@ -341,3 +246,80 @@ where module.vec_znx_big_normalize(log_base2k, res, 1, &res_big, 1, scratch1); } } + +impl ProdByScratchSpace for GRLWECt, FFT64> { + fn prod_by_grlwe_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_mat_rlwe_inplace_scratch_space( + module, lhs, rhs, + ) + } + + fn prod_by_rgsw_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_mat_rlwe_inplace_scratch_space( + module, lhs, rhs, + ) + } +} + +impl FromProdByScratchSpace for GRLWECt, FFT64> { + fn from_prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_mat_rlwe_scratch_space( + module, res_size, lhs, rhs, + ) + } + + fn from_prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_mat_rlwe_scratch_space( + module, res_size, lhs, rhs, + ) + } +} + +impl ProdBy> for GRLWECt +where + GRLWECt: GetRow + SetRow + Infos, +{ + fn prod_by_grlwe(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef, + { + rhs.mul_mat_rlwe_inplace(module, self, scratch); + } + + fn prod_by_rgsw(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef, + { + rhs.mul_mat_rlwe_inplace(module, self, scratch); + } +} + +impl FromProdBy, GRLWECt> for GRLWECt +where + GRLWECt: GetRow + SetRow + Infos, + GRLWECt: GetRow + Infos, +{ + fn from_prod_by_grlwe( + &mut self, + module: &Module, + lhs: &GRLWECt, + rhs: &GRLWECt, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + rhs.mul_mat_rlwe(module, self, lhs, scratch); + } + + fn from_prod_by_rgsw( + &mut self, + module: &Module, + lhs: &GRLWECt, + rhs: &RGSWCt, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + rhs.mul_mat_rlwe(module, self, lhs, scratch); + } +} diff --git a/core/src/rgsw.rs b/core/src/rgsw.rs index f271c15..c4c7c1c 100644 --- a/core/src/rgsw.rs +++ b/core/src/rgsw.rs @@ -7,7 +7,7 @@ use base2k::{ use sampling::source::Source; use crate::{ - elem::{GetRow, Infos, MatZnxDftProducts, SetRow}, + elem::{FromProdBy, FromProdByScratchSpace, GetRow, Infos, MatZnxDftProducts, ProdBy, ProdByScratchSpace, SetRow}, grlwe::GRLWECt, keys::SecretKeyDft, rlwe::{RLWECt, RLWECtDft, RLWEPt, encrypt_rlwe_sk}, @@ -71,43 +71,6 @@ impl RGSWCt, FFT64> { + module.bytes_of_vec_znx(1, size) + module.bytes_of_vec_znx_dft(2, size) } - - pub fn mul_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, rgsw_size: usize) -> usize { - module.bytes_of_vec_znx_dft(2, rgsw_size) - + ((module.bytes_of_vec_znx_dft(2, a_size) + module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 2, 2, rgsw_size)) - | module.vec_znx_big_normalize_tmp_bytes()) - } - - pub fn mul_rlwe_inplace_scratch_space(module: &Module, res_size: usize, rgsw_size: usize) -> usize { - Self::mul_rlwe_scratch_space(module, res_size, res_size, rgsw_size) - } - - pub fn mul_rlwe_dft_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { - (Self::mul_rlwe_scratch_space(module, res_size, a_size, grlwe_size) | module.vec_znx_idft_tmp_bytes()) - + module.bytes_of_vec_znx(2, a_size) - + module.bytes_of_vec_znx(2, res_size) - } - - pub fn mul_rlwe_dft_inplace_scratch_space(module: &Module, res_size: usize, grlwe_size: usize) -> usize { - (Self::mul_rlwe_inplace_scratch_space(module, res_size, grlwe_size) | module.vec_znx_idft_tmp_bytes()) - + module.bytes_of_vec_znx(2, res_size) - } - - pub fn mul_grlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { - Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, grlwe_size) + module.bytes_of_vec_znx_dft(2, a_size) - } - - pub fn mul_grlwe_inplace_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { - Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, grlwe_size) + module.bytes_of_vec_znx_dft(2, a_size) - } - - pub fn mul_rgsw_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { - Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, grlwe_size) + module.bytes_of_vec_znx_dft(2, a_size) - } - - pub fn mul_rgsw_inplace_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { - Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, grlwe_size) + module.bytes_of_vec_znx_dft(2, a_size) - } } pub fn encrypt_rgsw_sk( @@ -195,67 +158,6 @@ impl RGSWCt { module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch, ) } - - pub fn mul_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef, - VecZnx: VecZnxToMut, - VecZnx: VecZnxToRef, - { - MatZnxDftProducts::mul_rlwe(self, module, res, a, scratch); - } - - pub fn mul_rlwe_inplace(&self, module: &Module, res: &mut RLWECt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - VecZnx: VecZnxToMut + VecZnxToRef, - { - MatZnxDftProducts::mul_rlwe_inplace(self, module, res, scratch); - } - - pub fn mul_rlwe_dft( - &self, - module: &Module, - res: &mut RLWECtDft, - a: &RLWECtDft, - scratch: &mut Scratch, - ) where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef + ZnxInfos, - VecZnxDft: VecZnxDftToRef + ZnxInfos, - { - MatZnxDftProducts::mul_rlwe_dft(self, module, res, a, scratch); - } - - pub fn mul_rlwe_dft_inplace(&self, module: &Module, res: &mut RLWECtDft, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - VecZnxDft: VecZnxDftToRef + VecZnxDftToMut, - { - MatZnxDftProducts::mul_rlwe_dft_inplace(self, module, res, scratch); - } - - pub fn mul_grlwe( - &self, - module: &Module, - res: &mut GRLWECt, - a: &GRLWECt, - scratch: &mut Scratch, - ) where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef + ZnxInfos, - MatZnxDft: MatZnxDftToRef + ZnxInfos, - { - MatZnxDftProducts::mul_grlwe(self, module, res, a, scratch); - } - - pub fn mul_grlwe_inplace(&self, module: &Module, res: &mut R, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - R: GetRow + SetRow + Infos, - { - MatZnxDftProducts::mul_grlwe_inplace(self, module, res, scratch); - } } impl GetRow for RGSWCt @@ -286,6 +188,12 @@ impl MatZnxDftProducts, C> for RGSWCt where MatZnxDft: MatZnxDftToRef + ZnxInfos, { + fn mul_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, rgsw_size: usize) -> usize { + module.bytes_of_vec_znx_dft(2, rgsw_size) + + ((module.bytes_of_vec_znx_dft(2, a_size) + module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 2, 2, rgsw_size)) + | module.vec_znx_big_normalize_tmp_bytes()) + } + fn mul_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) where MatZnxDft: MatZnxDftToRef, @@ -318,3 +226,80 @@ where module.vec_znx_big_normalize(log_base2k, res, 1, &res_big, 1, scratch1); } } + +impl ProdByScratchSpace for RGSWCt, FFT64> { + fn prod_by_grlwe_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_mat_rlwe_inplace_scratch_space( + module, lhs, rhs, + ) + } + + fn prod_by_rgsw_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_mat_rlwe_inplace_scratch_space( + module, lhs, rhs, + ) + } +} + +impl FromProdByScratchSpace for RGSWCt, FFT64> { + fn from_prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_mat_rlwe_scratch_space( + module, res_size, lhs, rhs, + ) + } + + fn from_prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_mat_rlwe_scratch_space( + module, res_size, lhs, rhs, + ) + } +} + +impl ProdBy> for RGSWCt +where + RGSWCt: GetRow + SetRow + Infos, +{ + fn prod_by_grlwe(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef, + { + rhs.mul_mat_rlwe_inplace(module, self, scratch); + } + + fn prod_by_rgsw(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef, + { + rhs.mul_mat_rlwe_inplace(module, self, scratch); + } +} + +impl FromProdBy, RGSWCt> for RGSWCt +where + RGSWCt: GetRow + SetRow + Infos, + RGSWCt: GetRow + Infos, +{ + fn from_prod_by_grlwe( + &mut self, + module: &Module, + lhs: &RGSWCt, + rhs: &GRLWECt, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + rhs.mul_mat_rlwe(module, self, lhs, scratch); + } + + fn from_prod_by_rgsw( + &mut self, + module: &Module, + lhs: &RGSWCt, + rhs: &RGSWCt, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + rhs.mul_mat_rlwe(module, self, lhs, scratch); + } +} diff --git a/core/src/rlwe.rs b/core/src/rlwe.rs index b52d56d..ef1be64 100644 --- a/core/src/rlwe.rs +++ b/core/src/rlwe.rs @@ -6,9 +6,10 @@ use base2k::{ use sampling::source::Source; use crate::{ - elem::Infos, + elem::{FromProdBy, FromProdByScratchSpace, Infos, MatZnxDftProducts, ProdBy, ProdByScratchSpace}, grlwe::GRLWECt, keys::{PublicKey, SecretDistribution, SecretKeyDft}, + rgsw::RGSWCt, utils::derive_size, }; @@ -83,134 +84,70 @@ where } } -pub struct RLWEPt { - pub data: VecZnx, - pub log_base2k: usize, - pub log_k: usize, -} - -impl Infos for RLWEPt { - type Inner = VecZnx; - - fn inner(&self) -> &Self::Inner { - &self.data +impl ProdByScratchSpace for RLWECt> { + fn prod_by_grlwe_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_rlwe_inplace_scratch_space( + module, lhs, rhs, + ) } - fn log_base2k(&self) -> usize { - self.log_base2k - } - - fn log_k(&self) -> usize { - self.log_k + fn prod_by_rgsw_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_rlwe_inplace_scratch_space( + module, lhs, rhs, + ) } } -impl VecZnxToMut for RLWEPt +impl FromProdByScratchSpace for RLWECt> { + fn from_prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_rlwe_scratch_space( + module, res_size, lhs, rhs, + ) + } + + fn from_prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_rlwe_scratch_space( + module, res_size, lhs, rhs, + ) + } +} + +impl ProdBy> for RLWECt where - VecZnx: VecZnxToMut, + VecZnx: VecZnxToMut + VecZnxToRef, { - fn to_mut(&mut self) -> VecZnx<&mut [u8]> { - self.data.to_mut() - } -} - -impl VecZnxToRef for RLWEPt -where - VecZnx: VecZnxToRef, -{ - fn to_ref(&self) -> VecZnx<&[u8]> { - self.data.to_ref() - } -} - -impl RLWEPt> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { - Self { - data: module.new_vec_znx(1, derive_size(log_base2k, log_k)), - log_base2k: log_base2k, - log_k: log_k, - } - } -} - -pub struct RLWECtDft { - pub data: VecZnxDft, - pub log_base2k: usize, - pub log_k: usize, -} - -impl RLWECtDft, B> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { - Self { - data: module.new_vec_znx_dft(2, derive_size(log_base2k, log_k)), - log_base2k: log_base2k, - log_k: log_k, - } - } -} - -impl Infos for RLWECtDft { - type Inner = VecZnxDft; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn log_base2k(&self) -> usize { - self.log_base2k - } - - fn log_k(&self) -> usize { - self.log_k - } -} - -impl VecZnxDftToMut for RLWECtDft -where - VecZnxDft: VecZnxDftToMut, -{ - fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { - self.data.to_mut() - } -} - -impl VecZnxDftToRef for RLWECtDft -where - VecZnxDft: VecZnxDftToRef, -{ - fn to_ref(&self) -> VecZnxDft<&[u8], B> { - self.data.to_ref() - } -} - -impl RLWECtDft -where - VecZnxDft: VecZnxDftToRef, -{ - #[allow(dead_code)] - pub(crate) fn idft_scratch_space(module: &Module, size: usize) -> usize { - module.bytes_of_vec_znx(2, size) + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes()) - } - - pub(crate) fn idft(&self, module: &Module, res: &mut RLWECt, scratch: &mut Scratch) + fn prod_by_grlwe(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch) where - VecZnx: VecZnxToMut, + MatZnxDft: MatZnxDftToRef, { - #[cfg(debug_assertions)] - { - assert_eq!(self.cols(), 2); - assert_eq!(res.cols(), 2); - assert_eq!(self.log_base2k(), res.log_base2k()) - } + rhs.mul_rlwe_inplace(module, self, scratch); + } - let min_size: usize = self.size().min(res.size()); + fn prod_by_rgsw(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef, + { + rhs.mul_rlwe_inplace(module, self, scratch); + } +} - let (mut res_big, scratch1) = scratch.tmp_vec_znx_big(module, 2, min_size); +impl FromProdBy, RLWECt> for RLWECt +where + VecZnx: VecZnxToMut + VecZnxToRef, + VecZnx: VecZnxToRef, +{ + fn from_prod_by_grlwe(&mut self, module: &Module, lhs: &RLWECt, rhs: &GRLWECt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef, + { + rhs.mul_rlwe(module, self, lhs, scratch); + } - module.vec_znx_idft(&mut res_big, 0, &self.data, 0, scratch1); - module.vec_znx_idft(&mut res_big, 1, &self.data, 1, scratch1); - module.vec_znx_big_normalize(self.log_base2k(), res, 0, &res_big, 0, scratch1); - module.vec_znx_big_normalize(self.log_base2k(), res, 1, &res_big, 1, scratch1); + fn from_prod_by_rgsw(&mut self, module: &Module, lhs: &RLWECt, rhs: &RGSWCt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef, + { + rhs.mul_rlwe(module, self, lhs, scratch); } } @@ -390,6 +327,204 @@ impl RLWECt { } } +pub(crate) fn encrypt_rlwe_pk( + module: &Module, + ct: &mut RLWECt, + pt: Option<&RLWEPt

>, + pk: &PublicKey, + source_xu: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, +) where + VecZnx: VecZnxToMut + VecZnxToRef, + VecZnx

: VecZnxToRef, + VecZnxDft: VecZnxDftToRef, +{ + #[cfg(debug_assertions)] + { + assert_eq!(ct.log_base2k(), pk.log_base2k()); + assert_eq!(ct.n(), module.n()); + assert_eq!(pk.n(), module.n()); + if let Some(pt) = pt { + assert_eq!(pt.log_base2k(), pk.log_base2k()); + assert_eq!(pt.n(), module.n()); + } + } + + let log_base2k: usize = pk.log_base2k(); + let size_pk: usize = pk.size(); + + // Generates u according to the underlying secret distribution. + let (mut u_dft, scratch_1) = scratch.tmp_scalar_znx_dft(module, 1); + + { + let (mut u, _) = scratch_1.tmp_scalar_znx(module, 1); + match pk.dist { + SecretDistribution::NONE => panic!( + "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through Self::generate" + ), + SecretDistribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu), + SecretDistribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu), + SecretDistribution::ZERO => {} + } + + module.svp_prepare(&mut u_dft, 0, &u, 0); + } + + let (mut tmp_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, size_pk); // TODO optimize size (e.g. when encrypting at low homomorphic capacity) + let (mut tmp_dft, scratch_3) = scratch_2.tmp_vec_znx_dft(module, 1, size_pk); // TODO optimize size (e.g. when encrypting at low homomorphic capacity) + + // ct[0] = pk[0] * u + m + e0 + module.svp_apply(&mut tmp_dft, 0, &u_dft, 0, pk, 0); + module.vec_znx_idft_tmp_a(&mut tmp_big, 0, &mut tmp_dft, 0); + tmp_big.add_normal(log_base2k, 0, pk.log_k(), source_xe, sigma, bound); + + if let Some(pt) = pt { + module.vec_znx_big_add_small_inplace(&mut tmp_big, 0, pt, 0); + } + + module.vec_znx_big_normalize(log_base2k, ct, 0, &tmp_big, 0, scratch_3); + + // ct[1] = pk[1] * u + e1 + module.svp_apply(&mut tmp_dft, 0, &u_dft, 0, pk, 1); + module.vec_znx_idft_tmp_a(&mut tmp_big, 0, &mut tmp_dft, 0); + tmp_big.add_normal(log_base2k, 0, pk.log_k(), source_xe, sigma, bound); + module.vec_znx_big_normalize(log_base2k, ct, 1, &tmp_big, 0, scratch_3); +} + +pub struct RLWEPt { + pub data: VecZnx, + pub log_base2k: usize, + pub log_k: usize, +} + +impl Infos for RLWEPt { + type Inner = VecZnx; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn log_base2k(&self) -> usize { + self.log_base2k + } + + fn log_k(&self) -> usize { + self.log_k + } +} + +impl VecZnxToMut for RLWEPt +where + VecZnx: VecZnxToMut, +{ + fn to_mut(&mut self) -> VecZnx<&mut [u8]> { + self.data.to_mut() + } +} + +impl VecZnxToRef for RLWEPt +where + VecZnx: VecZnxToRef, +{ + fn to_ref(&self) -> VecZnx<&[u8]> { + self.data.to_ref() + } +} + +impl RLWEPt> { + pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { + Self { + data: module.new_vec_znx(1, derive_size(log_base2k, log_k)), + log_base2k: log_base2k, + log_k: log_k, + } + } +} + +pub struct RLWECtDft { + pub data: VecZnxDft, + pub log_base2k: usize, + pub log_k: usize, +} + +impl RLWECtDft, B> { + pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { + Self { + data: module.new_vec_znx_dft(2, derive_size(log_base2k, log_k)), + log_base2k: log_base2k, + log_k: log_k, + } + } +} + +impl Infos for RLWECtDft { + type Inner = VecZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn log_base2k(&self) -> usize { + self.log_base2k + } + + fn log_k(&self) -> usize { + self.log_k + } +} + +impl VecZnxDftToMut for RLWECtDft +where + VecZnxDft: VecZnxDftToMut, +{ + fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { + self.data.to_mut() + } +} + +impl VecZnxDftToRef for RLWECtDft +where + VecZnxDft: VecZnxDftToRef, +{ + fn to_ref(&self) -> VecZnxDft<&[u8], B> { + self.data.to_ref() + } +} + +impl RLWECtDft +where + VecZnxDft: VecZnxDftToRef, +{ + #[allow(dead_code)] + pub(crate) fn idft_scratch_space(module: &Module, size: usize) -> usize { + module.bytes_of_vec_znx(2, size) + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes()) + } + + pub(crate) fn idft(&self, module: &Module, res: &mut RLWECt, scratch: &mut Scratch) + where + VecZnx: VecZnxToMut, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.cols(), 2); + assert_eq!(res.cols(), 2); + assert_eq!(self.log_base2k(), res.log_base2k()) + } + + let min_size: usize = self.size().min(res.size()); + + let (mut res_big, scratch1) = scratch.tmp_vec_znx_big(module, 2, min_size); + + module.vec_znx_idft(&mut res_big, 0, &self.data, 0, scratch1); + module.vec_znx_idft(&mut res_big, 1, &self.data, 1, scratch1); + module.vec_znx_big_normalize(self.log_base2k(), res, 0, &res_big, 0, scratch1); + module.vec_znx_big_normalize(self.log_base2k(), res, 1, &res_big, 1, scratch1); + } +} + pub(crate) fn encrypt_zero_rlwe_dft_sk( module: &Module, ct: &mut RLWECtDft, @@ -528,79 +663,81 @@ impl RLWECtDft { { decrypt_rlwe_dft(module, pt, self, sk_dft, scratch); } +} - pub fn mul_grlwe_assign(&mut self, module: &Module, a: &GRLWECt, scratch: &mut Scratch) - where - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, - MatZnxDft: MatZnxDftToRef, - { - a.mul_rlwe_dft_inplace(module, self, scratch); +impl ProdByScratchSpace for RLWECtDft, FFT64> { + fn prod_by_grlwe_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_rlwe_dft_inplace_scratch_space( + module, lhs, rhs, + ) + } + + fn prod_by_rgsw_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_rlwe_dft_inplace_scratch_space( + module, lhs, rhs, + ) } } -pub(crate) fn encrypt_rlwe_pk( - module: &Module, - ct: &mut RLWECt, - pt: Option<&RLWEPt

>, - pk: &PublicKey, - source_xu: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, -) where - VecZnx: VecZnxToMut + VecZnxToRef, - VecZnx

: VecZnxToRef, - VecZnxDft: VecZnxDftToRef, +impl FromProdByScratchSpace for RLWECtDft, FFT64> { + fn from_prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_rlwe_dft_scratch_space( + module, res_size, lhs, rhs, + ) + } + + fn from_prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_rlwe_dft_scratch_space( + module, res_size, lhs, rhs, + ) + } +} + +impl ProdBy> for RLWECtDft +where + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, { - #[cfg(debug_assertions)] + fn prod_by_grlwe(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef, { - assert_eq!(ct.log_base2k(), pk.log_base2k()); - assert_eq!(ct.n(), module.n()); - assert_eq!(pk.n(), module.n()); - if let Some(pt) = pt { - assert_eq!(pt.log_base2k(), pk.log_base2k()); - assert_eq!(pt.n(), module.n()); - } + rhs.mul_rlwe_dft_inplace(module, self, scratch); } - let log_base2k: usize = pk.log_base2k(); - let size_pk: usize = pk.size(); - - // Generates u according to the underlying secret distribution. - let (mut u_dft, scratch_1) = scratch.tmp_scalar_znx_dft(module, 1); - + fn prod_by_rgsw(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef, { - let (mut u, _) = scratch_1.tmp_scalar_znx(module, 1); - match pk.dist { - SecretDistribution::NONE => panic!( - "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through Self::generate" - ), - SecretDistribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu), - SecretDistribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu), - SecretDistribution::ZERO => {} - } - - module.svp_prepare(&mut u_dft, 0, &u, 0); + rhs.mul_rlwe_dft_inplace(module, self, scratch); + } +} + +impl FromProdBy, RLWECtDft> for RLWECtDft +where + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, + VecZnxDft: VecZnxDftToRef, +{ + fn from_prod_by_grlwe( + &mut self, + module: &Module, + lhs: &RLWECtDft, + rhs: &GRLWECt, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + rhs.mul_rlwe_dft(module, self, lhs, scratch); + } + + fn from_prod_by_rgsw( + &mut self, + module: &Module, + lhs: &RLWECtDft, + rhs: &RGSWCt, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + rhs.mul_rlwe_dft(module, self, lhs, scratch); } - - let (mut tmp_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, size_pk); // TODO optimize size (e.g. when encrypting at low homomorphic capacity) - let (mut tmp_dft, scratch_3) = scratch_2.tmp_vec_znx_dft(module, 1, size_pk); // TODO optimize size (e.g. when encrypting at low homomorphic capacity) - - // ct[0] = pk[0] * u + m + e0 - module.svp_apply(&mut tmp_dft, 0, &u_dft, 0, pk, 0); - module.vec_znx_idft_tmp_a(&mut tmp_big, 0, &mut tmp_dft, 0); - tmp_big.add_normal(log_base2k, 0, pk.log_k(), source_xe, sigma, bound); - - if let Some(pt) = pt { - module.vec_znx_big_add_small_inplace(&mut tmp_big, 0, pt, 0); - } - - module.vec_znx_big_normalize(log_base2k, ct, 0, &tmp_big, 0, scratch_3); - - // ct[1] = pk[1] * u + e1 - module.svp_apply(&mut tmp_dft, 0, &u_dft, 0, pk, 1); - module.vec_znx_idft_tmp_a(&mut tmp_big, 0, &mut tmp_dft, 0); - tmp_big.add_normal(log_base2k, 0, pk.log_k(), source_xe, sigma, bound); - module.vec_znx_big_normalize(log_base2k, ct, 1, &tmp_big, 0, scratch_3); } diff --git a/core/src/test_fft64/grlwe.rs b/core/src/test_fft64/grlwe.rs index 86c13ec..294411b 100644 --- a/core/src/test_fft64/grlwe.rs +++ b/core/src/test_fft64/grlwe.rs @@ -1,14 +1,14 @@ #[cfg(test)] mod tests { - use base2k::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps}; + use base2k::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps}; use sampling::source::Source; use crate::{ - elem::Infos, + elem::{FromProdBy, FromProdByScratchSpace, Infos, ProdBy, ProdByScratchSpace}, grlwe::GRLWECt, keys::{SecretKey, SecretKeyDft}, - rlwe::{RLWECt, RLWECtDft, RLWEPt}, + rlwe::{RLWECtDft, RLWEPt}, test_fft64::grlwe::noise_grlwe_rlwe_product, }; @@ -67,413 +67,7 @@ mod tests { } #[test] - fn mul_rlwe() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe_in: usize = 45; - let log_k_rlwe_out: usize = 60; - let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) - | GRLWECt::mul_rlwe_scratch_space( - &module, - ct_rlwe_out.size(), - ct_rlwe_in.size(), - ct_grlwe.size(), - ), - ); - - let mut sk0: SecretKey> = SecretKey::new(&module); - sk0.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk0_dft.dft(&module, &sk0); - - let mut sk1: SecretKey> = SecretKey::new(&module); - sk1.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk1_dft.dft(&module, &sk1); - - ct_grlwe.encrypt_sk( - &module, - &sk0.data, - &sk1_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe_in.encrypt_sk( - &module, - Some(&pt_want), - &sk0_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_grlwe.mul_rlwe(&module, &mut ct_rlwe_out, &ct_rlwe_in, scratch.borrow()); - - ct_rlwe_out.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); - let noise_want: f64 = noise_grlwe_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - log_k_rlwe_in, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - - module.free(); - } - - #[test] - fn mul_rlwe_inplace() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe: usize = 45; - let rows: usize = (log_k_rlwe + log_base2k - 1) / log_base2k; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) - | GRLWECt::mul_rlwe_scratch_space(&module, ct_rlwe.size(), ct_rlwe.size(), ct_grlwe.size()), - ); - - let mut sk0: SecretKey> = SecretKey::new(&module); - sk0.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk0_dft.dft(&module, &sk0); - - let mut sk1: SecretKey> = SecretKey::new(&module); - sk1.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk1_dft.dft(&module, &sk1); - - ct_grlwe.encrypt_sk( - &module, - &sk0.data, - &sk1_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe.encrypt_sk( - &module, - Some(&pt_want), - &sk0_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_grlwe.mul_rlwe_inplace(&module, &mut ct_rlwe, scratch.borrow()); - - ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); - let noise_want: f64 = noise_grlwe_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - log_k_rlwe, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - - module.free(); - } - - #[test] - fn mul_rlwe_dft() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe_in: usize = 45; - let log_k_rlwe_out: usize = 60; - let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_in_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); - let mut ct_rlwe_out_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_out); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) - | GRLWECt::mul_rlwe_scratch_space( - &module, - ct_rlwe_out.size(), - ct_rlwe_in.size(), - ct_grlwe.size(), - ), - ); - - let mut sk0: SecretKey> = SecretKey::new(&module); - sk0.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk0_dft.dft(&module, &sk0); - - let mut sk1: SecretKey> = SecretKey::new(&module); - sk1.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk1_dft.dft(&module, &sk1); - - ct_grlwe.encrypt_sk( - &module, - &sk0.data, - &sk1_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe_in.encrypt_sk( - &module, - Some(&pt_want), - &sk0_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe_in.dft(&module, &mut ct_rlwe_in_dft); - ct_grlwe.mul_rlwe_dft( - &module, - &mut ct_rlwe_out_dft, - &ct_rlwe_in_dft, - scratch.borrow(), - ); - ct_rlwe_out_dft.idft(&module, &mut ct_rlwe_out, scratch.borrow()); - - ct_rlwe_out.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); - let noise_want: f64 = noise_grlwe_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - log_k_rlwe_in, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - - module.free(); - } - - #[test] - fn mul_rlwe_dft_inplace() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe: usize = 45; - let rows: usize = (log_k_rlwe + log_base2k - 1) / log_base2k; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe); - let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) - | GRLWECt::mul_rlwe_scratch_space(&module, ct_rlwe.size(), ct_rlwe.size(), ct_grlwe.size()), - ); - - let mut sk0: SecretKey> = SecretKey::new(&module); - sk0.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk0_dft.dft(&module, &sk0); - - let mut sk1: SecretKey> = SecretKey::new(&module); - sk1.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk1_dft.dft(&module, &sk1); - - ct_grlwe.encrypt_sk( - &module, - &sk0.data, - &sk1_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe.encrypt_sk( - &module, - Some(&pt_want), - &sk0_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe.dft(&module, &mut ct_rlwe_dft); - ct_grlwe.mul_rlwe_dft_inplace(&module, &mut ct_rlwe_dft, scratch.borrow()); - ct_rlwe_dft.idft(&module, &mut ct_rlwe, scratch.borrow()); - - ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); - let noise_want: f64 = noise_grlwe_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - log_k_rlwe, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - - module.free(); - } - - #[test] - fn mul_grlwe() { + fn from_prod_by_grlwe() { let module: Module = Module::::new(2048); let log_base2k: usize = 12; let log_k_grlwe: usize = 60; @@ -493,7 +87,7 @@ mod tests { let mut scratch: ScratchOwned = ScratchOwned::new( GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe_s0s1.size()) | RLWECtDft::decrypt_scratch_space(&module, ct_grlwe_s0s2.size()) - | GRLWECt::mul_grlwe_scratch_space( + | GRLWECt::from_prod_by_grlwe_scratch_space( &module, ct_grlwe_s0s2.size(), ct_grlwe_s0s1.size(), @@ -544,12 +138,7 @@ mod tests { ); // GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) - ct_grlwe_s1s2.mul_grlwe( - &module, - &mut ct_grlwe_s0s2, - &ct_grlwe_s0s1, - scratch.borrow(), - ); + ct_grlwe_s0s2.from_prod_by_grlwe(&module, &ct_grlwe_s0s1, &ct_grlwe_s1s2, scratch.borrow()); let mut ct_rlwe_dft_s0s2: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_grlwe); let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_grlwe); @@ -584,7 +173,7 @@ mod tests { } #[test] - fn mul_grlwe_inplace() { + fn prod_by_grlwe() { let module: Module = Module::::new(2048); let log_base2k: usize = 12; let log_k_grlwe: usize = 60; @@ -603,12 +192,7 @@ mod tests { let mut scratch: ScratchOwned = ScratchOwned::new( GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe_s0s1.size()) | RLWECtDft::decrypt_scratch_space(&module, ct_grlwe_s0s1.size()) - | GRLWECt::mul_grlwe_scratch_space( - &module, - ct_grlwe_s0s1.size(), - ct_grlwe_s0s1.size(), - ct_grlwe_s1s2.size(), - ), + | GRLWECt::prod_by_grlwe_scratch_space(&module, ct_grlwe_s0s1.size(), ct_grlwe_s1s2.size()), ); let mut sk0: SecretKey> = SecretKey::new(&module); @@ -654,7 +238,7 @@ mod tests { ); // GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) - ct_grlwe_s1s2.mul_grlwe_inplace(&module, &mut ct_grlwe_s0s1, scratch.borrow()); + ct_grlwe_s0s1.prod_by_grlwe(&module, &ct_grlwe_s1s2, scratch.borrow()); let ct_grlwe_s0s2: GRLWECt, FFT64> = ct_grlwe_s0s1; diff --git a/core/src/test_fft64/mod.rs b/core/src/test_fft64/mod.rs index 36d380c..59e2895 100644 --- a/core/src/test_fft64/mod.rs +++ b/core/src/test_fft64/mod.rs @@ -1,3 +1,4 @@ mod grlwe; mod rgsw; mod rlwe; +mod rlwe_dft; diff --git a/core/src/test_fft64/rgsw.rs b/core/src/test_fft64/rgsw.rs index 651f6b1..83df85b 100644 --- a/core/src/test_fft64/rgsw.rs +++ b/core/src/test_fft64/rgsw.rs @@ -2,7 +2,7 @@ mod tests { use base2k::{ FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, - VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, VecZnxToMut, ZnxViewMut, ZnxZero, + VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxZero, }; use sampling::source::Source; @@ -86,120 +86,6 @@ mod tests { module.free(); } - - #[test] - fn mul_rlwe() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe_in: usize = 45; - let log_k_rlwe_out: usize = 60; - let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); - let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - // pt_want - // .data - // .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); - - pt_want.to_mut().at_mut(0, 0)[1] = 1; - - let k: usize = 1; - - pt_rgsw.raw_mut()[k] = 1; // X^{k} - - let mut scratch: ScratchOwned = ScratchOwned::new( - RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) - | RGSWCt::mul_rlwe_scratch_space( - &module, - ct_rlwe_out.size(), - ct_rlwe_in.size(), - ct_rgsw.size(), - ), - ); - - let mut sk: SecretKey> = SecretKey::new(&module); - sk.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk_dft.dft(&module, &sk); - - ct_rgsw.encrypt_sk( - &module, - &pt_rgsw, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe_in.encrypt_sk( - &module, - Some(&pt_want), - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rgsw.mul_rlwe(&module, &mut ct_rlwe_out, &ct_rlwe_in, scratch.borrow()); - - ct_rlwe_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); - - module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); - - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); - - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_want: f64 = noise_rgsw_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - log_k_rlwe_in, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - - module.free(); - } } #[allow(dead_code)] diff --git a/core/src/test_fft64/rlwe.rs b/core/src/test_fft64/rlwe.rs index e735aa6..acc10a1 100644 --- a/core/src/test_fft64/rlwe.rs +++ b/core/src/test_fft64/rlwe.rs @@ -1,13 +1,19 @@ #[cfg(test)] -mod tests { - use base2k::{Decoding, Encoding, FFT64, Module, ScratchOwned, Stats, VecZnxOps, ZnxZero}; +mod tests_rlwe { + use base2k::{ + Decoding, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, + ZnxViewMut, ZnxZero, + }; use itertools::izip; use sampling::source::Source; use crate::{ - elem::Infos, + elem::{FromProdBy, FromProdByScratchSpace, Infos, ProdBy, ProdByScratchSpace}, + grlwe::GRLWECt, keys::{PublicKey, SecretKey, SecretKeyDft}, + rgsw::RGSWCt, rlwe::{RLWECt, RLWECtDft, RLWEPt}, + test_fft64::{grlwe::noise_grlwe_rlwe_product, rgsw::noise_rgsw_rlwe_product}, }; #[test] @@ -193,4 +199,423 @@ mod tests { module.free(); } + + #[test] + fn from_prod_by_grlwe() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe_in: usize = 45; + let log_k_rlwe_out: usize = 60; + let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | RLWECt::from_prod_by_grlwe_scratch_space( + &module, + ct_rlwe_out.size(), + ct_rlwe_in.size(), + ct_grlwe.size(), + ), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk1_dft.dft(&module, &sk1); + + ct_grlwe.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_in.encrypt_sk( + &module, + Some(&pt_want), + &sk0_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_out.from_prod_by_grlwe(&module, &ct_rlwe_in, &ct_grlwe, scratch.borrow()); + + ct_rlwe_out.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_rlwe_in, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); + } + + #[test] + fn prod_grlwe() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe: usize = 45; + let rows: usize = (log_k_rlwe + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | RLWECt::prod_by_grlwe_scratch_space(&module, ct_rlwe.size(), ct_grlwe.size()), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk1_dft.dft(&module, &sk1); + + ct_grlwe.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.encrypt_sk( + &module, + Some(&pt_want), + &sk0_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.prod_by_grlwe(&module, &ct_grlwe, scratch.borrow()); + + ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_rlwe, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); + } + + #[test] + fn from_prod_by_rgsw() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe_in: usize = 45; + let log_k_rlwe_out: usize = 60; + let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + pt_want.to_mut().at_mut(0, 0)[1] = 1; + + let k: usize = 1; + + pt_rgsw.raw_mut()[k] = 1; // X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::new( + RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | RLWECt::from_prod_by_rgsw_scratch_space( + &module, + ct_rlwe_out.size(), + ct_rlwe_in.size(), + ct_rgsw.size(), + ), + ); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_in.encrypt_sk( + &module, + Some(&pt_want), + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_out.from_prod_by_rgsw(&module, &ct_rlwe_in, &ct_rgsw, scratch.borrow()); + + ct_rlwe_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_rgsw_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + log_k_rlwe_in, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); + } + + #[test] + fn prod_by_rgsw() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe_in: usize = 45; + let log_k_rlwe_out: usize = 60; + let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + pt_want.to_mut().at_mut(0, 0)[1] = 1; + + let k: usize = 1; + + pt_rgsw.raw_mut()[k] = 1; // X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::new( + RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | RLWECt::prod_by_rgsw_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size()), + ); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.encrypt_sk( + &module, + Some(&pt_want), + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.prod_by_rgsw(&module, &ct_rgsw, scratch.borrow()); + + ct_rlwe.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_rgsw_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + log_k_rlwe_in, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); + } } diff --git a/core/src/test_fft64/rlwe_dft.rs b/core/src/test_fft64/rlwe_dft.rs new file mode 100644 index 0000000..fe0038d --- /dev/null +++ b/core/src/test_fft64/rlwe_dft.rs @@ -0,0 +1,216 @@ +#[cfg(test)] +mod tests { + use crate::{ + elem::{FromProdBy, FromProdByScratchSpace, Infos, ProdBy, ProdByScratchSpace}, + grlwe::GRLWECt, + keys::{SecretKey, SecretKeyDft}, + rlwe::{RLWECt, RLWECtDft, RLWEPt}, + test_fft64::grlwe::noise_grlwe_rlwe_product, + }; + use base2k::{FFT64, FillUniform, Module, ScratchOwned, Stats, VecZnxOps}; + use sampling::source::Source; + + #[test] + fn from_prod_by_grlwe() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe_in: usize = 45; + let log_k_rlwe_out: usize = 60; + let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_in_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); + let mut ct_rlwe_out_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_out); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | RLWECtDft::from_prod_by_grlwe_scratch_space( + &module, + ct_rlwe_out.size(), + ct_rlwe_in.size(), + ct_grlwe.size(), + ), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk1_dft.dft(&module, &sk1); + + ct_grlwe.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_in.encrypt_sk( + &module, + Some(&pt_want), + &sk0_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_in.dft(&module, &mut ct_rlwe_in_dft); + ct_rlwe_out_dft.from_prod_by_grlwe(&module, &ct_rlwe_in_dft, &ct_grlwe, scratch.borrow()); + ct_rlwe_out_dft.idft(&module, &mut ct_rlwe_out, scratch.borrow()); + + ct_rlwe_out.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_rlwe_in, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); + } + + #[test] + fn prod_by_grlwe() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe: usize = 45; + let rows: usize = (log_k_rlwe + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe); + let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | RLWECtDft::prod_by_grlwe_scratch_space(&module, ct_rlwe_dft.size(), ct_grlwe.size()), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk1_dft.dft(&module, &sk1); + + ct_grlwe.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.encrypt_sk( + &module, + Some(&pt_want), + &sk0_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.dft(&module, &mut ct_rlwe_dft); + ct_rlwe_dft.prod_by_grlwe(&module, &ct_grlwe, scratch.borrow()); + ct_rlwe_dft.idft(&module, &mut ct_rlwe, scratch.borrow()); + + ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_rlwe, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); + } +}