From d8a7d6cdaf16b016d623f2dadb8cb91195058031 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Mon, 12 May 2025 14:40:17 +0200 Subject: [PATCH] Some traits updates + added missing tests for products on RGSWCt --- core/src/elem.rs | 150 ++-- core/src/grlwe.rs | 108 +-- core/src/lib.rs | 1 + core/src/rgsw.rs | 96 +-- core/src/rlwe.rs | 148 ++-- core/src/test_fft64/grlwe.rs | 993 +++++++++++++------------ core/src/test_fft64/rgsw.rs | 627 ++++++++++++++-- core/src/test_fft64/rlwe.rs | 1197 +++++++++++++++---------------- core/src/test_fft64/rlwe_dft.rs | 889 ++++++++++++----------- 9 files changed, 2295 insertions(+), 1914 deletions(-) diff --git a/core/src/elem.rs b/core/src/elem.rs index 94311bc..b66c86d 100644 --- a/core/src/elem.rs +++ b/core/src/elem.rs @@ -66,92 +66,88 @@ pub trait SetRow { VecZnxDft: VecZnxDftToRef; } -pub trait ProdByScratchSpace { - fn prod_by_grlwe_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize; - fn prod_by_rgsw_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize; +pub trait ProdInplaceScratchSpace { + fn prod_by_grlwe_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize; + fn prod_by_rgsw_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize; } -pub trait ProdBy { - fn prod_by_grlwe(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef; - - fn prod_by_rgsw(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef; -} - -pub trait FromProdByScratchSpace { - fn from_prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize; - fn from_prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize; -} - -pub trait FromProdBy { - fn from_prod_by_grlwe(&mut self, module: &Module, lhs: &L, rhs: &GRLWECt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef; - - fn from_prod_by_rgsw(&mut self, module: &Module, lhs: &L, rhs: &RGSWCt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef; -} - -pub(crate) trait MatZnxDftProducts: Infos +pub trait ProdInplace where - MatZnxDft: MatZnxDftToRef + ZnxInfos, + MatZnxDft: MatZnxDftToRef, { - fn mul_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef, - VecZnx: VecZnxToMut, - VecZnx: VecZnxToRef; + fn prod_by_grlwe_inplace(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch); + fn prod_by_rgsw_inplace(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch); +} - fn mul_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize; +pub trait ProdScratchSpace { + fn prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize; + fn prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize; +} - fn mul_rlwe_inplace_scratch_space(module: &Module, res_size: usize, mat_size: usize) -> usize { - Self::mul_rlwe_scratch_space(module, res_size, res_size, mat_size) +pub trait Product +where + MatZnxDft: MatZnxDftToRef, +{ + type Lhs; + + fn prod_by_grlwe(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &GRLWECt, scratch: &mut Scratch); + fn prod_by_rgsw(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &RGSWCt, scratch: &mut Scratch); +} + +pub(crate) trait MatRLWEProductScratchSpace { + fn prod_with_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize; + + fn prod_with_rlwe_inplace_scratch_space(module: &Module, res_size: usize, mat_size: usize) -> usize { + Self::prod_with_rlwe_scratch_space(module, res_size, res_size, mat_size) } - fn mul_rlwe_dft_scratch_space(module: &Module, res_size: usize, a_size: usize, mat_size: usize) -> usize { - (Self::mul_rlwe_scratch_space(module, res_size, a_size, mat_size) | module.vec_znx_idft_tmp_bytes()) + fn prod_with_rlwe_dft_scratch_space(module: &Module, res_size: usize, a_size: usize, mat_size: usize) -> usize { + (Self::prod_with_rlwe_scratch_space(module, res_size, a_size, mat_size) | module.vec_znx_idft_tmp_bytes()) + module.bytes_of_vec_znx(2, a_size) + module.bytes_of_vec_znx(2, res_size) } - fn mul_rlwe_dft_inplace_scratch_space(module: &Module, res_size: usize, mat_size: usize) -> usize { - (Self::mul_rlwe_inplace_scratch_space(module, res_size, mat_size) | module.vec_znx_idft_tmp_bytes()) + fn prod_with_rlwe_dft_inplace_scratch_space(module: &Module, res_size: usize, mat_size: usize) -> usize { + (Self::prod_with_rlwe_inplace_scratch_space(module, res_size, mat_size) | module.vec_znx_idft_tmp_bytes()) + module.bytes_of_vec_znx(2, res_size) } - fn mul_mat_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, mat_size: usize) -> usize { - Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, mat_size) + module.bytes_of_vec_znx_dft(2, a_size) + fn prod_with_mat_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, mat_size: usize) -> usize { + Self::prod_with_rlwe_dft_scratch_space(module, res_size, a_size, mat_size) + + module.bytes_of_vec_znx_dft(2, a_size) + + module.bytes_of_vec_znx_dft(2, res_size) } - fn mul_mat_rlwe_inplace_scratch_space(module: &Module, res_size: usize, mat_size: usize) -> usize { - Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, mat_size) + module.bytes_of_vec_znx_dft(2, res_size) + fn prod_with_mat_rlwe_inplace_scratch_space(module: &Module, res_size: usize, mat_size: usize) -> usize { + Self::prod_with_rlwe_dft_inplace_scratch_space(module, res_size, mat_size) + module.bytes_of_vec_znx_dft(2, res_size) } +} - fn mul_rlwe_inplace(&self, module: &Module, res: &mut RLWECt, scratch: &mut Scratch) +pub(crate) trait MatRLWEProduct: Infos { + fn prod_with_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - VecZnx: VecZnxToMut + VecZnxToRef, + VecZnx: VecZnxToMut, + VecZnx: VecZnxToRef; + + fn prod_with_rlwe_inplace(&self, module: &Module, res: &mut RLWECt, scratch: &mut Scratch) + where + VecZnx: VecZnxToMut + VecZnxToRef, { unsafe { - let res_ptr: *mut RLWECt = res as *mut RLWECt; // This is ok because [Self::mul_rlwe] only updates res at the end. - self.mul_rlwe(&module, &mut *res_ptr, &*res_ptr, scratch); + let res_ptr: *mut RLWECt = res as *mut RLWECt; // This is ok because [Self::mul_rlwe] only updates res at the end. + self.prod_with_rlwe(&module, &mut *res_ptr, &*res_ptr, scratch); } } - fn mul_rlwe_dft( + fn prod_with_rlwe_dft( &self, module: &Module, - res: &mut RLWECtDft, - a: &RLWECtDft, + res: &mut RLWECtDft, + a: &RLWECtDft, scratch: &mut Scratch, ) where - MatZnxDft: MatZnxDftToRef + ZnxInfos, - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef + ZnxInfos, - VecZnxDft: VecZnxDftToRef + ZnxInfos, + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef + ZnxInfos, + VecZnxDft: VecZnxDftToRef + ZnxInfos, { let log_base2k: usize = self.log_base2k(); @@ -180,15 +176,15 @@ where log_k: res.log_k(), }; - self.mul_rlwe(module, &mut res_idft, &a_idft, scratch_2); + self.prod_with_rlwe(module, &mut res_idft, &a_idft, scratch_2); module.vec_znx_dft(res, 0, &res_idft, 0); module.vec_znx_dft(res, 1, &res_idft, 1); } - fn mul_rlwe_dft_inplace(&self, module: &Module, res: &mut RLWECtDft, scratch: &mut Scratch) + fn prod_with_rlwe_dft_inplace(&self, module: &Module, res: &mut RLWECtDft, scratch: &mut Scratch) where - VecZnxDft: VecZnxDftToRef + VecZnxDftToMut, + VecZnxDft: VecZnxDftToRef + VecZnxDftToMut, { let log_base2k: usize = self.log_base2k(); @@ -209,47 +205,55 @@ where res.idft(module, &mut res_idft, scratch_1); - self.mul_rlwe_inplace(module, &mut res_idft, scratch_1); + self.prod_with_rlwe_inplace(module, &mut res_idft, scratch_1); module.vec_znx_dft(res, 0, &res_idft, 0); module.vec_znx_dft(res, 1, &res_idft, 1); } - fn mul_mat_rlwe(&self, module: &Module, res: &mut R, a: &A, scratch: &mut Scratch) + fn prod_with_mat_rlwe(&self, module: &Module, res: &mut RES, a: &LHS, scratch: &mut Scratch) where - A: GetRow + Infos, - R: SetRow + Infos, + LHS: GetRow + Infos, + RES: SetRow + Infos, { let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, a.size()); - let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { + let mut tmp_a_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { data: tmp_row_data, log_base2k: a.log_base2k(), log_k: a.log_k(), }; + let (tmp_res_data, scratch2) = scratch1.tmp_vec_znx_dft(module, 2, res.size()); + + let mut tmp_res_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { + data: tmp_res_data, + log_base2k: res.log_base2k(), + log_k: res.log_k(), + }; + let min_rows: usize = res.rows().min(a.rows()); (0..res.rows()).for_each(|row_i| { (0..res.cols()).for_each(|col_j| { - a.get_row(module, row_i, col_j, &mut tmp_row); - self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); - res.set_row(module, row_i, col_j, &tmp_row); + a.get_row(module, row_i, col_j, &mut tmp_a_row); + self.prod_with_rlwe_dft(module, &mut tmp_res_row, &tmp_a_row, scratch2); + res.set_row(module, row_i, col_j, &tmp_res_row); }); }); - tmp_row.data.zero(); + tmp_res_row.data.zero(); (min_rows..res.rows()).for_each(|row_i| { (0..self.cols()).for_each(|col_j| { - res.set_row(module, row_i, col_j, &tmp_row); + res.set_row(module, row_i, col_j, &tmp_res_row); }); }); } - fn mul_mat_rlwe_inplace(&self, module: &Module, res: &mut R, scratch: &mut Scratch) + fn prod_with_mat_rlwe_inplace(&self, module: &Module, res: &mut RES, scratch: &mut Scratch) where - R: GetRow + SetRow + Infos, + RES: GetRow + SetRow + Infos, { let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, res.size()); @@ -262,7 +266,7 @@ where (0..res.rows()).for_each(|row_i| { (0..res.cols()).for_each(|col_j| { res.get_row(module, row_i, col_j, &mut tmp_row); - self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); + self.prod_with_rlwe_dft_inplace(module, &mut tmp_row, scratch1); res.set_row(module, row_i, col_j, &tmp_row); }); }); diff --git a/core/src/grlwe.rs b/core/src/grlwe.rs index 9c8c5b8..80c976d 100644 --- a/core/src/grlwe.rs +++ b/core/src/grlwe.rs @@ -7,7 +7,10 @@ use base2k::{ use sampling::source::Source; use crate::{ - elem::{FromProdBy, FromProdByScratchSpace, GetRow, Infos, MatZnxDftProducts, ProdBy, ProdByScratchSpace, SetRow}, + elem::{ + GetRow, Infos, MatRLWEProduct, MatRLWEProductScratchSpace, ProdInplace, ProdInplaceScratchSpace, ProdScratchSpace, + Product, SetRow, + }, keys::SecretKeyDft, rgsw::RGSWCt, rlwe::{RLWECt, RLWECtDft, RLWEPt}, @@ -30,18 +33,6 @@ impl GRLWECt, B> { } } -impl GRLWECt -where - MatZnxDft: MatZnxDftToRef, -{ - pub fn get_row(&self, module: &Module, row_i: usize, res: &mut RLWECtDft) - where - VecZnxDft: VecZnxDftToMut, - { - module.vmp_extract_row(res, self, row_i, 0); - } -} - impl Infos for GRLWECt { type Inner = MatZnxDft; @@ -202,18 +193,20 @@ where } } -impl MatZnxDftProducts, C> for GRLWECt -where - MatZnxDft: MatZnxDftToRef + ZnxInfos, -{ - fn mul_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { +impl MatRLWEProductScratchSpace for GRLWECt, FFT64> { + fn prod_with_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { module.bytes_of_vec_znx_dft(2, grlwe_size) + (module.vec_znx_big_normalize_tmp_bytes() | (module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 1, 2, grlwe_size) + module.bytes_of_vec_znx_dft(1, a_size))) } +} - fn mul_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) +impl MatRLWEProduct for GRLWECt +where + MatZnxDft: MatZnxDftToRef + ZnxInfos, +{ + fn prod_with_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) where MatZnxDft: MatZnxDftToRef, VecZnx: VecZnxToMut, @@ -247,79 +240,52 @@ where } } -impl ProdByScratchSpace for GRLWECt, FFT64> { - fn prod_by_grlwe_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_mat_rlwe_inplace_scratch_space( - module, lhs, rhs, - ) +impl ProdInplaceScratchSpace for GRLWECt, FFT64> { + fn prod_by_grlwe_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatRLWEProductScratchSpace>::prod_with_mat_rlwe_inplace_scratch_space(module, lhs, rhs) } - fn prod_by_rgsw_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_mat_rlwe_inplace_scratch_space( - module, lhs, rhs, - ) + fn prod_by_rgsw_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatRLWEProductScratchSpace>::prod_with_mat_rlwe_inplace_scratch_space(module, lhs, rhs) } } -impl FromProdByScratchSpace for GRLWECt, FFT64> { - fn from_prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_mat_rlwe_scratch_space( - module, res_size, lhs, rhs, - ) +impl ProdScratchSpace for GRLWECt, FFT64> { + fn prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatRLWEProductScratchSpace>::prod_with_mat_rlwe_scratch_space(module, res_size, lhs, rhs) } - fn from_prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_mat_rlwe_scratch_space( - module, res_size, lhs, rhs, - ) + fn prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatRLWEProductScratchSpace>::prod_with_mat_rlwe_scratch_space(module, res_size, lhs, rhs) } } -impl ProdBy> for GRLWECt +impl ProdInplace for GRLWECt where GRLWECt: GetRow + SetRow + Infos, + MatZnxDft: MatZnxDftToRef, { - fn prod_by_grlwe(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef, - { - rhs.mul_mat_rlwe_inplace(module, self, scratch); + fn prod_by_grlwe_inplace(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch) { + rhs.prod_with_mat_rlwe_inplace(module, self, scratch); } - fn prod_by_rgsw(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef, - { - rhs.mul_mat_rlwe_inplace(module, self, scratch); + fn prod_by_rgsw_inplace(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch) { + rhs.prod_with_mat_rlwe_inplace(module, self, scratch); } } -impl FromProdBy, GRLWECt> for GRLWECt +impl Product for GRLWECt where - GRLWECt: GetRow + SetRow + Infos, - GRLWECt: GetRow + Infos, + MatZnxDft: MatZnxDftToRef + MatZnxDftToMut, + MatZnxDft: MatZnxDftToRef, { - fn from_prod_by_grlwe( - &mut self, - module: &Module, - lhs: &GRLWECt, - rhs: &GRLWECt, - scratch: &mut Scratch, - ) where - MatZnxDft: MatZnxDftToRef, - { - rhs.mul_mat_rlwe(module, self, lhs, scratch); + type Lhs = GRLWECt; + + fn prod_by_grlwe(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &GRLWECt, scratch: &mut Scratch) { + rhs.prod_with_mat_rlwe(module, self, lhs, scratch); } - fn from_prod_by_rgsw( - &mut self, - module: &Module, - lhs: &GRLWECt, - rhs: &RGSWCt, - scratch: &mut Scratch, - ) where - MatZnxDft: MatZnxDftToRef, - { - rhs.mul_mat_rlwe(module, self, lhs, scratch); + fn prod_by_rgsw(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &RGSWCt, scratch: &mut Scratch) { + rhs.prod_with_mat_rlwe(module, self, lhs, scratch); } } diff --git a/core/src/lib.rs b/core/src/lib.rs index a93d44e..bed71cc 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -3,5 +3,6 @@ pub mod grlwe; pub mod keys; pub mod rgsw; pub mod rlwe; +#[cfg(test)] mod test_fft64; mod utils; diff --git a/core/src/rgsw.rs b/core/src/rgsw.rs index c4c7c1c..b866252 100644 --- a/core/src/rgsw.rs +++ b/core/src/rgsw.rs @@ -7,7 +7,10 @@ use base2k::{ use sampling::source::Source; use crate::{ - elem::{FromProdBy, FromProdByScratchSpace, GetRow, Infos, MatZnxDftProducts, ProdBy, ProdByScratchSpace, SetRow}, + elem::{ + GetRow, Infos, MatRLWEProduct, MatRLWEProductScratchSpace, ProdInplace, ProdInplaceScratchSpace, ProdScratchSpace, + Product, SetRow, + }, grlwe::GRLWECt, keys::SecretKeyDft, rlwe::{RLWECt, RLWECtDft, RLWEPt, encrypt_rlwe_sk}, @@ -184,17 +187,19 @@ where } } -impl MatZnxDftProducts, C> for RGSWCt -where - MatZnxDft: MatZnxDftToRef + ZnxInfos, -{ - fn mul_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, rgsw_size: usize) -> usize { +impl MatRLWEProductScratchSpace for RGSWCt, FFT64> { + fn prod_with_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, rgsw_size: usize) -> usize { module.bytes_of_vec_znx_dft(2, rgsw_size) + ((module.bytes_of_vec_znx_dft(2, a_size) + module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 2, 2, rgsw_size)) | module.vec_znx_big_normalize_tmp_bytes()) } +} - fn mul_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) +impl MatRLWEProduct for RGSWCt +where + MatZnxDft: MatZnxDftToRef + ZnxInfos, +{ + fn prod_with_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) where MatZnxDft: MatZnxDftToRef, VecZnx: VecZnxToMut, @@ -227,79 +232,52 @@ where } } -impl ProdByScratchSpace for RGSWCt, FFT64> { - fn prod_by_grlwe_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_mat_rlwe_inplace_scratch_space( - module, lhs, rhs, - ) +impl ProdInplaceScratchSpace for RGSWCt, FFT64> { + fn prod_by_grlwe_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatRLWEProductScratchSpace>::prod_with_mat_rlwe_inplace_scratch_space(module, lhs, rhs) } - fn prod_by_rgsw_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_mat_rlwe_inplace_scratch_space( - module, lhs, rhs, - ) + fn prod_by_rgsw_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatRLWEProductScratchSpace>::prod_with_mat_rlwe_inplace_scratch_space(module, lhs, rhs) } } -impl FromProdByScratchSpace for RGSWCt, FFT64> { - fn from_prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_mat_rlwe_scratch_space( - module, res_size, lhs, rhs, - ) +impl ProdScratchSpace for RGSWCt, FFT64> { + fn prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatRLWEProductScratchSpace>::prod_with_mat_rlwe_scratch_space(module, res_size, lhs, rhs) } - fn from_prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_mat_rlwe_scratch_space( - module, res_size, lhs, rhs, - ) + fn prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatRLWEProductScratchSpace>::prod_with_mat_rlwe_scratch_space(module, res_size, lhs, rhs) } } -impl ProdBy> for RGSWCt +impl ProdInplace for RGSWCt where RGSWCt: GetRow + SetRow + Infos, + MatZnxDft: MatZnxDftToRef, { - fn prod_by_grlwe(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef, - { - rhs.mul_mat_rlwe_inplace(module, self, scratch); + fn prod_by_grlwe_inplace(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch) { + rhs.prod_with_mat_rlwe_inplace(module, self, scratch); } - fn prod_by_rgsw(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef, - { - rhs.mul_mat_rlwe_inplace(module, self, scratch); + fn prod_by_rgsw_inplace(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch) { + rhs.prod_with_mat_rlwe_inplace(module, self, scratch); } } -impl FromProdBy, RGSWCt> for RGSWCt +impl Product for RGSWCt where - RGSWCt: GetRow + SetRow + Infos, - RGSWCt: GetRow + Infos, + MatZnxDft: MatZnxDftToRef + MatZnxDftToMut, + MatZnxDft: MatZnxDftToRef, { - fn from_prod_by_grlwe( - &mut self, - module: &Module, - lhs: &RGSWCt, - rhs: &GRLWECt, - scratch: &mut Scratch, - ) where - MatZnxDft: MatZnxDftToRef, - { - rhs.mul_mat_rlwe(module, self, lhs, scratch); + type Lhs = RGSWCt; + + fn prod_by_grlwe(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &GRLWECt, scratch: &mut Scratch) { + rhs.prod_with_mat_rlwe(module, self, lhs, scratch); } - fn from_prod_by_rgsw( - &mut self, - module: &Module, - lhs: &RGSWCt, - rhs: &RGSWCt, - scratch: &mut Scratch, - ) where - MatZnxDft: MatZnxDftToRef, - { - rhs.mul_mat_rlwe(module, self, lhs, scratch); + fn prod_by_rgsw(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &RGSWCt, scratch: &mut Scratch) { + rhs.prod_with_mat_rlwe(module, self, lhs, scratch); } } diff --git a/core/src/rlwe.rs b/core/src/rlwe.rs index ef1be64..2dab803 100644 --- a/core/src/rlwe.rs +++ b/core/src/rlwe.rs @@ -6,7 +6,7 @@ use base2k::{ use sampling::source::Source; use crate::{ - elem::{FromProdBy, FromProdByScratchSpace, Infos, MatZnxDftProducts, ProdBy, ProdByScratchSpace}, + elem::{Infos, MatRLWEProduct, MatRLWEProductScratchSpace, ProdInplace, ProdInplaceScratchSpace, ProdScratchSpace, Product}, grlwe::GRLWECt, keys::{PublicKey, SecretDistribution, SecretKeyDft}, rgsw::RGSWCt, @@ -84,70 +84,54 @@ where } } -impl ProdByScratchSpace for RLWECt> { - fn prod_by_grlwe_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_rlwe_inplace_scratch_space( - module, lhs, rhs, - ) +impl ProdInplaceScratchSpace for RLWECt> { + fn prod_by_grlwe_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatRLWEProductScratchSpace>::prod_with_rlwe_inplace_scratch_space(module, lhs, rhs) } - fn prod_by_rgsw_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_rlwe_inplace_scratch_space( - module, lhs, rhs, - ) + fn prod_by_rgsw_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatRLWEProductScratchSpace>::prod_with_rlwe_inplace_scratch_space(module, lhs, rhs) } } -impl FromProdByScratchSpace for RLWECt> { - fn from_prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_rlwe_scratch_space( - module, res_size, lhs, rhs, - ) +impl ProdScratchSpace for RLWECt> { + fn prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatRLWEProductScratchSpace>::prod_with_rlwe_scratch_space(module, res_size, lhs, rhs) } - fn from_prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_rlwe_scratch_space( - module, res_size, lhs, rhs, - ) + fn prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatRLWEProductScratchSpace>::prod_with_rlwe_scratch_space(module, res_size, lhs, rhs) } } -impl ProdBy> for RLWECt +impl ProdInplace for RLWECt where VecZnx: VecZnxToMut + VecZnxToRef, + MatZnxDft: MatZnxDftToRef, { - fn prod_by_grlwe(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef, - { - rhs.mul_rlwe_inplace(module, self, scratch); + fn prod_by_grlwe_inplace(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch) { + rhs.prod_with_rlwe_inplace(module, self, scratch); } - fn prod_by_rgsw(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef, - { - rhs.mul_rlwe_inplace(module, self, scratch); + fn prod_by_rgsw_inplace(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch) { + rhs.prod_with_rlwe_inplace(module, self, scratch); } } -impl FromProdBy, RLWECt> for RLWECt +impl Product for RLWECt where VecZnx: VecZnxToMut + VecZnxToRef, VecZnx: VecZnxToRef, + MatZnxDft: MatZnxDftToRef, { - fn from_prod_by_grlwe(&mut self, module: &Module, lhs: &RLWECt, rhs: &GRLWECt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef, - { - rhs.mul_rlwe(module, self, lhs, scratch); + type Lhs = RLWECt; + + fn prod_by_grlwe(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &GRLWECt, scratch: &mut Scratch) { + rhs.prod_with_rlwe(module, self, lhs, scratch); } - fn from_prod_by_rgsw(&mut self, module: &Module, lhs: &RLWECt, rhs: &RGSWCt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef, - { - rhs.mul_rlwe(module, self, lhs, scratch); + fn prod_by_rgsw(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &RGSWCt, scratch: &mut Scratch) { + rhs.prod_with_rlwe(module, self, lhs, scratch); } } @@ -496,7 +480,7 @@ where impl RLWECtDft where - VecZnxDft: VecZnxDftToRef, + RLWECtDft: VecZnxDftToRef, { #[allow(dead_code)] pub(crate) fn idft_scratch_space(module: &Module, size: usize) -> usize { @@ -505,7 +489,7 @@ where pub(crate) fn idft(&self, module: &Module, res: &mut RLWECt, scratch: &mut Scratch) where - VecZnx: VecZnxToMut, + RLWECt: VecZnxToMut, { #[cfg(debug_assertions)] { @@ -518,8 +502,8 @@ where let (mut res_big, scratch1) = scratch.tmp_vec_znx_big(module, 2, min_size); - module.vec_znx_idft(&mut res_big, 0, &self.data, 0, scratch1); - module.vec_znx_idft(&mut res_big, 1, &self.data, 1, scratch1); + module.vec_znx_idft(&mut res_big, 0, self, 0, scratch1); + module.vec_znx_idft(&mut res_big, 1, self, 1, scratch1); module.vec_znx_big_normalize(self.log_base2k(), res, 0, &res_big, 0, scratch1); module.vec_znx_big_normalize(self.log_base2k(), res, 1, &res_big, 1, scratch1); } @@ -665,79 +649,53 @@ impl RLWECtDft { } } -impl ProdByScratchSpace for RLWECtDft, FFT64> { - fn prod_by_grlwe_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_rlwe_dft_inplace_scratch_space( - module, lhs, rhs, - ) +impl ProdInplaceScratchSpace for RLWECtDft, FFT64> { + fn prod_by_grlwe_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatRLWEProductScratchSpace>::prod_with_rlwe_dft_inplace_scratch_space(module, lhs, rhs) } - fn prod_by_rgsw_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_rlwe_dft_inplace_scratch_space( - module, lhs, rhs, - ) + fn prod_by_rgsw_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatRLWEProductScratchSpace>::prod_with_rlwe_dft_inplace_scratch_space(module, lhs, rhs) } } -impl FromProdByScratchSpace for RLWECtDft, FFT64> { - fn from_prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_rlwe_dft_scratch_space( - module, res_size, lhs, rhs, - ) +impl ProdScratchSpace for RLWECtDft, FFT64> { + fn prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatRLWEProductScratchSpace>::prod_with_rlwe_dft_scratch_space(module, res_size, lhs, rhs) } - fn from_prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatZnxDftProducts, FFT64>, Vec>>::mul_rlwe_dft_scratch_space( - module, res_size, lhs, rhs, - ) + fn prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as MatRLWEProductScratchSpace>::prod_with_rlwe_dft_scratch_space(module, res_size, lhs, rhs) } } -impl ProdBy> for RLWECtDft +impl ProdInplace for RLWECtDft where VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, + MatZnxDft: MatZnxDftToRef, { - fn prod_by_grlwe(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef, - { - rhs.mul_rlwe_dft_inplace(module, self, scratch); + fn prod_by_grlwe_inplace(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch) { + rhs.prod_with_rlwe_dft_inplace(module, self, scratch); } - fn prod_by_rgsw(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef, - { - rhs.mul_rlwe_dft_inplace(module, self, scratch); + fn prod_by_rgsw_inplace(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch) { + rhs.prod_with_rlwe_dft_inplace(module, self, scratch); } } -impl FromProdBy, RLWECtDft> for RLWECtDft +impl Product for RLWECtDft where VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, VecZnxDft: VecZnxDftToRef, + MatZnxDft: MatZnxDftToRef, { - fn from_prod_by_grlwe( - &mut self, - module: &Module, - lhs: &RLWECtDft, - rhs: &GRLWECt, - scratch: &mut Scratch, - ) where - MatZnxDft: MatZnxDftToRef, - { - rhs.mul_rlwe_dft(module, self, lhs, scratch); + type Lhs = RLWECtDft; + + fn prod_by_grlwe(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &GRLWECt, scratch: &mut Scratch) { + rhs.prod_with_rlwe_dft(module, self, lhs, scratch); } - fn from_prod_by_rgsw( - &mut self, - module: &Module, - lhs: &RLWECtDft, - rhs: &RGSWCt, - scratch: &mut Scratch, - ) where - MatZnxDft: MatZnxDftToRef, - { - rhs.mul_rlwe_dft(module, self, lhs, scratch); + fn prod_by_rgsw(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &RGSWCt, scratch: &mut Scratch) { + rhs.prod_with_rlwe_dft(module, self, lhs, scratch); } } diff --git a/core/src/test_fft64/grlwe.rs b/core/src/test_fft64/grlwe.rs index 44fefd6..81c1023 100644 --- a/core/src/test_fft64/grlwe.rs +++ b/core/src/test_fft64/grlwe.rs @@ -1,504 +1,499 @@ -#[cfg(test)] - -mod tests { - use base2k::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, ZnxViewMut}; - use sampling::source::Source; - - use crate::{ - elem::{FromProdBy, FromProdByScratchSpace, Infos, ProdBy, ProdByScratchSpace}, - grlwe::GRLWECt, - keys::{SecretKey, SecretKeyDft}, - rgsw::RGSWCt, - rlwe::{RLWECtDft, RLWEPt}, - test_fft64::{grlwe::noise_grlwe_rlwe_product, rgsw::noise_rgsw_rlwe_product}, - }; - - #[test] - fn encrypt_sk() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 8; - let log_k_ct: usize = 54; - let rows: usize = 4; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_ct, rows); - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); - let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct.size()) | RLWECtDft::decrypt_scratch_space(&module, ct.size()), - ); - - let mut sk: SecretKey> = SecretKey::new(&module); - sk.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk_dft.dft(&module, &sk); - - ct.encrypt_sk( - &module, - &pt_scalar, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct); - - (0..ct.rows()).for_each(|row_i| { - ct.get_row(&module, row_i, &mut ct_rlwe_dft); - ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); - module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_scalar, 0); - let std_pt: f64 = pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2(); - assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); - }); - - module.free(); - } - - #[test] - fn from_prod_by_grlwe() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct_grlwe_s0s1: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_grlwe_s1s2: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_grlwe_s0s2: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe_s0s1.size()) - | RLWECtDft::decrypt_scratch_space(&module, ct_grlwe_s0s2.size()) - | GRLWECt::from_prod_by_grlwe_scratch_space( - &module, - ct_grlwe_s0s2.size(), - ct_grlwe_s0s1.size(), - ct_grlwe_s1s2.size(), - ), - ); - - let mut sk0: SecretKey> = SecretKey::new(&module); - sk0.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk0_dft.dft(&module, &sk0); - - let mut sk1: SecretKey> = SecretKey::new(&module); - sk1.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk1_dft.dft(&module, &sk1); - - let mut sk2: SecretKey> = SecretKey::new(&module); - sk2.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk2_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk2_dft.dft(&module, &sk2); - - // GRLWE_{s1}(s0) = s0 -> s1 - ct_grlwe_s0s1.encrypt_sk( - &module, - &sk0.data, - &sk1_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - // GRLWE_{s2}(s1) -> s1 -> s2 - ct_grlwe_s1s2.encrypt_sk( - &module, - &sk1.data, - &sk2_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - // GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) - ct_grlwe_s0s2.from_prod_by_grlwe(&module, &ct_grlwe_s0s1, &ct_grlwe_s1s2, scratch.borrow()); - - let mut ct_rlwe_dft_s0s2: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_grlwe); - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_grlwe); - - (0..ct_grlwe_s0s2.rows()).for_each(|row_i| { - ct_grlwe_s0s2.get_row(&module, row_i, &mut ct_rlwe_dft_s0s2); - ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); - module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, 0); - - let noise_have: f64 = pt.data.std(0, log_base2k).log2(); - let noise_want: f64 = noise_grlwe_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - log_k_grlwe, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - }); - - module.free(); - } - - #[test] - fn prod_by_grlwe() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct_grlwe_s0s1: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_grlwe_s1s2: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe_s0s1.size()) - | RLWECtDft::decrypt_scratch_space(&module, ct_grlwe_s0s1.size()) - | GRLWECt::prod_by_grlwe_scratch_space(&module, ct_grlwe_s0s1.size(), ct_grlwe_s1s2.size()), - ); - - let mut sk0: SecretKey> = SecretKey::new(&module); - sk0.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk0_dft.dft(&module, &sk0); - - let mut sk1: SecretKey> = SecretKey::new(&module); - sk1.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk1_dft.dft(&module, &sk1); - - let mut sk2: SecretKey> = SecretKey::new(&module); - sk2.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk2_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk2_dft.dft(&module, &sk2); - - // GRLWE_{s1}(s0) = s0 -> s1 - ct_grlwe_s0s1.encrypt_sk( - &module, - &sk0.data, - &sk1_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - // GRLWE_{s2}(s1) -> s1 -> s2 - ct_grlwe_s1s2.encrypt_sk( - &module, - &sk1.data, - &sk2_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - // GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) - ct_grlwe_s0s1.prod_by_grlwe(&module, &ct_grlwe_s1s2, scratch.borrow()); - - let ct_grlwe_s0s2: GRLWECt, FFT64> = ct_grlwe_s0s1; - - let mut ct_rlwe_dft_s0s2: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_grlwe); - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_grlwe); - - (0..ct_grlwe_s0s2.rows()).for_each(|row_i| { - ct_grlwe_s0s2.get_row(&module, row_i, &mut ct_rlwe_dft_s0s2); - ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); - module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, 0); - - let noise_have: f64 = pt.data.std(0, log_base2k).log2(); - let noise_want: f64 = noise_grlwe_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - log_k_grlwe, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - }); - - module.free(); - } - - #[test] - fn from_prod_by_rgsw() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct_grlwe_in: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_grlwe_out: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); - - let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_grlwe: ScalarZnx> = module.new_scalar_znx(1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe_in.size()) - | RLWECtDft::decrypt_scratch_space(&module, ct_grlwe_out.size()) - | GRLWECt::from_prod_by_rgsw_scratch_space( - &module, - ct_grlwe_out.size(), - ct_grlwe_in.size(), - ct_rgsw.size(), - ) - | RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()), - ); - - let k: usize = 1; - - pt_rgsw.raw_mut()[k] = 1; // X^{k} - - pt_grlwe.fill_ternary_prob(0, 0.5, &mut source_xs); - - let mut sk: SecretKey> = SecretKey::new(&module); - sk.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk_dft.dft(&module, &sk); - - // GRLWE_{s1}(s0) = s0 -> s1 - ct_grlwe_in.encrypt_sk( - &module, - &pt_grlwe, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rgsw.encrypt_sk( - &module, - &pt_rgsw, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - // GRLWE_(m) (x) RGSW_(X^k) = GRLWE_(m * X^k) - ct_grlwe_out.from_prod_by_rgsw(&module, &ct_grlwe_in, &ct_rgsw, scratch.borrow()); - - let mut ct_rlwe_dft_s0s2: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_grlwe); - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_grlwe); - - module.vec_znx_rotate_inplace(k as i64, &mut pt_grlwe, 0); - - (0..ct_grlwe_out.rows()).for_each(|row_i| { - ct_grlwe_out.get_row(&module, row_i, &mut ct_rlwe_dft_s0s2); - ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); - module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_grlwe, 0); - - let noise_have: f64 = pt.data.std(0, log_base2k).log2(); - - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_want: f64 = noise_rgsw_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - log_k_grlwe, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - }); - - module.free(); - } - - #[test] - fn prod_by_rgsw() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); - - let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_grlwe: ScalarZnx> = module.new_scalar_znx(1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) - | RLWECtDft::decrypt_scratch_space(&module, ct_grlwe.size()) - | GRLWECt::prod_by_rgsw_scratch_space(&module, ct_grlwe.size(), ct_rgsw.size()) - | RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()), - ); - - let k: usize = 1; - - pt_rgsw.raw_mut()[k] = 1; // X^{k} - - pt_grlwe.fill_ternary_prob(0, 0.5, &mut source_xs); - - let mut sk: SecretKey> = SecretKey::new(&module); - sk.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk_dft.dft(&module, &sk); - - // GRLWE_{s1}(s0) = s0 -> s1 - ct_grlwe.encrypt_sk( - &module, - &pt_grlwe, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rgsw.encrypt_sk( - &module, - &pt_rgsw, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - // GRLWE_(m) (x) RGSW_(X^k) = GRLWE_(m * X^k) - ct_grlwe.prod_by_rgsw(&module, &ct_rgsw, scratch.borrow()); - - let mut ct_rlwe_dft_s0s2: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_grlwe); - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_grlwe); - - module.vec_znx_rotate_inplace(k as i64, &mut pt_grlwe, 0); - - (0..ct_grlwe.rows()).for_each(|row_i| { - ct_grlwe.get_row(&module, row_i, &mut ct_rlwe_dft_s0s2); - ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); - module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_grlwe, 0); - - let noise_have: f64 = pt.data.std(0, log_base2k).log2(); - - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_want: f64 = noise_rgsw_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - log_k_grlwe, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - }); - - module.free(); - } +use base2k::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, ZnxViewMut}; +use sampling::source::Source; + +use crate::{ + elem::{GetRow, Infos, ProdInplace, ProdInplaceScratchSpace, ProdScratchSpace, Product}, + grlwe::GRLWECt, + keys::{SecretKey, SecretKeyDft}, + rgsw::RGSWCt, + rlwe::{RLWECtDft, RLWEPt}, + test_fft64::rgsw::noise_rgsw_product, +}; + +#[test] +fn encrypt_sk() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 8; + let log_k_ct: usize = 54; + let rows: usize = 4; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_ct, rows); + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct.size()) | RLWECtDft::decrypt_scratch_space(&module, ct.size()), + ); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + ct.encrypt_sk( + &module, + &pt_scalar, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct); + + (0..ct.rows()).for_each(|row_i| { + ct.get_row(&module, row_i, 0, &mut ct_rlwe_dft); + ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_scalar, 0); + let std_pt: f64 = pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2(); + assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); + }); + + module.free(); +} + +#[test] +fn from_prod_by_grlwe() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe_s0s1: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe_s1s2: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe_s0s2: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe_s0s1.size()) + | RLWECtDft::decrypt_scratch_space(&module, ct_grlwe_s0s2.size()) + | GRLWECt::prod_by_grlwe_scratch_space( + &module, + ct_grlwe_s0s2.size(), + ct_grlwe_s0s1.size(), + ct_grlwe_s1s2.size(), + ), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk1_dft.dft(&module, &sk1); + + let mut sk2: SecretKey> = SecretKey::new(&module); + sk2.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk2_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk2_dft.dft(&module, &sk2); + + // GRLWE_{s1}(s0) = s0 -> s1 + ct_grlwe_s0s1.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + // GRLWE_{s2}(s1) -> s1 -> s2 + ct_grlwe_s1s2.encrypt_sk( + &module, + &sk1.data, + &sk2_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + // GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) + ct_grlwe_s0s2.prod_by_grlwe(&module, &ct_grlwe_s0s1, &ct_grlwe_s1s2, scratch.borrow()); + + let mut ct_rlwe_dft_s0s2: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_grlwe); + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_grlwe); + + (0..ct_grlwe_s0s2.rows()).for_each(|row_i| { + ct_grlwe_s0s2.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); + ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, 0); + + let noise_have: f64 = pt.data.std(0, log_base2k).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_grlwe, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + }); + + module.free(); +} + +#[test] +fn prod_by_grlwe() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe_s0s1: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe_s1s2: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe_s0s1.size()) + | RLWECtDft::decrypt_scratch_space(&module, ct_grlwe_s0s1.size()) + | GRLWECt::prod_by_grlwe_inplace_scratch_space(&module, ct_grlwe_s0s1.size(), ct_grlwe_s1s2.size()), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk1_dft.dft(&module, &sk1); + + let mut sk2: SecretKey> = SecretKey::new(&module); + sk2.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk2_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk2_dft.dft(&module, &sk2); + + // GRLWE_{s1}(s0) = s0 -> s1 + ct_grlwe_s0s1.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + // GRLWE_{s2}(s1) -> s1 -> s2 + ct_grlwe_s1s2.encrypt_sk( + &module, + &sk1.data, + &sk2_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + // GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) + ct_grlwe_s0s1.prod_by_grlwe_inplace(&module, &ct_grlwe_s1s2, scratch.borrow()); + + let ct_grlwe_s0s2: GRLWECt, FFT64> = ct_grlwe_s0s1; + + let mut ct_rlwe_dft_s0s2: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_grlwe); + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_grlwe); + + (0..ct_grlwe_s0s2.rows()).for_each(|row_i| { + ct_grlwe_s0s2.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); + ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, 0); + + let noise_have: f64 = pt.data.std(0, log_base2k).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_grlwe, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + }); + + module.free(); +} + +#[test] +fn from_prod_by_rgsw() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe_in: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe_out: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); + + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_grlwe: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe_in.size()) + | RLWECtDft::decrypt_scratch_space(&module, ct_grlwe_out.size()) + | GRLWECt::prod_by_rgsw_scratch_space( + &module, + ct_grlwe_out.size(), + ct_grlwe_in.size(), + ct_rgsw.size(), + ) + | RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()), + ); + + let k: usize = 1; + + pt_rgsw.raw_mut()[k] = 1; // X^{k} + + pt_grlwe.fill_ternary_prob(0, 0.5, &mut source_xs); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + // GRLWE_{s1}(s0) = s0 -> s1 + ct_grlwe_in.encrypt_sk( + &module, + &pt_grlwe, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + // GRLWE_(m) (x) RGSW_(X^k) = GRLWE_(m * X^k) + ct_grlwe_out.prod_by_rgsw(&module, &ct_grlwe_in, &ct_rgsw, scratch.borrow()); + + let mut ct_rlwe_dft_s0s2: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_grlwe); + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_grlwe); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_grlwe, 0); + + (0..ct_grlwe_out.rows()).for_each(|row_i| { + ct_grlwe_out.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); + ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_grlwe, 0); + + let noise_have: f64 = pt.data.std(0, log_base2k).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_rgsw_product( + module.n() as f64, + log_base2k, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + log_k_grlwe, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + }); + + module.free(); +} + +#[test] +fn prod_by_rgsw() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); + + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_grlwe: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | RLWECtDft::decrypt_scratch_space(&module, ct_grlwe.size()) + | GRLWECt::prod_by_rgsw_inplace_scratch_space(&module, ct_grlwe.size(), ct_rgsw.size()) + | RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()), + ); + + let k: usize = 1; + + pt_rgsw.raw_mut()[k] = 1; // X^{k} + + pt_grlwe.fill_ternary_prob(0, 0.5, &mut source_xs); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + // GRLWE_{s1}(s0) = s0 -> s1 + ct_grlwe.encrypt_sk( + &module, + &pt_grlwe, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + // GRLWE_(m) (x) RGSW_(X^k) = GRLWE_(m * X^k) + ct_grlwe.prod_by_rgsw_inplace(&module, &ct_rgsw, scratch.borrow()); + + let mut ct_rlwe_dft_s0s2: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_grlwe); + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_grlwe); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_grlwe, 0); + + (0..ct_grlwe.rows()).for_each(|row_i| { + ct_grlwe.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); + ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_grlwe, 0); + + let noise_have: f64 = pt.data.std(0, log_base2k).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_rgsw_product( + module.n() as f64, + log_base2k, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + log_k_grlwe, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + }); + + module.free(); } -#[allow(dead_code)] pub(crate) fn noise_grlwe_rlwe_product( n: f64, log_base2k: usize, diff --git a/core/src/test_fft64/rgsw.rs b/core/src/test_fft64/rgsw.rs index 83df85b..50cd356 100644 --- a/core/src/test_fft64/rgsw.rs +++ b/core/src/test_fft64/rgsw.rs @@ -1,95 +1,582 @@ -#[cfg(test)] -mod tests { - use base2k::{ - FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, - VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxZero, - }; - use sampling::source::Source; +use base2k::{ + FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, + VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, VecZnxToMut, ZnxViewMut, ZnxZero, +}; +use sampling::source::Source; - use crate::{ - elem::{GetRow, Infos}, - keys::{SecretKey, SecretKeyDft}, - rgsw::RGSWCt, - rlwe::{RLWECt, RLWECtDft, RLWEPt}, - test_fft64::rgsw::noise_rgsw_rlwe_product, - }; +use crate::{ + elem::{GetRow, Infos, ProdInplace, ProdInplaceScratchSpace, ProdScratchSpace, Product}, + grlwe::GRLWECt, + keys::{SecretKey, SecretKeyDft}, + rgsw::RGSWCt, + rlwe::{RLWECtDft, RLWEPt}, + test_fft64::grlwe::noise_grlwe_rlwe_product, +}; - #[test] - fn encrypt_rgsw_sk() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 8; - let log_k_ct: usize = 54; - let rows: usize = 4; +#[test] +fn encrypt_rgsw_sk() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 8; + let log_k_ct: usize = 54; + let rows: usize = 4; - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; - let mut ct: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_ct, rows); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); - let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); + let mut ct: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_ct, rows); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); - pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); - let mut scratch: ScratchOwned = ScratchOwned::new( - RGSWCt::encrypt_sk_scratch_space(&module, ct.size()) | RLWECtDft::decrypt_scratch_space(&module, ct.size()), - ); + let mut scratch: ScratchOwned = ScratchOwned::new( + RGSWCt::encrypt_sk_scratch_space(&module, ct.size()) | RLWECtDft::decrypt_scratch_space(&module, ct.size()), + ); - let mut sk: SecretKey> = SecretKey::new(&module); - sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk_dft.dft(&module, &sk); + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); - ct.encrypt_sk( - &module, - &pt_scalar, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); + ct.encrypt_sk( + &module, + &pt_scalar, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); - let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct); - let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct.size()); - let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct.size()); + let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct); + let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct.size()); + let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct.size()); - (0..ct.cols()).for_each(|col_j| { - (0..ct.rows()).for_each(|row_i| { - module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0); + (0..ct.cols()).for_each(|col_j| { + (0..ct.rows()).for_each(|row_i| { + module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0); - if col_j == 1 { - module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0); - module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); - module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); - } + if col_j == 1 { + module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0); + module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); + module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); + } - ct.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); + ct.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); - ct_rlwe_dft.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + ct_rlwe_dft.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - let std_pt: f64 = pt_have.data.std(0, log_base2k) * (log_k_ct as f64).exp2(); - assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); + let std_pt: f64 = pt_have.data.std(0, log_base2k) * (log_k_ct as f64).exp2(); + assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); - pt_want.data.zero(); - }); + pt_want.data.zero(); }); + }); - module.free(); - } + module.free(); } -#[allow(dead_code)] -pub(crate) fn noise_rgsw_rlwe_product( +#[test] +fn from_prod_by_grlwe() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rgsw_in: usize = 45; + let log_k_rgsw_out: usize = 45; + let rows: usize = (log_k_rgsw_in + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rgsw_in: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_rgsw_in, rows); + let mut ct_rgsw_out: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_rgsw_out, rows); + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_rgsw.fill_ternary_prob(0, 0.5, &mut source_xs); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | RLWECtDft::decrypt_scratch_space(&module, ct_rgsw_out.size()) + | RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw_in.size()) + | RGSWCt::prod_by_grlwe_scratch_space( + &module, + ct_rgsw_out.size(), + ct_rgsw_in.size(), + ct_grlwe.size(), + ), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk1_dft.dft(&module, &sk1); + + ct_grlwe.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rgsw_in.encrypt_sk( + &module, + &pt_rgsw, + &sk0_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rgsw_out.prod_by_grlwe(&module, &ct_rgsw_in, &ct_grlwe, scratch.borrow()); + + let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rgsw_out); + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rgsw_out); + let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_out.size()); + let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw_out.size()); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rgsw_out); + + (0..ct_rgsw_out.cols()).for_each(|col_j| { + (0..ct_rgsw_out.rows()).for_each(|row_i| { + module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw, 0); + + if col_j == 1 { + module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); + module.svp_apply_inplace(&mut pt_dft, 0, &sk0_dft, 0); + module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); + module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); + } + + ct_rgsw_out.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); + ct_rlwe_dft.decrypt(&module, &mut pt, &sk1_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); + + let noise_have: f64 = pt.data.std(0, log_base2k).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_grlwe, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.2, + "have: {} want: {}", + noise_have, + noise_want + ); + + pt_want.data.zero(); + }); + }); + + module.free(); +} + +#[test] +fn from_prod_by_grlwe_inplace() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rgsw: usize = 45; + let rows: usize = (log_k_rgsw + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_rgsw, rows); + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_rgsw.fill_ternary_prob(0, 0.5, &mut source_xs); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | RLWECtDft::decrypt_scratch_space(&module, ct_rgsw.size()) + | RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + | RGSWCt::prod_by_grlwe_inplace_scratch_space(&module, ct_rgsw.size(), ct_grlwe.size()), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk1_dft.dft(&module, &sk1); + + ct_grlwe.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk0_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rgsw.prod_by_grlwe_inplace(&module, &ct_grlwe, scratch.borrow()); + + let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rgsw); + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rgsw); + let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw.size()); + let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw.size()); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rgsw); + + (0..ct_rgsw.cols()).for_each(|col_j| { + (0..ct_rgsw.rows()).for_each(|row_i| { + module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw, 0); + + if col_j == 1 { + module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); + module.svp_apply_inplace(&mut pt_dft, 0, &sk0_dft, 0); + module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); + module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); + } + + ct_rgsw.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); + ct_rlwe_dft.decrypt(&module, &mut pt, &sk1_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); + + let noise_have: f64 = pt.data.std(0, log_base2k).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_grlwe, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.2, + "have: {} want: {}", + noise_have, + noise_want + ); + + pt_want.data.zero(); + }); + }); + + module.free(); +} + +#[test] +fn from_prod_by_rgsw() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_rgsw_rhs: usize = 60; + let log_k_rgsw_lhs_in: usize = 45; + let log_k_rgsw_lhs_out: usize = 45; + let rows: usize = (log_k_rgsw_lhs_in + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_rgsw_rhs: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_rgsw_rhs, rows); + let mut ct_rgsw_lhs_in: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_rgsw_lhs_in, rows); + let mut ct_rgsw_lhs_out: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_rgsw_lhs_out, rows); + let mut pt_rgsw_lhs: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_rgsw_rhs: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_rgsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs); + + let k: usize = 1; + + pt_rgsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_rgsw_rhs.size()) + | RLWECtDft::decrypt_scratch_space(&module, ct_rgsw_lhs_out.size()) + | RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw_lhs_in.size()) + | RGSWCt::prod_by_rgsw_scratch_space( + &module, + ct_rgsw_lhs_out.size(), + ct_rgsw_lhs_in.size(), + ct_rgsw_rhs.size(), + ), + ); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + ct_rgsw_rhs.encrypt_sk( + &module, + &pt_rgsw_rhs, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rgsw_lhs_in.encrypt_sk( + &module, + &pt_rgsw_lhs, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rgsw_lhs_out.prod_by_rgsw(&module, &ct_rgsw_lhs_in, &ct_rgsw_rhs, scratch.borrow()); + + let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rgsw_lhs_out); + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rgsw_lhs_out); + let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_lhs_out.size()); + let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw_lhs_out.size()); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rgsw_lhs_out); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_rgsw_lhs, 0); + + (0..ct_rgsw_lhs_out.cols()).for_each(|col_j| { + (0..ct_rgsw_lhs_out.rows()).for_each(|row_i| { + module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw_lhs, 0); + + if col_j == 1 { + module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0); + module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); + module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); + } + + ct_rgsw_lhs_out.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); + ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); + + let noise_have: f64 = pt.data.std(0, log_base2k).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_rgsw_product( + module.n() as f64, + log_base2k, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + log_k_rgsw_lhs_in, + log_k_rgsw_rhs, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "have: {} want: {}", + noise_have, + noise_want + ); + + pt_want.data.zero(); + }); + }); + + module.free(); +} + +#[test] +fn from_prod_by_rgsw_inplace() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_rgsw_rhs: usize = 60; + let log_k_rgsw_lhs: usize = 45; + let rows: usize = (log_k_rgsw_lhs + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_rgsw_rhs: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_rgsw_rhs, rows); + let mut ct_rgsw_lhs: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_rgsw_lhs, rows); + let mut pt_rgsw_lhs: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_rgsw_rhs: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_rgsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs); + + let k: usize = 1; + + pt_rgsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_rgsw_rhs.size()) + | RLWECtDft::decrypt_scratch_space(&module, ct_rgsw_lhs.size()) + | RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw_lhs.size()) + | RGSWCt::prod_by_rgsw_inplace_scratch_space(&module, ct_rgsw_lhs.size(), ct_rgsw_rhs.size()), + ); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + ct_rgsw_rhs.encrypt_sk( + &module, + &pt_rgsw_rhs, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rgsw_lhs.encrypt_sk( + &module, + &pt_rgsw_lhs, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rgsw_lhs.prod_by_rgsw_inplace(&module, &ct_rgsw_rhs, scratch.borrow()); + + let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rgsw_lhs); + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rgsw_lhs); + let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_lhs.size()); + let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw_lhs.size()); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rgsw_lhs); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_rgsw_lhs, 0); + + (0..ct_rgsw_lhs.cols()).for_each(|col_j| { + (0..ct_rgsw_lhs.rows()).for_each(|row_i| { + module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw_lhs, 0); + + if col_j == 1 { + module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0); + module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); + module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); + } + + ct_rgsw_lhs.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); + ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); + + let noise_have: f64 = pt.data.std(0, log_base2k).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_rgsw_product( + module.n() as f64, + log_base2k, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + log_k_rgsw_lhs, + log_k_rgsw_rhs, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "have: {} want: {}", + noise_have, + noise_want + ); + + pt_want.data.zero(); + }); + }); + + module.free(); +} + +pub(crate) fn noise_rgsw_product( n: f64, log_base2k: usize, var_xs: f64, diff --git a/core/src/test_fft64/rlwe.rs b/core/src/test_fft64/rlwe.rs index acc10a1..a2fabb9 100644 --- a/core/src/test_fft64/rlwe.rs +++ b/core/src/test_fft64/rlwe.rs @@ -1,621 +1,618 @@ -#[cfg(test)] -mod tests_rlwe { - use base2k::{ - Decoding, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, - ZnxViewMut, ZnxZero, - }; - use itertools::izip; - use sampling::source::Source; +use base2k::{ + Decoding, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, + ZnxViewMut, ZnxZero, +}; +use itertools::izip; +use sampling::source::Source; - use crate::{ - elem::{FromProdBy, FromProdByScratchSpace, Infos, ProdBy, ProdByScratchSpace}, - grlwe::GRLWECt, - keys::{PublicKey, SecretKey, SecretKeyDft}, - rgsw::RGSWCt, - rlwe::{RLWECt, RLWECtDft, RLWEPt}, - test_fft64::{grlwe::noise_grlwe_rlwe_product, rgsw::noise_rgsw_rlwe_product}, - }; +use crate::{ + elem::{Infos, ProdInplace, ProdInplaceScratchSpace, ProdScratchSpace, Product}, + grlwe::GRLWECt, + keys::{PublicKey, SecretKey, SecretKeyDft}, + rgsw::RGSWCt, + rlwe::{RLWECt, RLWECtDft, RLWEPt}, + test_fft64::{grlwe::noise_grlwe_rlwe_product, rgsw::noise_rgsw_product}, +}; - #[test] - fn encrypt_sk() { - let module: Module = Module::::new(32); - let log_base2k: usize = 8; - let log_k_ct: usize = 54; - let log_k_pt: usize = 30; +#[test] +fn encrypt_sk() { + let module: Module = Module::::new(32); + let log_base2k: usize = 8; + let log_k_ct: usize = 54; + let log_k_pt: usize = 30; - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; - let mut ct: RLWECt> = RLWECt::new(&module, log_base2k, log_k_ct); - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_pt); + let mut ct: RLWECt> = RLWECt::new(&module, log_base2k, log_k_ct); + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_pt); - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::new( - RLWECt::encrypt_sk_scratch_space(&module, ct.size()) | RLWECt::decrypt_scratch_space(&module, ct.size()), - ); + let mut scratch: ScratchOwned = ScratchOwned::new( + RLWECt::encrypt_sk_scratch_space(&module, ct.size()) | RLWECt::decrypt_scratch_space(&module, ct.size()), + ); - let mut sk: SecretKey> = SecretKey::new(&module); - sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk_dft.dft(&module, &sk); + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); - let mut data_want: Vec = vec![0i64; module.n()]; + let mut data_want: Vec = vec![0i64; module.n()]; - data_want - .iter_mut() - .for_each(|x| *x = source_xa.next_i64() & 0xFF); + data_want + .iter_mut() + .for_each(|x| *x = source_xa.next_i64() & 0xFF); - pt.data - .encode_vec_i64(0, log_base2k, log_k_pt, &data_want, 10); + pt.data + .encode_vec_i64(0, log_base2k, log_k_pt, &data_want, 10); - ct.encrypt_sk( - &module, - Some(&pt), - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); + ct.encrypt_sk( + &module, + Some(&pt), + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); - pt.data.zero(); + pt.data.zero(); - ct.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + ct.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); - let mut data_have: Vec = vec![0i64; module.n()]; + let mut data_have: Vec = vec![0i64; module.n()]; - pt.data - .decode_vec_i64(0, log_base2k, pt.size() * log_base2k, &mut data_have); - - // TODO: properly assert the decryption noise through std(dec(ct) - pt) - let scale: f64 = (1 << (pt.size() * log_base2k - log_k_pt)) as f64; - izip!(data_want.iter(), data_have.iter()).for_each(|(a, b)| { - let b_scaled = (*b as f64) / scale; - assert!( - (*a as f64 - b_scaled).abs() < 0.1, - "{} {}", - *a as f64, - b_scaled - ) - }); - - module.free(); - } - - #[test] - fn encrypt_zero_sk() { - let module: Module = Module::::new(1024); - let log_base2k: usize = 8; - let log_k_ct: usize = 55; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([1u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut sk: SecretKey> = SecretKey::new(&module); - sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk_dft.dft(&module, &sk); - - let mut ct_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct); - - let mut scratch: ScratchOwned = ScratchOwned::new( - RLWECtDft::decrypt_scratch_space(&module, ct_dft.size()) - | RLWECtDft::encrypt_zero_sk_scratch_space(&module, ct_dft.size()), - ); - - ct_dft.encrypt_zero_sk( - &module, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - ct_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); - - assert!((sigma - pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2()) <= 0.2); - module.free(); - } - - #[test] - fn encrypt_pk() { - let module: Module = Module::::new(32); - let log_base2k: usize = 8; - let log_k_ct: usize = 54; - let log_k_pk: usize = 64; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct: RLWECt> = RLWECt::new(&module, log_base2k, log_k_ct); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - let mut source_xu: Source = Source::new([0u8; 32]); - - let mut sk: SecretKey> = SecretKey::new(&module); - sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk_dft.dft(&module, &sk); - - let mut pk: PublicKey, FFT64> = PublicKey::new(&module, log_base2k, log_k_pk); - pk.generate( - &module, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - ); - - let mut scratch: ScratchOwned = ScratchOwned::new( - RLWECt::encrypt_sk_scratch_space(&module, ct.size()) - | RLWECt::decrypt_scratch_space(&module, ct.size()) - | RLWECt::encrypt_pk_scratch_space(&module, pk.size()), - ); - - let mut data_want: Vec = vec![0i64; module.n()]; - - data_want - .iter_mut() - .for_each(|x| *x = source_xa.next_i64() & 0); - - pt_want - .data - .encode_vec_i64(0, log_base2k, log_k_ct, &data_want, 10); - - ct.encrypt_pk( - &module, - Some(&pt_want), - &pk, - &mut source_xu, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); - - ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_want, 0, &pt_have, 0); - - assert!(((1.0f64 / 12.0).sqrt() - pt_want.data.std(0, log_base2k) * (log_k_ct as f64).exp2()).abs() < 0.2); - - module.free(); - } - - #[test] - fn from_prod_by_grlwe() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe_in: usize = 45; - let log_k_rlwe_out: usize = 60; - let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) - | RLWECt::from_prod_by_grlwe_scratch_space( - &module, - ct_rlwe_out.size(), - ct_rlwe_in.size(), - ct_grlwe.size(), - ), - ); - - let mut sk0: SecretKey> = SecretKey::new(&module); - sk0.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk0_dft.dft(&module, &sk0); - - let mut sk1: SecretKey> = SecretKey::new(&module); - sk1.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk1_dft.dft(&module, &sk1); - - ct_grlwe.encrypt_sk( - &module, - &sk0.data, - &sk1_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe_in.encrypt_sk( - &module, - Some(&pt_want), - &sk0_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe_out.from_prod_by_grlwe(&module, &ct_rlwe_in, &ct_grlwe, scratch.borrow()); - - ct_rlwe_out.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); - let noise_want: f64 = noise_grlwe_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - log_k_rlwe_in, - log_k_grlwe, - ); + pt.data + .decode_vec_i64(0, log_base2k, pt.size() * log_base2k, &mut data_have); + // TODO: properly assert the decryption noise through std(dec(ct) - pt) + let scale: f64 = (1 << (pt.size() * log_base2k - log_k_pt)) as f64; + izip!(data_want.iter(), data_have.iter()).for_each(|(a, b)| { + let b_scaled = (*b as f64) / scale; assert!( - (noise_have - noise_want).abs() <= 0.1, + (*a as f64 - b_scaled).abs() < 0.1, "{} {}", - noise_have, - noise_want - ); + *a as f64, + b_scaled + ) + }); - module.free(); - } - - #[test] - fn prod_grlwe() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe: usize = 45; - let rows: usize = (log_k_rlwe + log_base2k - 1) / log_base2k; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) - | RLWECt::prod_by_grlwe_scratch_space(&module, ct_rlwe.size(), ct_grlwe.size()), - ); - - let mut sk0: SecretKey> = SecretKey::new(&module); - sk0.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk0_dft.dft(&module, &sk0); - - let mut sk1: SecretKey> = SecretKey::new(&module); - sk1.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk1_dft.dft(&module, &sk1); - - ct_grlwe.encrypt_sk( - &module, - &sk0.data, - &sk1_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe.encrypt_sk( - &module, - Some(&pt_want), - &sk0_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe.prod_by_grlwe(&module, &ct_grlwe, scratch.borrow()); - - ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); - let noise_want: f64 = noise_grlwe_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - log_k_rlwe, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - - module.free(); - } - - #[test] - fn from_prod_by_rgsw() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe_in: usize = 45; - let log_k_rlwe_out: usize = 60; - let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); - let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); - - pt_want.to_mut().at_mut(0, 0)[1] = 1; - - let k: usize = 1; - - pt_rgsw.raw_mut()[k] = 1; // X^{k} - - let mut scratch: ScratchOwned = ScratchOwned::new( - RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) - | RLWECt::from_prod_by_rgsw_scratch_space( - &module, - ct_rlwe_out.size(), - ct_rlwe_in.size(), - ct_rgsw.size(), - ), - ); - - let mut sk: SecretKey> = SecretKey::new(&module); - sk.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk_dft.dft(&module, &sk); - - ct_rgsw.encrypt_sk( - &module, - &pt_rgsw, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe_in.encrypt_sk( - &module, - Some(&pt_want), - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe_out.from_prod_by_rgsw(&module, &ct_rlwe_in, &ct_rgsw, scratch.borrow()); - - ct_rlwe_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); - - module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); - - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); - - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_want: f64 = noise_rgsw_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - log_k_rlwe_in, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - - module.free(); - } - - #[test] - fn prod_by_rgsw() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe_in: usize = 45; - let log_k_rlwe_out: usize = 60; - let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); - - pt_want.to_mut().at_mut(0, 0)[1] = 1; - - let k: usize = 1; - - pt_rgsw.raw_mut()[k] = 1; // X^{k} - - let mut scratch: ScratchOwned = ScratchOwned::new( - RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) - | RLWECt::prod_by_rgsw_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size()), - ); - - let mut sk: SecretKey> = SecretKey::new(&module); - sk.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk_dft.dft(&module, &sk); - - ct_rgsw.encrypt_sk( - &module, - &pt_rgsw, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe.encrypt_sk( - &module, - Some(&pt_want), - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe.prod_by_rgsw(&module, &ct_rgsw, scratch.borrow()); - - ct_rlwe.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); - - module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); - - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); - - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_want: f64 = noise_rgsw_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - log_k_rlwe_in, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - - module.free(); - } + module.free(); +} + +#[test] +fn encrypt_zero_sk() { + let module: Module = Module::::new(1024); + let log_base2k: usize = 8; + let log_k_ct: usize = 55; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([1u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + let mut ct_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct); + + let mut scratch: ScratchOwned = ScratchOwned::new( + RLWECtDft::decrypt_scratch_space(&module, ct_dft.size()) + | RLWECtDft::encrypt_zero_sk_scratch_space(&module, ct_dft.size()), + ); + + ct_dft.encrypt_zero_sk( + &module, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + ct_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + + assert!((sigma - pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2()) <= 0.2); + module.free(); +} + +#[test] +fn encrypt_pk() { + let module: Module = Module::::new(32); + let log_base2k: usize = 8; + let log_k_ct: usize = 54; + let log_k_pk: usize = 64; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct: RLWECt> = RLWECt::new(&module, log_base2k, log_k_ct); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + let mut source_xu: Source = Source::new([0u8; 32]); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + let mut pk: PublicKey, FFT64> = PublicKey::new(&module, log_base2k, log_k_pk); + pk.generate( + &module, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + ); + + let mut scratch: ScratchOwned = ScratchOwned::new( + RLWECt::encrypt_sk_scratch_space(&module, ct.size()) + | RLWECt::decrypt_scratch_space(&module, ct.size()) + | RLWECt::encrypt_pk_scratch_space(&module, pk.size()), + ); + + let mut data_want: Vec = vec![0i64; module.n()]; + + data_want + .iter_mut() + .for_each(|x| *x = source_xa.next_i64() & 0); + + pt_want + .data + .encode_vec_i64(0, log_base2k, log_k_ct, &data_want, 10); + + ct.encrypt_pk( + &module, + Some(&pt_want), + &pk, + &mut source_xu, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + + ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_want, 0, &pt_have, 0); + + assert!(((1.0f64 / 12.0).sqrt() - pt_want.data.std(0, log_base2k) * (log_k_ct as f64).exp2()).abs() < 0.2); + + module.free(); +} + +#[test] +fn prod_by_grlwe() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe_in: usize = 45; + let log_k_rlwe_out: usize = 60; + let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | RLWECt::prod_by_grlwe_scratch_space( + &module, + ct_rlwe_out.size(), + ct_rlwe_in.size(), + ct_grlwe.size(), + ), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk1_dft.dft(&module, &sk1); + + ct_grlwe.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_in.encrypt_sk( + &module, + Some(&pt_want), + &sk0_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_out.prod_by_grlwe(&module, &ct_rlwe_in, &ct_grlwe, scratch.borrow()); + + ct_rlwe_out.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_rlwe_in, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); +} + +#[test] +fn prod_by_grlwe_inplace() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe: usize = 45; + let rows: usize = (log_k_rlwe + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | RLWECt::prod_by_grlwe_inplace_scratch_space(&module, ct_rlwe.size(), ct_grlwe.size()), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk1_dft.dft(&module, &sk1); + + ct_grlwe.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.encrypt_sk( + &module, + Some(&pt_want), + &sk0_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.prod_by_grlwe_inplace(&module, &ct_grlwe, scratch.borrow()); + + ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_rlwe, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); +} + +#[test] +fn prod_by_rgsw() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe_in: usize = 45; + let log_k_rlwe_out: usize = 60; + let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + pt_want.to_mut().at_mut(0, 0)[1] = 1; + + let k: usize = 1; + + pt_rgsw.raw_mut()[k] = 1; // X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::new( + RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | RLWECt::prod_by_grlwe_scratch_space( + &module, + ct_rlwe_out.size(), + ct_rlwe_in.size(), + ct_rgsw.size(), + ), + ); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_in.encrypt_sk( + &module, + Some(&pt_want), + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_out.prod_by_rgsw(&module, &ct_rlwe_in, &ct_rgsw, scratch.borrow()); + + ct_rlwe_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_rgsw_product( + module.n() as f64, + log_base2k, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + log_k_rlwe_in, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); +} + +#[test] +fn prod_by_rgsw_inplace() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe_in: usize = 45; + let log_k_rlwe_out: usize = 60; + let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + pt_want.to_mut().at_mut(0, 0)[1] = 1; + + let k: usize = 1; + + pt_rgsw.raw_mut()[k] = 1; // X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::new( + RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | RLWECt::prod_by_rgsw_inplace_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size()), + ); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.encrypt_sk( + &module, + Some(&pt_want), + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.prod_by_rgsw_inplace(&module, &ct_rgsw, scratch.borrow()); + + ct_rlwe.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_rgsw_product( + module.n() as f64, + log_base2k, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + log_k_rlwe_in, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); } diff --git a/core/src/test_fft64/rlwe_dft.rs b/core/src/test_fft64/rlwe_dft.rs index 448bdfb..fe71a09 100644 --- a/core/src/test_fft64/rlwe_dft.rs +++ b/core/src/test_fft64/rlwe_dft.rs @@ -1,448 +1,443 @@ -#[cfg(test)] -mod tests { - use crate::{ - elem::{FromProdBy, FromProdByScratchSpace, Infos, ProdBy, ProdByScratchSpace}, - grlwe::GRLWECt, - keys::{SecretKey, SecretKeyDft}, - rgsw::RGSWCt, - rlwe::{RLWECt, RLWECtDft, RLWEPt}, - test_fft64::{grlwe::noise_grlwe_rlwe_product, rgsw::noise_rgsw_rlwe_product}, - }; - use base2k::{ - FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, ZnxViewMut, - }; - use sampling::source::Source; - - #[test] - fn from_prod_by_grlwe() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe_in: usize = 45; - let log_k_rlwe_out: usize = 60; - let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_in_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); - let mut ct_rlwe_out_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_out); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) - | RLWECtDft::from_prod_by_grlwe_scratch_space( - &module, - ct_rlwe_out.size(), - ct_rlwe_in.size(), - ct_grlwe.size(), - ), - ); - - let mut sk0: SecretKey> = SecretKey::new(&module); - sk0.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk0_dft.dft(&module, &sk0); - - let mut sk1: SecretKey> = SecretKey::new(&module); - sk1.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk1_dft.dft(&module, &sk1); - - ct_grlwe.encrypt_sk( - &module, - &sk0.data, - &sk1_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe_in.encrypt_sk( - &module, - Some(&pt_want), - &sk0_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe_in.dft(&module, &mut ct_rlwe_in_dft); - ct_rlwe_out_dft.from_prod_by_grlwe(&module, &ct_rlwe_in_dft, &ct_grlwe, scratch.borrow()); - ct_rlwe_out_dft.idft(&module, &mut ct_rlwe_out, scratch.borrow()); - - ct_rlwe_out.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); - let noise_want: f64 = noise_grlwe_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - log_k_rlwe_in, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - - module.free(); - } - - #[test] - fn prod_by_grlwe() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe: usize = 45; - let rows: usize = (log_k_rlwe + log_base2k - 1) / log_base2k; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe); - let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) - | RLWECtDft::prod_by_grlwe_scratch_space(&module, ct_rlwe_dft.size(), ct_grlwe.size()), - ); - - let mut sk0: SecretKey> = SecretKey::new(&module); - sk0.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk0_dft.dft(&module, &sk0); - - let mut sk1: SecretKey> = SecretKey::new(&module); - sk1.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk1_dft.dft(&module, &sk1); - - ct_grlwe.encrypt_sk( - &module, - &sk0.data, - &sk1_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe.encrypt_sk( - &module, - Some(&pt_want), - &sk0_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe.dft(&module, &mut ct_rlwe_dft); - ct_rlwe_dft.prod_by_grlwe(&module, &ct_grlwe, scratch.borrow()); - ct_rlwe_dft.idft(&module, &mut ct_rlwe, scratch.borrow()); - - ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); - let noise_want: f64 = noise_grlwe_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - log_k_rlwe, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - - module.free(); - } - - #[test] - fn from_prod_by_rgsw() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe_in: usize = 45; - let log_k_rlwe_out: usize = 60; - let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); - let mut ct_rlwe_dft_in: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_dft_out: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_out); - let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); - - pt_want.to_mut().at_mut(0, 0)[1] = 1; - - let k: usize = 1; - - pt_rgsw.raw_mut()[k] = 1; // X^{k} - - let mut scratch: ScratchOwned = ScratchOwned::new( - RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) - | RLWECt::from_prod_by_rgsw_scratch_space( - &module, - ct_rlwe_out.size(), - ct_rlwe_in.size(), - ct_rgsw.size(), - ), - ); - - let mut sk: SecretKey> = SecretKey::new(&module); - sk.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk_dft.dft(&module, &sk); - - ct_rgsw.encrypt_sk( - &module, - &pt_rgsw, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe_in.encrypt_sk( - &module, - Some(&pt_want), - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe_in.dft(&module, &mut ct_rlwe_dft_in); - ct_rlwe_dft_out.from_prod_by_rgsw(&module, &ct_rlwe_dft_in, &ct_rgsw, scratch.borrow()); - ct_rlwe_dft_out.idft(&module, &mut ct_rlwe_out, scratch.borrow()); - - ct_rlwe_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); - - module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); - - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); - - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_want: f64 = noise_rgsw_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - log_k_rlwe_in, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - - module.free(); - } - - #[test] - fn prod_by_rgsw() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe_in: usize = 45; - let log_k_rlwe_out: usize = 60; - let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); - - pt_want.to_mut().at_mut(0, 0)[1] = 1; - - let k: usize = 1; - - pt_rgsw.raw_mut()[k] = 1; // X^{k} - - let mut scratch: ScratchOwned = ScratchOwned::new( - RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) - | RLWECt::prod_by_rgsw_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size()), - ); - - let mut sk: SecretKey> = SecretKey::new(&module); - sk.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk_dft.dft(&module, &sk); - - ct_rgsw.encrypt_sk( - &module, - &pt_rgsw, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe.encrypt_sk( - &module, - Some(&pt_want), - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - ct_rlwe.dft(&module, &mut ct_rlwe_dft); - ct_rlwe_dft.prod_by_rgsw(&module, &ct_rgsw, scratch.borrow()); - ct_rlwe_dft.idft(&module, &mut ct_rlwe, scratch.borrow()); - - ct_rlwe.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); - - module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); - - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); - - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_want: f64 = noise_rgsw_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - log_k_rlwe_in, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - - module.free(); - } +use crate::{ + elem::{Infos, ProdInplace, ProdInplaceScratchSpace, ProdScratchSpace, Product}, + grlwe::GRLWECt, + keys::{SecretKey, SecretKeyDft}, + rgsw::RGSWCt, + rlwe::{RLWECt, RLWECtDft, RLWEPt}, + test_fft64::{grlwe::noise_grlwe_rlwe_product, rgsw::noise_rgsw_product}, +}; +use base2k::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, ZnxViewMut}; +use sampling::source::Source; + +#[test] +fn by_grlwe_inplace() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe_in: usize = 45; + let log_k_rlwe_out: usize = 60; + let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_in_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); + let mut ct_rlwe_out_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_out); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | RLWECtDft::prod_by_grlwe_scratch_space( + &module, + ct_rlwe_out.size(), + ct_rlwe_in.size(), + ct_grlwe.size(), + ), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk1_dft.dft(&module, &sk1); + + ct_grlwe.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_in.encrypt_sk( + &module, + Some(&pt_want), + &sk0_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_in.dft(&module, &mut ct_rlwe_in_dft); + ct_rlwe_out_dft.prod_by_grlwe(&module, &ct_rlwe_in_dft, &ct_grlwe, scratch.borrow()); + ct_rlwe_out_dft.idft(&module, &mut ct_rlwe_out, scratch.borrow()); + + ct_rlwe_out.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_rlwe_in, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); +} + +#[test] +fn prod_by_grlwe_inplace() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe: usize = 45; + let rows: usize = (log_k_rlwe + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe); + let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | RLWECtDft::prod_by_grlwe_inplace_scratch_space(&module, ct_rlwe_dft.size(), ct_grlwe.size()), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk1_dft.dft(&module, &sk1); + + ct_grlwe.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.encrypt_sk( + &module, + Some(&pt_want), + &sk0_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.dft(&module, &mut ct_rlwe_dft); + ct_rlwe_dft.prod_by_grlwe_inplace(&module, &ct_grlwe, scratch.borrow()); + ct_rlwe_dft.idft(&module, &mut ct_rlwe, scratch.borrow()); + + ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_rlwe, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); +} + +#[test] +fn prod_by_rgsw() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe_in: usize = 45; + let log_k_rlwe_out: usize = 60; + let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); + let mut ct_rlwe_dft_in: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_dft_out: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_out); + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + pt_want.to_mut().at_mut(0, 0)[1] = 1; + + let k: usize = 1; + + pt_rgsw.raw_mut()[k] = 1; // X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::new( + RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | RLWECt::prod_by_rgsw_scratch_space( + &module, + ct_rlwe_out.size(), + ct_rlwe_in.size(), + ct_rgsw.size(), + ), + ); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_in.encrypt_sk( + &module, + Some(&pt_want), + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_in.dft(&module, &mut ct_rlwe_dft_in); + ct_rlwe_dft_out.prod_by_rgsw(&module, &ct_rlwe_dft_in, &ct_rgsw, scratch.borrow()); + ct_rlwe_dft_out.idft(&module, &mut ct_rlwe_out, scratch.borrow()); + + ct_rlwe_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_rgsw_product( + module.n() as f64, + log_base2k, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + log_k_rlwe_in, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); +} + +#[test] +fn prod_by_rgsw_inplace() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe_in: usize = 45; + let log_k_rlwe_out: usize = 60; + let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + pt_want.to_mut().at_mut(0, 0)[1] = 1; + + let k: usize = 1; + + pt_rgsw.raw_mut()[k] = 1; // X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::new( + RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | RLWECt::prod_by_rgsw_inplace_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size()), + ); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.encrypt_sk( + &module, + Some(&pt_want), + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.dft(&module, &mut ct_rlwe_dft); + ct_rlwe_dft.prod_by_rgsw_inplace(&module, &ct_rgsw, scratch.borrow()); + ct_rlwe_dft.idft(&module, &mut ct_rlwe, scratch.borrow()); + + ct_rlwe.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_rgsw_product( + module.n() as f64, + log_base2k, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + log_k_rlwe_in, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); }