From 9913040aa1b2b3a92041fa72c426816b330a9b32 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Fri, 9 May 2025 10:39:00 +0200 Subject: [PATCH] Added grlwe ops + tests --- base2k/src/scalar_znx.rs | 8 +- base2k/src/vec_znx.rs | 34 ++ base2k/src/vec_znx_big.rs | 35 +- base2k/src/vec_znx_big_ops.rs | 9 +- base2k/src/vec_znx_dft.rs | 37 +- base2k/src/vec_znx_dft_ops.rs | 27 +- base2k/src/znx_base.rs | 8 +- rlwe/src/elem_grlwe.rs | 265 ++++++++--- rlwe/src/elem_rgsw.rs | 94 +--- rlwe/src/elem_rlwe.rs | 282 +++--------- rlwe/src/keys.rs | 11 +- rlwe/src/lib.rs | 1 + rlwe/src/test_fft64/elem_grlwe.rs | 722 ++++++++++++++++++++++++++++++ rlwe/src/test_fft64/elem_rgsw.rs | 88 ++++ rlwe/src/test_fft64/elem_rlwe.rs | 196 ++++++++ rlwe/src/test_fft64/mod.rs | 3 + 16 files changed, 1435 insertions(+), 385 deletions(-) create mode 100644 rlwe/src/test_fft64/elem_grlwe.rs create mode 100644 rlwe/src/test_fft64/elem_rgsw.rs create mode 100644 rlwe/src/test_fft64/elem_rlwe.rs create mode 100644 rlwe/src/test_fft64/mod.rs diff --git a/base2k/src/scalar_znx.rs b/base2k/src/scalar_znx.rs index 28ee38a..108ba3f 100644 --- a/base2k/src/scalar_znx.rs +++ b/base2k/src/scalar_znx.rs @@ -1,5 +1,7 @@ use crate::znx_base::ZnxInfos; -use crate::{alloc_aligned, Backend, DataView, DataViewMut, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxSliceSize, ZnxView, ZnxViewMut}; +use crate::{ + Backend, DataView, DataViewMut, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxSliceSize, ZnxView, ZnxViewMut, alloc_aligned, +}; use rand::seq::SliceRandom; use rand_core::RngCore; use rand_distr::{Distribution, weighted::WeightedIndex}; @@ -144,7 +146,7 @@ impl ScalarZnxToMut for ScalarZnx> { } } -impl VecZnxToMut for ScalarZnx>{ +impl VecZnxToMut for ScalarZnx> { fn to_mut(&mut self) -> VecZnx<&mut [u8]> { VecZnx { data: self.data.as_mut_slice(), @@ -165,7 +167,7 @@ impl ScalarZnxToRef for ScalarZnx> { } } -impl VecZnxToRef for ScalarZnx>{ +impl VecZnxToRef for ScalarZnx> { fn to_ref(&self) -> VecZnx<&[u8]> { VecZnx { data: self.data.as_slice(), diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index 31459d4..b945b2c 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -1,6 +1,7 @@ use crate::DataView; use crate::DataViewMut; use crate::ZnxSliceSize; +use crate::ZnxZero; use crate::alloc_aligned; use crate::assert_alignement; use crate::cast_mut; @@ -182,6 +183,39 @@ fn normalize + AsRef<[u8]>>(log_base2k: usize, a: &mut VecZnx, } } +impl VecZnx +where + VecZnx: VecZnxToMut + ZnxInfos, +{ + /// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self]. + pub fn extract_column(&mut self, self_col: usize, a: &VecZnx, a_col: usize) + where + VecZnx: VecZnxToRef + ZnxInfos, + { + #[cfg(debug_assertions)] + { + assert!(self_col < self.cols()); + assert!(a_col < a.cols()); + } + + let min_size: usize = self.size.min(a.size()); + let max_size: usize = self.size; + + let mut self_mut: VecZnx<&mut [u8]> = self.to_mut(); + let a_ref: VecZnx<&[u8]> = a.to_ref(); + + (0..min_size).for_each(|i: usize| { + self_mut + .at_mut(self_col, i) + .copy_from_slice(a_ref.at(a_col, i)); + }); + + (min_size..max_size).for_each(|i| { + self_mut.zero_at(self_col, i); + }); + } +} + impl> fmt::Display for VecZnx { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!( diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index d8c1bdd..8b3223b 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,6 +1,6 @@ use crate::ffi::vec_znx_big; use crate::znx_base::{ZnxInfos, ZnxView}; -use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, alloc_aligned}; +use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxViewMut, ZnxZero, alloc_aligned}; use std::fmt; use std::marker::PhantomData; @@ -94,6 +94,39 @@ impl VecZnxBig { } } +impl VecZnxBig +where + VecZnxBig: VecZnxBigToMut + ZnxInfos, +{ + /// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self]. + pub fn extract_column(&mut self, self_col: usize, a: &VecZnxBig, a_col: usize) + where + VecZnxBig: VecZnxBigToRef + ZnxInfos, + { + #[cfg(debug_assertions)] + { + assert!(self_col < self.cols()); + assert!(a_col < a.cols()); + } + + let min_size: usize = self.size.min(a.size()); + let max_size: usize = self.size; + + let mut self_mut: VecZnxBig<&mut [u8], FFT64> = self.to_mut(); + let a_ref: VecZnxBig<&[u8], FFT64> = a.to_ref(); + + (0..min_size).for_each(|i: usize| { + self_mut + .at_mut(self_col, i) + .copy_from_slice(a_ref.at(a_col, i)); + }); + + (min_size..max_size).for_each(|i| { + self_mut.zero_at(self_col, i); + }); + } +} + pub type VecZnxBigOwned = VecZnxBig, B>; pub trait VecZnxBigToRef { diff --git a/base2k/src/vec_znx_big_ops.rs b/base2k/src/vec_znx_big_ops.rs index 809a1eb..8208c97 100644 --- a/base2k/src/vec_znx_big_ops.rs +++ b/base2k/src/vec_znx_big_ops.rs @@ -115,7 +115,9 @@ pub trait VecZnxBigOps { A: VecZnxToRef; /// Negates `a` inplace. - fn vec_znx_big_negate_inplace(&self, a: &mut A, a_col: usize) where A: VecZnxBigToMut; + fn vec_znx_big_negate_inplace(&self, a: &mut A, a_col: usize) + where + A: VecZnxBigToMut; /// Normalizes `a` and stores the result on `b`. /// @@ -506,7 +508,10 @@ impl VecZnxBigOps for Module { } } - fn vec_znx_big_negate_inplace(&self, a: &mut A, res_col: usize) where A: VecZnxBigToMut { + fn vec_znx_big_negate_inplace(&self, a: &mut A, res_col: usize) + where + A: VecZnxBigToMut, + { let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut(); #[cfg(debug_assertions)] { diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index 61e1be5..b4bc973 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -2,7 +2,9 @@ use std::marker::PhantomData; use crate::ffi::vec_znx_dft; use crate::znx_base::ZnxInfos; -use crate::{Backend, DataView, DataViewMut, FFT64, Module, VecZnxBig, ZnxSliceSize, ZnxView, alloc_aligned}; +use crate::{ + Backend, DataView, DataViewMut, FFT64, Module, VecZnxBig, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, alloc_aligned, +}; use std::fmt; pub struct VecZnxDft { @@ -89,6 +91,39 @@ impl>, B: Backend> VecZnxDft { } } +impl VecZnxDft +where + VecZnxDft: VecZnxDftToMut + ZnxInfos, +{ + /// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self]. + pub fn extract_column(&mut self, self_col: usize, a: &VecZnxDft, a_col: usize) + where + VecZnxDft: VecZnxDftToRef + ZnxInfos, + { + #[cfg(debug_assertions)] + { + assert!(self_col < self.cols()); + assert!(a_col < a.cols()); + } + + let min_size: usize = self.size.min(a.size()); + let max_size: usize = self.size; + + let mut self_mut: VecZnxDft<&mut [u8], FFT64> = self.to_mut(); + let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); + + (0..min_size).for_each(|i: usize| { + self_mut + .at_mut(self_col, i) + .copy_from_slice(a_ref.at(a_col, i)); + }); + + (min_size..max_size).for_each(|i| { + self_mut.zero_at(self_col, i); + }); + } +} + pub type VecZnxDftOwned = VecZnxDft, B>; impl VecZnxDft { diff --git a/base2k/src/vec_znx_dft_ops.rs b/base2k/src/vec_znx_dft_ops.rs index cf06cc2..282ef4d 100644 --- a/base2k/src/vec_znx_dft_ops.rs +++ b/base2k/src/vec_znx_dft_ops.rs @@ -47,7 +47,9 @@ pub trait VecZnxDftOps { where R: VecZnxBigToMut, A: VecZnxDftToMut; - fn vec_znx_idft_consume(&self, a: VecZnxDft, a_cols: usize) -> VecZnxBig + + /// Consumes a to return IDFT(a) in big coeff space. + fn vec_znx_idft_consume(&self, a: VecZnxDft) -> VecZnxBig where VecZnxDft: VecZnxDftToMut; @@ -103,25 +105,28 @@ impl VecZnxDftOps for Module { } } - fn vec_znx_idft_consume(&self, mut a: VecZnxDft, a_col: usize) -> VecZnxBig + fn vec_znx_idft_consume(&self, mut a: VecZnxDft) -> VecZnxBig where VecZnxDft: VecZnxDftToMut, { let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut(); unsafe { + // Rev col and rows because ZnxDft.sl() >= ZnxBig.sl() (0..a_mut.size()).for_each(|j| { - vec_znx_dft::vec_znx_idft_tmp_a( - self.ptr, - a_mut.at_mut_ptr(a_col, j) as *mut vec_znx_big::vec_znx_big_t, - 1 as u64, - a_mut.at_mut_ptr(a_col, j) as *mut vec_znx_dft::vec_znx_dft_t, - 1 as u64, - ) + (0..a_mut.cols()).for_each(|i| { + vec_znx_dft::vec_znx_idft_tmp_a( + self.ptr, + a_mut.at_mut_ptr(i, j) as *mut vec_znx_big::vec_znx_big_t, + 1 as u64, + a_mut.at_mut_ptr(i, j) as *mut vec_znx_dft::vec_znx_dft_t, + 1 as u64, + ) + }); }); - - a.into_big() } + + a.into_big() } fn vec_znx_idft_tmp_bytes(&self) -> usize { diff --git a/base2k/src/znx_base.rs b/base2k/src/znx_base.rs index 94da450..a168e18 100644 --- a/base2k/src/znx_base.rs +++ b/base2k/src/znx_base.rs @@ -101,25 +101,25 @@ pub trait ZnxViewMut: ZnxView + DataViewMut> { //(Jay)Note: Can't provide blanket impl. of ZnxView because Scalar is not known impl ZnxViewMut for T where T: ZnxView + DataViewMut> {} -pub trait ZnxZero: ZnxViewMut +pub trait ZnxZero: ZnxViewMut + ZnxSliceSize where Self: Sized, { fn zero(&mut self) { unsafe { - std::ptr::write_bytes(self.as_mut_ptr(), 0, self.n() * self.poly_count()); + std::ptr::write_bytes(self.as_mut_ptr(), 0, self.sl() * self.poly_count()); } } fn zero_at(&mut self, i: usize, j: usize) { unsafe { - std::ptr::write_bytes(self.at_mut_ptr(i, j), 0, self.n()); + std::ptr::write_bytes(self.at_mut_ptr(i, j), 0, self.sl()); } } } // Blanket implementations -impl ZnxZero for T where T: ZnxViewMut {} +impl ZnxZero for T where T: ZnxViewMut + ZnxSliceSize {} // WARNING should not work for mat_znx_dft but it does use std::ops::{Add, AddAssign, Div, Mul, Neg, Shl, Shr, Sub}; diff --git a/rlwe/src/elem_grlwe.rs b/rlwe/src/elem_grlwe.rs index a460ec4..b865c1e 100644 --- a/rlwe/src/elem_grlwe.rs +++ b/rlwe/src/elem_grlwe.rs @@ -1,6 +1,7 @@ use base2k::{ - Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, - ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxOps, + Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, + ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigOps, VecZnxBigScratch, + VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxZero, }; use sampling::source::Source; @@ -32,11 +33,23 @@ impl GRLWECt where MatZnxDft: MatZnxDftToRef, { - pub fn get_row(&self, module: &Module, i: usize, res: &mut RLWECtDft) + pub fn get_row(&self, module: &Module, row_i: usize, res: &mut RLWECtDft) where - VecZnxDft: VecZnxDftToMut, + VecZnxDft: VecZnxDftToMut, { - module.vmp_extract_row(res, self, i, 0); + module.vmp_extract_row(res, self, row_i, 0); + } +} + +impl GRLWECt +where + MatZnxDft: MatZnxDftToMut, +{ + pub fn set_row(&mut self, module: &Module, row_i: usize, a: &RLWECtDft) + where + VecZnxDft: VecZnxDftToRef, + { + module.vmp_prepare_row(self, row_i, 0, a); } } @@ -75,16 +88,42 @@ where } impl GRLWECt, FFT64> { - pub fn encrypt_sk_scratch_bytes(module: &Module, size: usize) -> usize { - RLWECt::encrypt_sk_scratch_bytes(module, size) + pub fn encrypt_sk_scratch_space(module: &Module, size: usize) -> usize { + RLWECt::encrypt_sk_scratch_space(module, size) + module.bytes_of_vec_znx(2, size) + module.bytes_of_vec_znx(1, size) + module.bytes_of_vec_znx_dft(2, size) } - // pub fn encrypt_pk_scratch_bytes(module: &Module, pk_size: usize) -> usize { - // RLWECt::encrypt_pk_scratch_bytes(module, pk_size) - // } + pub fn mul_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { + module.bytes_of_vec_znx_dft(2, grlwe_size) + + (module.vec_znx_big_normalize_tmp_bytes() + | (module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 1, 2, grlwe_size) + + module.bytes_of_vec_znx_dft(1, a_size))) + } + + pub fn mul_rlwe_inplace_scratch_space(module: &Module, res_size: usize, grlwe_size: usize) -> usize { + Self::mul_rlwe_scratch_space(module, res_size, res_size, grlwe_size) + } + + pub fn mul_rlwe_dft_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { + (Self::mul_rlwe_scratch_space(module, res_size, a_size, grlwe_size) | module.vec_znx_idft_tmp_bytes()) + + module.bytes_of_vec_znx(2, a_size) + + module.bytes_of_vec_znx(2, res_size) + } + + pub fn mul_rlwe_dft_inplace_scratch_space(module: &Module, res_size: usize, grlwe_size: usize) -> usize { + (Self::mul_rlwe_inplace_scratch_space(module, res_size, grlwe_size) | module.vec_znx_idft_tmp_bytes()) + + module.bytes_of_vec_znx(2, res_size) + } + + pub fn mul_grlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { + Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, grlwe_size) + module.bytes_of_vec_znx_dft(2, a_size) + } + + pub fn mul_grlwe_inplace_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { + Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, grlwe_size) + module.bytes_of_vec_znx_dft(2, a_size) + } } pub fn encrypt_grlwe_sk( @@ -170,72 +209,176 @@ impl GRLWECt { module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch, ) } -} -#[cfg(test)] -mod tests { - use base2k::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps}; - use sampling::source::Source; + pub fn mul_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + VecZnx: VecZnxToMut, + VecZnx: VecZnxToRef, + { + let log_base2k: usize = self.log_base2k(); - use crate::{ - elem::Infos, - elem_rlwe::{RLWECtDft, RLWEPt}, - keys::{SecretKey, SecretKeyDft}, - }; + #[cfg(debug_assertions)] + { + assert_eq!(res.log_base2k(), log_base2k); + assert_eq!(a.log_base2k(), log_base2k); + assert_eq!(self.n(), module.n()); + assert_eq!(res.n(), module.n()); + assert_eq!(a.n(), module.n()); + } - use super::GRLWECt; + let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, 2, self.size()); // Todo optimise - #[test] - fn encrypt_sk_fft64() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 8; - let log_k_ct: usize = 54; - let rows: usize = 4; + { + let (mut a1_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, 1, a.size()); + module.vec_znx_dft(&mut a1_dft, 0, a, 1); + module.vmp_apply(&mut res_dft, &a1_dft, self, scratch2); + } - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; + let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); - let mut ct: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_ct, rows); - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); - let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); + module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0); - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); + module.vec_znx_big_normalize(log_base2k, res, 0, &res_big, 0, scratch1); + module.vec_znx_big_normalize(log_base2k, res, 1, &res_big, 1, scratch1); + } - pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + pub fn mul_rlwe_inplace(&self, module: &Module, res: &mut RLWECt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + VecZnx: VecZnxToMut + VecZnxToRef, + { + unsafe { + let res_ptr: *mut RLWECt = res as *mut RLWECt; // This is ok because [Self::mul_rlwe] only updates res at the end. + self.mul_rlwe(&module, &mut *res_ptr, &*res_ptr, scratch); + } + } - let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_bytes(&module, ct.size()) | RLWECtDft::decrypt_scratch_bytes(&module, ct.size()), - ); + pub fn mul_rlwe_dft( + &self, + module: &Module, + res: &mut RLWECtDft, + a: &RLWECtDft, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef + ZnxInfos, + VecZnxDft: VecZnxDftToRef + ZnxInfos, + { + let log_base2k: usize = self.log_base2k(); - let mut sk: SecretKey> = SecretKey::new(&module); - sk.fill_ternary_prob(0.5, &mut source_xs); + #[cfg(debug_assertions)] + { + assert_eq!(res.log_base2k(), log_base2k); + assert_eq!(self.n(), module.n()); + assert_eq!(res.n(), module.n()); + } - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk_dft.dft(&module, &sk); + let (a_data, scratch_1) = scratch.tmp_vec_znx(module, 2, a.size()); - ct.encrypt_sk( - &module, - &pt_scalar, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); + let mut a_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> { + data: a_data, + log_base2k: a.log_base2k(), + log_k: a.log_k(), + }; - let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct, 2); + a.idft(module, &mut a_idft, scratch_1); - (0..ct.rows()).for_each(|row_i| { - ct.get_row(&module, row_i, &mut ct_rlwe_dft); - ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); - module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_scalar, 0); - let std_pt: f64 = pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2(); - assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); + let (res_data, scratch_2) = scratch_1.tmp_vec_znx(module, 2, res.size()); + + let mut res_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> { + data: res_data, + log_base2k: res.log_base2k(), + log_k: res.log_k(), + }; + + self.mul_rlwe(module, &mut res_idft, &a_idft, scratch_2); + + module.vec_znx_dft(res, 0, &res_idft, 0); + module.vec_znx_dft(res, 1, &res_idft, 1); + } + + pub fn mul_rlwe_dft_inplace(&self, module: &Module, res: &mut RLWECtDft, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef + ZnxInfos, + { + let log_base2k: usize = self.log_base2k(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.log_base2k(), log_base2k); + assert_eq!(self.n(), module.n()); + assert_eq!(res.n(), module.n()); + } + + let (res_data, scratch_1) = scratch.tmp_vec_znx(module, 2, res.size()); + + let mut res_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> { + data: res_data, + log_base2k: res.log_base2k(), + log_k: res.log_k(), + }; + + res.idft(module, &mut res_idft, scratch_1); + + self.mul_rlwe_inplace(module, &mut res_idft, scratch_1); + + module.vec_znx_dft(res, 0, &res_idft, 0); + module.vec_znx_dft(res, 1, &res_idft, 1); + } + + pub fn mul_grlwe( + &self, + module: &Module, + res: &mut GRLWECt, + a: &GRLWECt, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef + ZnxInfos, + MatZnxDft: MatZnxDftToRef + ZnxInfos, + { + let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, a.size()); + + let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { + data: tmp_row_data, + log_base2k: a.log_base2k(), + log_k: a.log_k(), + }; + + let min_rows: usize = res.rows().min(a.rows()); + + (0..min_rows).for_each(|row_i| { + a.get_row(module, row_i, &mut tmp_row); + self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); + res.set_row(module, row_i, &tmp_row); }); - module.free(); + tmp_row.data.zero(); + + (min_rows..res.rows()).for_each(|row_i| { + res.set_row(module, row_i, &tmp_row); + }) + } + + pub fn mul_grlwe_inplace(&self, module: &Module, res: &mut GRLWECt, scratch: &mut Scratch) + where + MatZnxDft: MatZnxDftToRef + ZnxInfos, + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef + ZnxInfos, + { + let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, res.size()); + + let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { + data: tmp_row_data, + log_base2k: res.log_base2k(), + log_k: res.log_k(), + }; + + (0..res.rows()).for_each(|row_i| { + res.get_row(module, row_i, &mut tmp_row); + self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1); + res.set_row(module, row_i, &tmp_row); + }); } } diff --git a/rlwe/src/elem_rgsw.rs b/rlwe/src/elem_rgsw.rs index 75d6583..eab378a 100644 --- a/rlwe/src/elem_rgsw.rs +++ b/rlwe/src/elem_rgsw.rs @@ -63,8 +63,8 @@ where } impl RGSWCt, FFT64> { - pub fn encrypt_sk_scratch_bytes(module: &Module, size: usize) -> usize { - RLWECt::encrypt_sk_scratch_bytes(module, size) + pub fn encrypt_sk_scratch_space(module: &Module, size: usize) -> usize { + RLWECt::encrypt_sk_scratch_space(module, size) + module.bytes_of_vec_znx(2, size) + module.bytes_of_vec_znx(1, size) + module.bytes_of_vec_znx_dft(2, size) @@ -169,93 +169,3 @@ impl RGSWCt { ) } } - -#[cfg(test)] -mod tests { - use base2k::{ - FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, - VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxZero, - }; - use sampling::source::Source; - - use crate::{ - elem::Infos, - elem_rlwe::{RLWECtDft, RLWEPt}, - keys::{SecretKey, SecretKeyDft}, - }; - - use super::RGSWCt; - - #[test] - fn encrypt_rgsw_sk_fft64() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 8; - let log_k_ct: usize = 54; - let rows: usize = 4; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_ct, rows); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); - let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); - - let mut scratch: ScratchOwned = ScratchOwned::new( - RGSWCt::encrypt_sk_scratch_bytes(&module, ct.size()) | RLWECtDft::decrypt_scratch_bytes(&module, ct.size()), - ); - - let mut sk: SecretKey> = SecretKey::new(&module); - sk.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk_dft.dft(&module, &sk); - - ct.encrypt_sk( - &module, - &pt_scalar, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct, 2); - let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct.size()); - let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct.size()); - - (0..ct.cols()).for_each(|col_j| { - (0..ct.rows()).for_each(|row_i| { - module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0); - - if col_j == 1 { - module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0); - module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); - module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); - } - - ct.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); - - ct_rlwe_dft.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - let std_pt: f64 = pt_have.data.std(0, log_base2k) * (log_k_ct as f64).exp2(); - assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); - - pt_want.data.zero(); - }); - }); - - module.free(); - } -} diff --git a/rlwe/src/elem_rlwe.rs b/rlwe/src/elem_rlwe.rs index 938b3c5..54cb4f9 100644 --- a/rlwe/src/elem_rlwe.rs +++ b/rlwe/src/elem_rlwe.rs @@ -1,12 +1,13 @@ use base2k::{ - AddNormal, Backend, FFT64, FillUniform, Module, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, - ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, - VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, + AddNormal, Backend, FFT64, FillUniform, MatZnxDft, MatZnxDftToRef, Module, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, + ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, + VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos, }; use sampling::source::Source; use crate::{ elem::Infos, + elem_grlwe::GRLWECt, keys::{PublicKey, SecretDistribution, SecretKeyDft}, utils::derive_size, }; @@ -18,9 +19,9 @@ pub struct RLWECt { } impl RLWECt> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize, cols: usize) -> Self { + pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { Self { - data: module.new_vec_znx(cols, derive_size(log_base2k, log_k)), + data: module.new_vec_znx(2, derive_size(log_base2k, log_k)), log_base2k: log_base2k, log_k: log_k, } @@ -61,6 +62,27 @@ where } } +impl RLWECt +where + VecZnx: VecZnxToRef, +{ + #[allow(dead_code)] + pub(crate) fn dft(&self, module: &Module, res: &mut RLWECtDft) + where + VecZnxDft: VecZnxDftToMut + ZnxInfos, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.cols(), 2); + assert_eq!(res.cols(), 2); + assert_eq!(self.log_base2k(), res.log_base2k()) + } + + module.vec_znx_dft(res, 0, self, 0); + module.vec_znx_dft(res, 1, self, 1); + } +} + pub struct RLWEPt { pub data: VecZnx, pub log_base2k: usize, @@ -118,9 +140,9 @@ pub struct RLWECtDft { } impl RLWECtDft, B> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize, cols: usize) -> Self { + pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { Self { - data: module.new_vec_znx_dft(cols, derive_size(log_base2k, log_k)), + data: module.new_vec_znx_dft(2, derive_size(log_base2k, log_k)), log_base2k: log_base2k, log_k: log_k, } @@ -161,18 +183,49 @@ where } } +impl RLWECtDft +where + VecZnxDft: VecZnxDftToRef, +{ + #[allow(dead_code)] + pub(crate) fn idft_scratch_space(module: &Module, size: usize) -> usize { + module.bytes_of_vec_znx(2, size) + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes()) + } + + pub(crate) fn idft(&self, module: &Module, res: &mut RLWECt, scratch: &mut Scratch) + where + VecZnx: VecZnxToMut + ZnxInfos, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.cols(), 2); + assert_eq!(res.cols(), 2); + assert_eq!(self.log_base2k(), res.log_base2k()) + } + + let min_size: usize = self.size().min(res.size()); + + let (mut res_big, scratch1) = scratch.tmp_vec_znx_big(module, 2, min_size); + + module.vec_znx_idft(&mut res_big, 0, &self.data, 0, scratch1); + module.vec_znx_idft(&mut res_big, 1, &self.data, 1, scratch1); + module.vec_znx_big_normalize(self.log_base2k(), res, 0, &res_big, 0, scratch1); + module.vec_znx_big_normalize(self.log_base2k(), res, 1, &res_big, 1, scratch1); + } +} + impl RLWECt> { - pub fn encrypt_sk_scratch_bytes(module: &Module, size: usize) -> usize { + pub fn encrypt_sk_scratch_space(module: &Module, size: usize) -> usize { (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) } - pub fn encrypt_pk_scratch_bytes(module: &Module, pk_size: usize) -> usize { + pub fn encrypt_pk_scratch_space(module: &Module, pk_size: usize) -> usize { ((module.bytes_of_vec_znx_dft(1, pk_size) + module.bytes_of_vec_znx_big(1, pk_size)) | module.bytes_of_scalar_znx(1)) + module.bytes_of_scalar_znx_dft(1) + module.vec_znx_big_normalize_tmp_bytes() } - pub fn decrypt_scratch_bytes(module: &Module, size: usize) -> usize { + pub fn decrypt_scratch_space(module: &Module, size: usize) -> usize { (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) } } @@ -393,14 +446,14 @@ pub(crate) fn encrypt_zero_rlwe_dft_sk( } impl RLWECtDft, FFT64> { - pub fn encrypt_zero_sk_scratch_bytes(module: &Module, size: usize) -> usize { + pub fn encrypt_zero_sk_scratch_space(module: &Module, size: usize) -> usize { (module.bytes_of_vec_znx(1, size) | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) + module.bytes_of_vec_znx(1, size) + module.vec_znx_big_normalize_tmp_bytes() } - pub fn decrypt_scratch_bytes(module: &Module, size: usize) -> usize { + pub fn decrypt_scratch_space(module: &Module, size: usize) -> usize { (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size) | (module.bytes_of_vec_znx_big(1, size) + module.vec_znx_idft_tmp_bytes())) @@ -475,6 +528,14 @@ impl RLWECtDft { { decrypt_rlwe_dft(module, pt, self, sk_dft, scratch); } + + pub fn mul_grlwe_assign(&mut self, module: &Module, a: &GRLWECt, scratch: &mut Scratch) + where + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + a.mul_rlwe_dft_inplace(module, self, scratch); + } } pub(crate) fn encrypt_rlwe_pk( @@ -517,6 +578,7 @@ pub(crate) fn encrypt_rlwe_pk( ), SecretDistribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu), SecretDistribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu), + SecretDistribution::ZERO => {} } module.svp_prepare(&mut u_dft, 0, &u, 0); @@ -542,199 +604,3 @@ pub(crate) fn encrypt_rlwe_pk( tmp_big.add_normal(log_base2k, 0, pk.log_k(), source_xe, sigma, bound); module.vec_znx_big_normalize(log_base2k, ct, 1, &tmp_big, 0, scratch_3); } - -#[cfg(test)] -mod tests { - use base2k::{Decoding, Encoding, FFT64, Module, ScratchOwned, Stats, VecZnxOps, ZnxZero}; - use itertools::izip; - use sampling::source::Source; - - use crate::{ - elem_rlwe::{Infos, RLWECt, RLWECtDft, RLWEPt}, - keys::{PublicKey, SecretKey, SecretKeyDft}, - }; - - #[test] - fn encrypt_sk_fft64() { - let module: Module = Module::::new(32); - let log_base2k: usize = 8; - let log_k_ct: usize = 54; - let log_k_pt: usize = 30; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct: RLWECt> = RLWECt::new(&module, log_base2k, log_k_ct, 2); - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_pt); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::new( - RLWECt::encrypt_sk_scratch_bytes(&module, ct.size()) | RLWECt::decrypt_scratch_bytes(&module, ct.size()), - ); - - let mut sk: SecretKey> = SecretKey::new(&module); - sk.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk_dft.dft(&module, &sk); - - let mut data_want: Vec = vec![0i64; module.n()]; - - data_want - .iter_mut() - .for_each(|x| *x = source_xa.next_i64() & 0xFF); - - pt.data - .encode_vec_i64(0, log_base2k, log_k_pt, &data_want, 10); - - ct.encrypt_sk( - &module, - Some(&pt), - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - pt.data.zero(); - - ct.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); - - let mut data_have: Vec = vec![0i64; module.n()]; - - pt.data - .decode_vec_i64(0, log_base2k, pt.size() * log_base2k, &mut data_have); - - // TODO: properly assert the decryption noise through std(dec(ct) - pt) - let scale: f64 = (1 << (pt.size() * log_base2k - log_k_pt)) as f64; - izip!(data_want.iter(), data_have.iter()).for_each(|(a, b)| { - let b_scaled = (*b as f64) / scale; - assert!( - (*a as f64 - b_scaled).abs() < 0.1, - "{} {}", - *a as f64, - b_scaled - ) - }); - - module.free(); - } - - #[test] - fn encrypt_zero_sk_fft64() { - let module: Module = Module::::new(1024); - let log_base2k: usize = 8; - let log_k_ct: usize = 55; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([1u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut sk: SecretKey> = SecretKey::new(&module); - sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk_dft.dft(&module, &sk); - - let mut ct_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct, 2); - - let mut scratch: ScratchOwned = ScratchOwned::new( - RLWECtDft::decrypt_scratch_bytes(&module, ct_dft.size()) - | RLWECtDft::encrypt_zero_sk_scratch_bytes(&module, ct_dft.size()), - ); - - ct_dft.encrypt_zero_sk( - &module, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - ct_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); - - assert!((sigma - pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2()) <= 0.2); - module.free(); - } - - #[test] - fn encrypt_pk_fft64() { - let module: Module = Module::::new(32); - let log_base2k: usize = 8; - let log_k_ct: usize = 54; - let log_k_pk: usize = 64; - - let sigma: f64 = 3.2; - let bound: f64 = sigma * 6.0; - - let mut ct: RLWECt> = RLWECt::new(&module, log_base2k, log_k_ct, 2); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - let mut source_xu: Source = Source::new([0u8; 32]); - - let mut sk: SecretKey> = SecretKey::new(&module); - sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); - sk_dft.dft(&module, &sk); - - let mut pk: PublicKey, FFT64> = PublicKey::new(&module, log_base2k, log_k_pk); - pk.generate( - &module, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - bound, - ); - - let mut scratch: ScratchOwned = ScratchOwned::new( - RLWECt::encrypt_sk_scratch_bytes(&module, ct.size()) - | RLWECt::decrypt_scratch_bytes(&module, ct.size()) - | RLWECt::encrypt_pk_scratch_bytes(&module, pk.size()), - ); - - let mut data_want: Vec = vec![0i64; module.n()]; - - data_want - .iter_mut() - .for_each(|x| *x = source_xa.next_i64() & 0); - - pt_want - .data - .encode_vec_i64(0, log_base2k, log_k_ct, &data_want, 10); - - ct.encrypt_pk( - &module, - Some(&pt_want), - &pk, - &mut source_xu, - &mut source_xe, - sigma, - bound, - scratch.borrow(), - ); - - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); - - ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_want, 0, &pt_have, 0); - - assert!(((1.0f64 / 12.0).sqrt() - pt_want.data.std(0, log_base2k) * (log_k_ct as f64).exp2()).abs() < 0.2); - - module.free(); - } -} diff --git a/rlwe/src/keys.rs b/rlwe/src/keys.rs index 2f7b2c7..19fda01 100644 --- a/rlwe/src/keys.rs +++ b/rlwe/src/keys.rs @@ -1,6 +1,7 @@ use base2k::{ Backend, FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxDftToMut, ScalarZnxDftToRef, ScalarZnxToMut, ScalarZnxToRef, ScratchOwned, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ZnxInfos, + ZnxZero, }; use sampling::source::Source; @@ -10,6 +11,7 @@ use crate::{elem::Infos, elem_rlwe::RLWECtDft}; pub enum SecretDistribution { TernaryFixed(usize), // Ternary with fixed Hamming weight TernaryProb(f64), // Ternary with probabilistic Hamming weight + ZERO, // Debug mod NONE, } @@ -40,6 +42,11 @@ where self.data.fill_ternary_hw(0, hw, source); self.dist = SecretDistribution::TernaryFixed(hw); } + + pub fn fill_zero(&mut self) { + self.data.zero(); + self.dist = SecretDistribution::ZERO; + } } impl ScalarZnxToMut for SecretKey @@ -117,7 +124,7 @@ pub struct PublicKey { impl PublicKey, B> { pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { Self { - data: RLWECtDft::new(module, log_base2k, log_k, 2), + data: RLWECtDft::new(module, log_base2k, log_k), dist: SecretDistribution::NONE, } } @@ -179,7 +186,7 @@ impl PublicKey { } // Its ok to allocate scratch space here since pk is usually generated only once. - let mut scratch: ScratchOwned = ScratchOwned::new(RLWECtDft::encrypt_zero_sk_scratch_bytes( + let mut scratch: ScratchOwned = ScratchOwned::new(RLWECtDft::encrypt_zero_sk_scratch_space( module, self.size(), )); diff --git a/rlwe/src/lib.rs b/rlwe/src/lib.rs index 9eea116..cad8dbc 100644 --- a/rlwe/src/lib.rs +++ b/rlwe/src/lib.rs @@ -3,4 +3,5 @@ pub mod elem_grlwe; pub mod elem_rgsw; pub mod elem_rlwe; pub mod keys; +mod test_fft64; mod utils; diff --git a/rlwe/src/test_fft64/elem_grlwe.rs b/rlwe/src/test_fft64/elem_grlwe.rs new file mode 100644 index 0000000..aa871f3 --- /dev/null +++ b/rlwe/src/test_fft64/elem_grlwe.rs @@ -0,0 +1,722 @@ +#[cfg(test)] + +mod test { + use base2k::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps}; + use sampling::source::Source; + + use crate::{ + elem::Infos, + elem_grlwe::GRLWECt, + elem_rlwe::{RLWECt, RLWECtDft, RLWEPt}, + keys::{SecretKey, SecretKeyDft}, + test_fft64::elem_grlwe::noise_grlwe_rlwe_product, + }; + + #[test] + fn encrypt_sk() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 8; + let log_k_ct: usize = 54; + let rows: usize = 4; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_ct, rows); + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct.size()) | RLWECtDft::decrypt_scratch_space(&module, ct.size()), + ); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + ct.encrypt_sk( + &module, + &pt_scalar, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct); + + (0..ct.rows()).for_each(|row_i| { + ct.get_row(&module, row_i, &mut ct_rlwe_dft); + ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_scalar, 0); + let std_pt: f64 = pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2(); + assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); + }); + + module.free(); + } + + #[test] + fn mul_rlwe() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe_in: usize = 45; + let log_k_rlwe_out: usize = 60; + let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | GRLWECt::mul_rlwe_scratch_space( + &module, + ct_rlwe_out.size(), + ct_rlwe_in.size(), + ct_grlwe.size(), + ), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk1_dft.dft(&module, &sk1); + + ct_grlwe.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_in.encrypt_sk( + &module, + Some(&pt_want), + &sk0_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_grlwe.mul_rlwe(&module, &mut ct_rlwe_out, &ct_rlwe_in, scratch.borrow()); + + ct_rlwe_out.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_rlwe_in, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); + } + + #[test] + fn mul_rlwe_inplace() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe: usize = 45; + let rows: usize = (log_k_rlwe + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | GRLWECt::mul_rlwe_scratch_space(&module, ct_rlwe.size(), ct_rlwe.size(), ct_grlwe.size()), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk1_dft.dft(&module, &sk1); + + ct_grlwe.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.encrypt_sk( + &module, + Some(&pt_want), + &sk0_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_grlwe.mul_rlwe_inplace(&module, &mut ct_rlwe, scratch.borrow()); + + ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_rlwe, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); + } + + #[test] + fn mul_rlwe_dft() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe_in: usize = 45; + let log_k_rlwe_out: usize = 60; + let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_in_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); + let mut ct_rlwe_out_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_out); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | GRLWECt::mul_rlwe_scratch_space( + &module, + ct_rlwe_out.size(), + ct_rlwe_in.size(), + ct_grlwe.size(), + ), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk1_dft.dft(&module, &sk1); + + ct_grlwe.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_in.encrypt_sk( + &module, + Some(&pt_want), + &sk0_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_in.dft(&module, &mut ct_rlwe_in_dft); + ct_grlwe.mul_rlwe_dft( + &module, + &mut ct_rlwe_out_dft, + &ct_rlwe_in_dft, + scratch.borrow(), + ); + ct_rlwe_out_dft.idft(&module, &mut ct_rlwe_out, scratch.borrow()); + + ct_rlwe_out.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_rlwe_in, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); + } + + #[test] + fn mul_rlwe_dft_inplace() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let log_k_rlwe: usize = 45; + let rows: usize = (log_k_rlwe + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe); + let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) + | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | GRLWECt::mul_rlwe_scratch_space(&module, ct_rlwe.size(), ct_rlwe.size(), ct_grlwe.size()), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk1_dft.dft(&module, &sk1); + + ct_grlwe.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.encrypt_sk( + &module, + Some(&pt_want), + &sk0_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.dft(&module, &mut ct_rlwe_dft); + ct_grlwe.mul_rlwe_dft_inplace(&module, &mut ct_rlwe_dft, scratch.borrow()); + ct_rlwe_dft.idft(&module, &mut ct_rlwe, scratch.borrow()); + + ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_rlwe, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + module.free(); + } + + #[test] + fn mul_grlwe() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe_s0s1: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe_s1s2: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe_s0s2: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe_s0s1.size()) + | RLWECtDft::decrypt_scratch_space(&module, ct_grlwe_s0s2.size()) + | GRLWECt::mul_grlwe_scratch_space( + &module, + ct_grlwe_s0s2.size(), + ct_grlwe_s0s1.size(), + ct_grlwe_s1s2.size(), + ), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk1_dft.dft(&module, &sk1); + + let mut sk2: SecretKey> = SecretKey::new(&module); + sk2.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk2_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk2_dft.dft(&module, &sk2); + + // GRLWE_{s1}(s0) = s0 -> s1 + ct_grlwe_s0s1.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + // GRLWE_{s2}(s1) -> s1 -> s2 + ct_grlwe_s1s2.encrypt_sk( + &module, + &sk1.data, + &sk2_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + // GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) + ct_grlwe_s1s2.mul_grlwe( + &module, + &mut ct_grlwe_s0s2, + &ct_grlwe_s0s1, + scratch.borrow(), + ); + + let mut ct_rlwe_dft_s0s2: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_grlwe); + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_grlwe); + + (0..ct_grlwe_s0s2.rows()).for_each(|row_i| { + ct_grlwe_s0s2.get_row(&module, row_i, &mut ct_rlwe_dft_s0s2); + ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, 0); + + let noise_have: f64 = pt.data.std(0, log_base2k).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_grlwe, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + }); + + module.free(); + } + + #[test] + fn mul_grlwe_inplace() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 12; + let log_k_grlwe: usize = 60; + let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe_s0s1: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe_s1s2: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe_s0s1.size()) + | RLWECtDft::decrypt_scratch_space(&module, ct_grlwe_s0s1.size()) + | GRLWECt::mul_grlwe_scratch_space( + &module, + ct_grlwe_s0s1.size(), + ct_grlwe_s0s1.size(), + ct_grlwe_s1s2.size(), + ), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk1_dft.dft(&module, &sk1); + + let mut sk2: SecretKey> = SecretKey::new(&module); + sk2.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk2_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk2_dft.dft(&module, &sk2); + + // GRLWE_{s1}(s0) = s0 -> s1 + ct_grlwe_s0s1.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + // GRLWE_{s2}(s1) -> s1 -> s2 + ct_grlwe_s1s2.encrypt_sk( + &module, + &sk1.data, + &sk2_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + // GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) + ct_grlwe_s1s2.mul_grlwe_inplace(&module, &mut ct_grlwe_s0s1, scratch.borrow()); + + let ct_grlwe_s0s2: GRLWECt, FFT64> = ct_grlwe_s0s1; + + let mut ct_rlwe_dft_s0s2: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_grlwe); + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_grlwe); + + (0..ct_grlwe_s0s2.rows()).for_each(|row_i| { + ct_grlwe_s0s2.get_row(&module, row_i, &mut ct_rlwe_dft_s0s2); + ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, 0); + + let noise_have: f64 = pt.data.std(0, log_base2k).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + log_base2k, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_grlwe, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + }); + + module.free(); + } +} + +#[allow(dead_code)] +pub(crate) fn noise_grlwe_rlwe_product( + n: f64, + log_base2k: usize, + var_xs: f64, + var_msg: f64, + var_a_err: f64, + var_gct_err_lhs: f64, + var_gct_err_rhs: f64, + a_logq: usize, + b_logq: usize, +) -> f64 { + let a_logq: usize = a_logq.min(b_logq); + let a_cols: usize = (a_logq + log_base2k - 1) / log_base2k; + + let b_scale = 2.0f64.powi(b_logq as i32); + let a_scale: f64 = 2.0f64.powi((b_logq - a_logq) as i32); + + let base: f64 = (1 << (log_base2k)) as f64; + let var_base: f64 = base * base / 12f64; + + // lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2) + // rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs + let mut noise: f64 = (a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs); + noise += var_msg * var_a_err * a_scale * a_scale * n; + noise = noise.sqrt(); + noise /= b_scale; + noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}] +} diff --git a/rlwe/src/test_fft64/elem_rgsw.rs b/rlwe/src/test_fft64/elem_rgsw.rs new file mode 100644 index 0000000..b7af5ca --- /dev/null +++ b/rlwe/src/test_fft64/elem_rgsw.rs @@ -0,0 +1,88 @@ +#[cfg(test)] +mod tests { + use base2k::{ + FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, + VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxZero, + }; + use sampling::source::Source; + + use crate::{ + elem::Infos, + elem_rgsw::RGSWCt, + elem_rlwe::{RLWECtDft, RLWEPt}, + keys::{SecretKey, SecretKeyDft}, + }; + + #[test] + fn encrypt_rgsw_sk() { + let module: Module = Module::::new(2048); + let log_base2k: usize = 8; + let log_k_ct: usize = 54; + let rows: usize = 4; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_ct, rows); + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + + let mut scratch: ScratchOwned = ScratchOwned::new( + RGSWCt::encrypt_sk_scratch_space(&module, ct.size()) | RLWECtDft::decrypt_scratch_space(&module, ct.size()), + ); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + ct.encrypt_sk( + &module, + &pt_scalar, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct); + let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct.size()); + let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct.size()); + + (0..ct.cols()).for_each(|col_j| { + (0..ct.rows()).for_each(|row_i| { + module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0); + + if col_j == 1 { + module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0); + module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); + module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); + } + + ct.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); + + ct_rlwe_dft.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let std_pt: f64 = pt_have.data.std(0, log_base2k) * (log_k_ct as f64).exp2(); + assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); + + pt_want.data.zero(); + }); + }); + + module.free(); + } +} diff --git a/rlwe/src/test_fft64/elem_rlwe.rs b/rlwe/src/test_fft64/elem_rlwe.rs new file mode 100644 index 0000000..d6f812b --- /dev/null +++ b/rlwe/src/test_fft64/elem_rlwe.rs @@ -0,0 +1,196 @@ +#[cfg(test)] +mod tests { + use base2k::{Decoding, Encoding, FFT64, Module, ScratchOwned, Stats, VecZnxOps, ZnxZero}; + use itertools::izip; + use sampling::source::Source; + + use crate::{ + elem::Infos, + elem_rlwe::{RLWECt, RLWECtDft, RLWEPt}, + keys::{PublicKey, SecretKey, SecretKeyDft}, + }; + + #[test] + fn encrypt_sk() { + let module: Module = Module::::new(32); + let log_base2k: usize = 8; + let log_k_ct: usize = 54; + let log_k_pt: usize = 30; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct: RLWECt> = RLWECt::new(&module, log_base2k, log_k_ct); + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_pt); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + RLWECt::encrypt_sk_scratch_space(&module, ct.size()) | RLWECt::decrypt_scratch_space(&module, ct.size()), + ); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + let mut data_want: Vec = vec![0i64; module.n()]; + + data_want + .iter_mut() + .for_each(|x| *x = source_xa.next_i64() & 0xFF); + + pt.data + .encode_vec_i64(0, log_base2k, log_k_pt, &data_want, 10); + + ct.encrypt_sk( + &module, + Some(&pt), + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + pt.data.zero(); + + ct.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + + let mut data_have: Vec = vec![0i64; module.n()]; + + pt.data + .decode_vec_i64(0, log_base2k, pt.size() * log_base2k, &mut data_have); + + // TODO: properly assert the decryption noise through std(dec(ct) - pt) + let scale: f64 = (1 << (pt.size() * log_base2k - log_k_pt)) as f64; + izip!(data_want.iter(), data_have.iter()).for_each(|(a, b)| { + let b_scaled = (*b as f64) / scale; + assert!( + (*a as f64 - b_scaled).abs() < 0.1, + "{} {}", + *a as f64, + b_scaled + ) + }); + + module.free(); + } + + #[test] + fn encrypt_zero_sk() { + let module: Module = Module::::new(1024); + let log_base2k: usize = 8; + let log_k_ct: usize = 55; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([1u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + let mut ct_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct); + + let mut scratch: ScratchOwned = ScratchOwned::new( + RLWECtDft::decrypt_scratch_space(&module, ct_dft.size()) + | RLWECtDft::encrypt_zero_sk_scratch_space(&module, ct_dft.size()), + ); + + ct_dft.encrypt_zero_sk( + &module, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + ct_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + + assert!((sigma - pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2()) <= 0.2); + module.free(); + } + + #[test] + fn encrypt_pk() { + let module: Module = Module::::new(32); + let log_base2k: usize = 8; + let log_k_ct: usize = 54; + let log_k_pk: usize = 64; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct: RLWECt> = RLWECt::new(&module, log_base2k, log_k_ct); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + let mut source_xu: Source = Source::new([0u8; 32]); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + let mut pk: PublicKey, FFT64> = PublicKey::new(&module, log_base2k, log_k_pk); + pk.generate( + &module, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + ); + + let mut scratch: ScratchOwned = ScratchOwned::new( + RLWECt::encrypt_sk_scratch_space(&module, ct.size()) + | RLWECt::decrypt_scratch_space(&module, ct.size()) + | RLWECt::encrypt_pk_scratch_space(&module, pk.size()), + ); + + let mut data_want: Vec = vec![0i64; module.n()]; + + data_want + .iter_mut() + .for_each(|x| *x = source_xa.next_i64() & 0); + + pt_want + .data + .encode_vec_i64(0, log_base2k, log_k_ct, &data_want, 10); + + ct.encrypt_pk( + &module, + Some(&pt_want), + &pk, + &mut source_xu, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + + ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_want, 0, &pt_have, 0); + + assert!(((1.0f64 / 12.0).sqrt() - pt_want.data.std(0, log_base2k) * (log_k_ct as f64).exp2()).abs() < 0.2); + + module.free(); + } +} diff --git a/rlwe/src/test_fft64/mod.rs b/rlwe/src/test_fft64/mod.rs new file mode 100644 index 0000000..edac310 --- /dev/null +++ b/rlwe/src/test_fft64/mod.rs @@ -0,0 +1,3 @@ +mod elem_grlwe; +mod elem_rgsw; +mod elem_rlwe;