diff --git a/base2k/src/module.rs b/base2k/src/module.rs index aab18b4..904d0ec 100644 --- a/base2k/src/module.rs +++ b/base2k/src/module.rs @@ -1,107 +1,107 @@ -use crate::GALOISGENERATOR; -use crate::ffi::module::{MODULE, delete_module_info, module_info_t, new_module_info}; -use std::marker::PhantomData; - -#[derive(Copy, Clone)] -#[repr(u8)] -pub enum BACKEND { - FFT64, - NTT120, -} - -pub trait Backend { - const KIND: BACKEND; - fn module_type() -> u32; -} - -pub struct FFT64; -pub struct NTT120; - -impl Backend for FFT64 { - const KIND: BACKEND = BACKEND::FFT64; - fn module_type() -> u32 { - 0 - } -} - -impl Backend for NTT120 { - const KIND: BACKEND = BACKEND::NTT120; - fn module_type() -> u32 { - 1 - } -} - -pub struct Module { - pub ptr: *mut MODULE, - n: usize, - _marker: PhantomData, -} - -impl Module { - // Instantiates a new module. - pub fn new(n: usize) -> Self { - unsafe { - let m: *mut module_info_t = new_module_info(n as u64, B::module_type()); - if m.is_null() { - panic!("Failed to create module."); - } - Self { - ptr: m, - n: n, - _marker: PhantomData, - } - } - } - - pub fn n(&self) -> usize { - self.n - } - - pub fn log_n(&self) -> usize { - (usize::BITS - (self.n() - 1).leading_zeros()) as _ - } - - pub fn cyclotomic_order(&self) -> u64 { - (self.n() << 1) as _ - } - - // Returns GALOISGENERATOR^|generator| * sign(generator) - pub fn galois_element(&self, generator: i64) -> i64 { - if generator == 0 { - return 1; - } - ((mod_exp_u64(GALOISGENERATOR, generator.abs() as usize) & (self.cyclotomic_order() - 1)) as i64) * generator.signum() - } - - // Returns gen^-1 - pub fn galois_element_inv(&self, generator: i64) -> i64 { - if generator == 0 { - panic!("cannot invert 0") - } - ((mod_exp_u64( - generator.abs() as u64, - (self.cyclotomic_order() - 1) as usize, - ) & (self.cyclotomic_order() - 1)) as i64) - * generator.signum() - } -} - -impl Drop for Module { - fn drop(&mut self) { - unsafe { delete_module_info(self.ptr) } - } -} - -fn mod_exp_u64(x: u64, e: usize) -> u64 { - let mut y: u64 = 1; - let mut x_pow: u64 = x; - let mut exp = e; - while exp > 0 { - if exp & 1 == 1 { - y = y.wrapping_mul(x_pow); - } - x_pow = x_pow.wrapping_mul(x_pow); - exp >>= 1; - } - y -} +use crate::GALOISGENERATOR; +use crate::ffi::module::{MODULE, delete_module_info, module_info_t, new_module_info}; +use std::marker::PhantomData; + +#[derive(Copy, Clone)] +#[repr(u8)] +pub enum BACKEND { + FFT64, + NTT120, +} + +pub trait Backend { + const KIND: BACKEND; + fn module_type() -> u32; +} + +pub struct FFT64; +pub struct NTT120; + +impl Backend for FFT64 { + const KIND: BACKEND = BACKEND::FFT64; + fn module_type() -> u32 { + 0 + } +} + +impl Backend for NTT120 { + const KIND: BACKEND = BACKEND::NTT120; + fn module_type() -> u32 { + 1 + } +} + +pub struct Module { + pub ptr: *mut MODULE, + n: usize, + _marker: PhantomData, +} + +impl Module { + // Instantiates a new module. + pub fn new(n: usize) -> Self { + unsafe { + let m: *mut module_info_t = new_module_info(n as u64, B::module_type()); + if m.is_null() { + panic!("Failed to create module."); + } + Self { + ptr: m, + n: n, + _marker: PhantomData, + } + } + } + + pub fn n(&self) -> usize { + self.n + } + + pub fn log_n(&self) -> usize { + (usize::BITS - (self.n() - 1).leading_zeros()) as _ + } + + pub fn cyclotomic_order(&self) -> u64 { + (self.n() << 1) as _ + } + + // Returns GALOISGENERATOR^|generator| * sign(generator) + pub fn galois_element(&self, generator: i64) -> i64 { + if generator == 0 { + return 1; + } + ((mod_exp_u64(GALOISGENERATOR, generator.abs() as usize) & (self.cyclotomic_order() - 1)) as i64) * generator.signum() + } + + // Returns gen^-1 + pub fn galois_element_inv(&self, generator: i64) -> i64 { + if generator == 0 { + panic!("cannot invert 0") + } + ((mod_exp_u64( + generator.abs() as u64, + (self.cyclotomic_order() - 1) as usize, + ) & (self.cyclotomic_order() - 1)) as i64) + * generator.signum() + } +} + +impl Drop for Module { + fn drop(&mut self) { + unsafe { delete_module_info(self.ptr) } + } +} + +fn mod_exp_u64(x: u64, e: usize) -> u64 { + let mut y: u64 = 1; + let mut x_pow: u64 = x; + let mut exp = e; + while exp > 0 { + if exp & 1 == 1 { + y = y.wrapping_mul(x_pow); + } + x_pow = x_pow.wrapping_mul(x_pow); + exp >>= 1; + } + y +} diff --git a/base2k/src/vec_znx_dft_ops.rs b/base2k/src/vec_znx_dft_ops.rs index 282ef4d..27e6f59 100644 --- a/base2k/src/vec_znx_dft_ops.rs +++ b/base2k/src/vec_znx_dft_ops.rs @@ -42,8 +42,13 @@ pub trait VecZnxDftOps { /// a new [VecZnxDft] through [VecZnxDft::from_bytes]. fn vec_znx_idft_tmp_bytes(&self) -> usize; + fn vec_znx_dft_copy(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef; + /// b <- IDFT(a), uses a as scratch space. - fn vec_znx_idft_tmp_a(&self, res: &mut R, res_col: usize, a: &mut A, a_cols: usize) + fn vec_znx_idft_tmp_a(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize) where R: VecZnxBigToMut, A: VecZnxDftToMut; @@ -79,13 +84,33 @@ impl VecZnxDftAlloc for Module { } impl VecZnxDftOps for Module { + fn vec_znx_dft_copy(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + { + let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); + + let min_size: usize = min(res_mut.size(), a_ref.size()); + + (0..min_size).for_each(|j| { + res_mut + .at_mut(res_col, j) + .copy_from_slice(a_ref.at(a_col, j)); + }); + (min_size..res_mut.size()).for_each(|j| { + res_mut.zero_at(res_col, j); + }) + } + fn vec_znx_idft_tmp_a(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize) where R: VecZnxBigToMut, A: VecZnxDftToMut, { - let mut res_mut = res.to_mut(); - let mut a_mut = a.to_mut(); + let mut res_mut: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut(); let min_size: usize = min(res_mut.size(), a_mut.size()); @@ -136,14 +161,14 @@ impl VecZnxDftOps for Module { /// b <- DFT(a) /// /// # Panics - /// If b.cols < a_cols + /// If b.cols < a_col fn vec_znx_dft(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxDftToMut, A: VecZnxToRef, { - let mut res_mut = res.to_mut(); - let a_ref = a.to_ref(); + let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a_ref: crate::VecZnx<&[u8]> = a.to_ref(); let min_size: usize = min(res_mut.size(), a_ref.size()); @@ -170,8 +195,8 @@ impl VecZnxDftOps for Module { R: VecZnxBigToMut, A: VecZnxDftToRef, { - let mut res_mut = res.to_mut(); - let a_ref = a.to_ref(); + let mut res_mut: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); let (tmp_bytes, _) = scratch.tmp_slice(self.vec_znx_idft_tmp_bytes()); diff --git a/core/src/gglwe_ciphertext.rs b/core/src/gglwe_ciphertext.rs index d20072a..7deb225 100644 --- a/core/src/gglwe_ciphertext.rs +++ b/core/src/gglwe_ciphertext.rs @@ -1,8 +1,7 @@ use base2k::{ - Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, - ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigOps, VecZnxBigScratch, - VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos, - ZnxZero, + Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, + ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef, + VecZnxOps, ZnxInfos, ZnxZero, }; use sampling::source::Source; @@ -13,7 +12,6 @@ use crate::{ glwe_plaintext::GLWEPlaintext, keys::SecretKeyFourier, utils::derive_size, - vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace}, }; pub struct GGLWECiphertext { @@ -212,81 +210,3 @@ where module.vmp_prepare_row(self, row_i, col_j, a); } } - -impl VecGLWEProductScratchSpace for GGLWECiphertext, FFT64> { - fn prod_with_glwe_scratch_space( - module: &Module, - res_size: usize, - a_size: usize, - grlwe_size: usize, - rank_in: usize, - rank_out: usize, - ) -> usize { - module.bytes_of_vec_znx_dft(rank_out + 1, grlwe_size) - + (module.vec_znx_big_normalize_tmp_bytes() - | (module.vmp_apply_tmp_bytes(res_size, a_size, a_size, rank_in, rank_out + 1, grlwe_size) - + module.bytes_of_vec_znx_dft(rank_in, a_size))) - } -} - -impl VecGLWEProduct for GGLWECiphertext -where - MatZnxDft: MatZnxDftToRef + ZnxInfos, -{ - fn prod_with_glwe( - &self, - module: &Module, - res: &mut GLWECiphertext, - a: &GLWECiphertext, - scratch: &mut Scratch, - ) where - MatZnxDft: MatZnxDftToRef, - VecZnx: VecZnxToMut, - VecZnx: VecZnxToRef, - { - let basek: usize = self.basek(); - - #[cfg(debug_assertions)] - { - assert_eq!(a.rank(), self.rank_in()); - assert_eq!(res.rank(), self.rank_out()); - assert_eq!(res.basek(), basek); - assert_eq!(a.basek(), basek); - assert_eq!(self.n(), module.n()); - assert_eq!(res.n(), module.n()); - assert_eq!(a.n(), module.n()); - assert!( - scratch.available() - >= GGLWECiphertext::prod_with_glwe_scratch_space( - module, - res.size(), - a.size(), - self.size(), - self.rank_in(), - self.rank_out() - ) - ); - } - - let cols_in: usize = self.rank_in(); - let cols_out: usize = self.rank_out() + 1; - - let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, self.size()); // Todo optimise - - { - let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, a.size()); - (0..cols_in).for_each(|col_i| { - module.vec_znx_dft(&mut ai_dft, col_i, a, col_i + 1); - }); - module.vmp_apply(&mut res_dft, &ai_dft, self, scratch2); - } - - let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); - - module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0); - - (0..cols_out).for_each(|i| { - module.vec_znx_big_normalize(basek, res, i, &res_big, i, scratch1); - }); - } -} diff --git a/core/src/ggsw_ciphertext.rs b/core/src/ggsw_ciphertext.rs index fa3b365..67f4774 100644 --- a/core/src/ggsw_ciphertext.rs +++ b/core/src/ggsw_ciphertext.rs @@ -1,35 +1,31 @@ use base2k::{ - Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, - ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigOps, VecZnxBigScratch, - VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos, - ZnxZero, + Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, + ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, + VecZnxDftToRef, VecZnxOps, ZnxInfos, ZnxZero, }; use sampling::source::Source; use crate::{ elem::{GetRow, Infos, SetRow}, - gglwe_ciphertext::GGLWECiphertext, glwe_ciphertext::GLWECiphertext, glwe_ciphertext_fourier::GLWECiphertextFourier, glwe_plaintext::GLWEPlaintext, keys::SecretKeyFourier, - keyswitch_key::GLWESwitchingKey, utils::derive_size, - vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace}, }; pub struct GGSWCiphertext { pub data: MatZnxDft, - pub log_base2k: usize, - pub log_k: usize, + pub basek: usize, + pub k: usize, } impl GGSWCiphertext, B> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize, rows: usize, rank: usize) -> Self { + pub fn new(module: &Module, basek: usize, k: usize, rows: usize, rank: usize) -> Self { Self { - data: module.new_mat_znx_dft(rows, rank + 1, rank + 1, derive_size(log_base2k, log_k)), - log_base2k: log_base2k, - log_k: log_k, + data: module.new_mat_znx_dft(rows, rank + 1, rank + 1, derive_size(basek, k)), + basek: basek, + k: k, } } } @@ -42,11 +38,11 @@ impl Infos for GGSWCiphertext { } fn basek(&self) -> usize { - self.log_base2k + self.basek } fn k(&self) -> usize { - self.log_k + self.k } } @@ -82,35 +78,28 @@ impl GGSWCiphertext, FFT64> { + module.bytes_of_vec_znx_dft(rank + 1, size) } - pub fn keyswitch_scratch_space( + pub fn external_product_scratch_space( module: &Module, - res_size: usize, - lhs: usize, - rhs: usize, - rank_in: usize, - rank_out: usize, + out_size: usize, + in_size: usize, + ggsw_size: usize, + rank: usize, ) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space( - module, res_size, lhs, rhs, rank_in, rank_out, - ) + let tmp_in: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size); + let tmp_out: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size); + let ggsw: usize = GLWECiphertextFourier::external_product_scratch_space(module, out_size, in_size, ggsw_size, rank); + tmp_in + tmp_out + ggsw } - pub fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_inplace_scratch_space( - module, res_size, rhs, rank, - ) - } - - pub fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize, rank: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space( - module, res_size, lhs, rhs, rank, rank, - ) - } - - pub fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( - module, res_size, rhs, rank, - ) + pub fn external_product_inplace_scratch_space( + module: &Module, + out_size: usize, + ggsw_size: usize, + rank: usize, + ) -> usize { + let tmp: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size); + let ggsw: usize = GLWECiphertextFourier::external_product_inplace_scratch_space(module, out_size, ggsw_size, rank); + tmp + ggsw } } @@ -140,7 +129,7 @@ where } let size: usize = self.size(); - let log_base2k: usize = self.basek(); + let basek: usize = self.basek(); let k: usize = self.k(); let cols: usize = self.rank() + 1; @@ -149,20 +138,20 @@ where let mut vec_znx_pt: GLWEPlaintext<&mut [u8]> = GLWEPlaintext { data: tmp_znx_pt, - basek: log_base2k, + basek: basek, k: k, }; let mut vec_znx_ct: GLWECiphertext<&mut [u8]> = GLWECiphertext { data: tmp_znx_ct, - basek: log_base2k, + basek: basek, k, }; (0..self.rows()).for_each(|row_j| { // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_j, pt, 0); - module.vec_znx_normalize_inplace(log_base2k, &mut vec_znx_pt, 0, scrach_2); + module.vec_znx_normalize_inplace(basek, &mut vec_znx_pt, 0, scrach_2); (0..cols).for_each(|col_i| { // rlwe encrypt of vec_znx_pt into vec_znx_ct @@ -193,30 +182,6 @@ where }); } - pub fn keyswitch( - &mut self, - module: &Module, - lhs: &GGSWCiphertext, - rhs: &GLWESwitchingKey, - scratch: &mut Scratch, - ) where - MatZnxDft: MatZnxDftToRef, - MatZnxDft: MatZnxDftToRef, - { - rhs.0.prod_with_vec_glwe(module, self, lhs, scratch); - } - - pub fn keyswitch_inplace( - &mut self, - module: &Module, - rhs: &GLWESwitchingKey, - scratch: &mut Scratch, - ) where - MatZnxDft: MatZnxDftToRef, - { - rhs.0.prod_with_vec_glwe_inplace(module, self, scratch); - } - pub fn external_product( &mut self, module: &Module, @@ -227,7 +192,55 @@ where MatZnxDft: MatZnxDftToRef, MatZnxDft: MatZnxDftToRef, { - rhs.prod_with_vec_glwe(module, self, lhs, scratch); + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank(), + lhs.rank(), + "ggsw_out rank: {} != ggsw_in rank: {}", + self.rank(), + lhs.rank() + ); + assert_eq!( + self.rank(), + rhs.rank(), + "ggsw_in rank: {} != ggsw_apply rank: {}", + self.rank(), + rhs.rank() + ); + } + + let (tmp_in_data, scratch1) = scratch.tmp_vec_znx_dft(module, lhs.rank() + 1, lhs.size()); + + let mut tmp_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_in_data, + basek: lhs.basek(), + k: lhs.k(), + }; + + let (tmp_out_data, scratch2) = scratch1.tmp_vec_znx_dft(module, self.rank() + 1, self.size()); + + let mut tmp_out: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_out_data, + basek: self.basek(), + k: self.k(), + }; + + (0..self.rank() + 1).for_each(|col_i| { + (0..self.rows()).for_each(|row_j| { + lhs.get_row(module, row_j, col_i, &mut tmp_in); + tmp_out.external_product(module, &tmp_in, rhs, scratch2); + self.set_row(module, row_j, col_i, &tmp_out); + }); + }); + + tmp_out.data.zero(); + + (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { + (0..self.rank() + 1).for_each(|col_j| { + self.set_row(module, row_i, col_j, &tmp_out); + }); + }); } pub fn external_product_inplace( @@ -238,7 +251,32 @@ where ) where MatZnxDft: MatZnxDftToRef, { - rhs.prod_with_vec_glwe_inplace(module, self, scratch); + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank(), + rhs.rank(), + "ggsw_out rank: {} != ggsw_apply: {}", + self.rank(), + rhs.rank() + ); + } + + let (tmp_data, scratch1) = scratch.tmp_vec_znx_dft(module, self.rank() + 1, self.size()); + + let mut tmp: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_data, + basek: self.basek(), + k: self.k(), + }; + + (0..self.rank() + 1).for_each(|col_i| { + (0..self.rows()).for_each(|row_j| { + self.get_row(module, row_j, col_i, &mut tmp); + tmp.external_product_inplace(module, rhs, scratch1); + self.set_row(module, row_j, col_i, &tmp); + }); + }); } } @@ -270,73 +308,3 @@ where module.vmp_prepare_row(self, row_i, col_j, a); } } - -impl VecGLWEProductScratchSpace for GGSWCiphertext, FFT64> { - fn prod_with_glwe_scratch_space( - module: &Module, - res_size: usize, - a_size: usize, - rgsw_size: usize, - rank_in: usize, - rank_out: usize, - ) -> usize { - module.bytes_of_vec_znx_dft(rank_out + 1, rgsw_size) - + ((module.bytes_of_vec_znx_dft(rank_in + 1, a_size) - + module.vmp_apply_tmp_bytes( - res_size, - a_size, - a_size, - rank_in + 1, - rank_out + 1, - rgsw_size, - )) - | module.vec_znx_big_normalize_tmp_bytes()) - } -} - -impl VecGLWEProduct for GGSWCiphertext -where - MatZnxDft: MatZnxDftToRef + ZnxInfos, -{ - fn prod_with_glwe( - &self, - module: &Module, - res: &mut GLWECiphertext, - a: &GLWECiphertext, - scratch: &mut Scratch, - ) where - VecZnx: VecZnxToMut, - VecZnx: VecZnxToRef, - { - let log_base2k: usize = self.basek(); - - #[cfg(debug_assertions)] - { - assert_eq!(self.rank(), a.rank()); - assert_eq!(self.rank(), res.rank()); - assert_eq!(res.basek(), log_base2k); - assert_eq!(a.basek(), log_base2k); - assert_eq!(self.n(), module.n()); - assert_eq!(res.n(), module.n()); - assert_eq!(a.n(), module.n()); - } - - let cols: usize = self.rank() + 1; - - let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, self.size()); // Todo optimise - - { - let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, a.size()); - (0..cols).for_each(|col_i| { - module.vec_znx_dft(&mut a_dft, col_i, a, col_i); - }); - module.vmp_apply(&mut res_dft, &a_dft, self, scratch2); - } - - let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); - - (0..cols).for_each(|i| { - module.vec_znx_big_normalize(log_base2k, res, i, &res_big, i, scratch1); - }); - } -} diff --git a/core/src/glwe_ciphertext.rs b/core/src/glwe_ciphertext.rs index 82e44da..1875a54 100644 --- a/core/src/glwe_ciphertext.rs +++ b/core/src/glwe_ciphertext.rs @@ -1,22 +1,20 @@ use base2k::{ - AddNormal, Backend, FFT64, FillUniform, MatZnxDft, MatZnxDftToRef, Module, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, - ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, - VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos, - ZnxZero, + AddNormal, Backend, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToRef, Module, ScalarZnxAlloc, + ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, + VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, + VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxZero, }; use sampling::source::Source; use crate::{ SIX_SIGMA, elem::Infos, - gglwe_ciphertext::GGLWECiphertext, ggsw_ciphertext::GGSWCiphertext, glwe_ciphertext_fourier::GLWECiphertextFourier, glwe_plaintext::GLWEPlaintext, keys::{GLWEPublicKey, SecretDistribution, SecretKeyFourier}, keyswitch_key::GLWESwitchingKey, utils::derive_size, - vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace}, }; pub struct GLWECiphertext { @@ -115,33 +113,50 @@ impl GLWECiphertext> { pub fn keyswitch_scratch_space( module: &Module, - res_size: usize, - lhs: usize, - rhs: usize, - rank_in: usize, - rank_out: usize, + out_size: usize, + out_rank: usize, + in_size: usize, + in_rank: usize, + ksk_size: usize, ) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space( - module, res_size, lhs, rhs, rank_in, rank_out, - ) + module.bytes_of_vec_znx_dft(out_rank + 1, ksk_size) + + (module.vec_znx_big_normalize_tmp_bytes() + | (module.vmp_apply_tmp_bytes( + out_size, + in_size, + in_size, + in_rank + 1, + out_rank + 1, + ksk_size, + ) + module.bytes_of_vec_znx_dft(in_size, in_size))) } - pub fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( - module, res_size, rhs, rank, - ) + pub fn keyswitch_inplace_scratch_space(module: &Module, out_size: usize, out_rank: usize, ksk_size: usize) -> usize { + GLWECiphertext::keyswitch_scratch_space(module, out_size, out_rank, out_size, out_rank, ksk_size) } - pub fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize, rank: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space( - module, res_size, lhs, rhs, rank, rank, - ) + pub fn external_product_scratch_space( + module: &Module, + out_size: usize, + in_size: usize, + ggsw_size: usize, + rank: usize, + ) -> usize { + module.bytes_of_vec_znx_dft(rank + 1, ggsw_size) + + ((module.bytes_of_vec_znx_dft(rank + 1, in_size) + + module.vmp_apply_tmp_bytes( + out_size, + in_size, + in_size, // rows + rank + 1, // cols in + rank + 1, // cols out + ggsw_size, + )) + | module.vec_znx_big_normalize_tmp_bytes()) } pub fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( - module, res_size, rhs, rank, - ) + GLWECiphertext::external_product_scratch_space(module, res_size, res_size, rhs, rank) } } @@ -235,7 +250,50 @@ where VecZnx: VecZnxToRef, MatZnxDft: MatZnxDftToRef, { - rhs.0.prod_with_glwe(module, self, lhs, scratch); + let basek: usize = self.basek(); + + #[cfg(debug_assertions)] + { + assert_eq!(lhs.rank(), rhs.rank_in()); + assert_eq!(self.rank(), rhs.rank_out()); + assert_eq!(self.basek(), basek); + assert_eq!(lhs.basek(), basek); + assert_eq!(rhs.n(), module.n()); + assert_eq!(self.n(), module.n()); + assert_eq!(lhs.n(), module.n()); + assert!( + scratch.available() + >= GLWECiphertext::keyswitch_scratch_space( + module, + self.size(), + self.rank(), + lhs.size(), + lhs.rank(), + rhs.size(), + ) + ); + } + + let cols_in: usize = rhs.rank_in(); + let cols_out: usize = rhs.rank_out() + 1; + + let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise + + { + let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, lhs.size()); + (0..cols_in).for_each(|col_i| { + module.vec_znx_dft(&mut ai_dft, col_i, lhs, col_i + 1); + }); + module.vmp_apply(&mut res_dft, &ai_dft, rhs, scratch2); + } + + let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); + + module.vec_znx_big_add_small_inplace(&mut res_big, 0, lhs, 0); + + (0..cols_out).for_each(|i| { + module.vec_znx_big_normalize(basek, self, i, &res_big, i, scratch1); + }); } pub fn keyswitch_inplace( @@ -246,7 +304,10 @@ where ) where MatZnxDft: MatZnxDftToRef, { - rhs.0.prod_with_glwe_inplace(module, self, scratch); + unsafe { + let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; + self.keyswitch(&module, &*self_ptr, rhs, scratch); + } } pub fn external_product( @@ -259,7 +320,36 @@ where VecZnx: VecZnxToRef, MatZnxDft: MatZnxDftToRef, { - rhs.prod_with_glwe(module, self, lhs, scratch); + let basek: usize = self.basek(); + + #[cfg(debug_assertions)] + { + assert_eq!(rhs.rank(), lhs.rank()); + assert_eq!(rhs.rank(), self.rank()); + assert_eq!(self.basek(), basek); + assert_eq!(lhs.basek(), basek); + assert_eq!(rhs.n(), module.n()); + assert_eq!(self.n(), module.n()); + assert_eq!(lhs.n(), module.n()); + } + + let cols: usize = rhs.rank() + 1; + + let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); // Todo optimise + + { + let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, lhs.size()); + (0..cols).for_each(|col_i| { + module.vec_znx_dft(&mut a_dft, col_i, lhs, col_i); + }); + module.vmp_apply(&mut res_dft, &a_dft, rhs, scratch2); + } + + let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); + + (0..cols).for_each(|i| { + module.vec_znx_big_normalize(basek, self, i, &res_big, i, scratch1); + }); } pub fn external_product_inplace( @@ -270,7 +360,10 @@ where ) where MatZnxDft: MatZnxDftToRef, { - rhs.prod_with_glwe_inplace(module, self, scratch); + unsafe { + let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; + self.external_product(&module, &*self_ptr, rhs, scratch); + } } pub(crate) fn encrypt_sk_private( diff --git a/core/src/glwe_ciphertext_fourier.rs b/core/src/glwe_ciphertext_fourier.rs index fe2a50d..ebbe9cf 100644 --- a/core/src/glwe_ciphertext_fourier.rs +++ b/core/src/glwe_ciphertext_fourier.rs @@ -1,20 +1,13 @@ use base2k::{ - Backend, FFT64, MatZnxDft, MatZnxDftToRef, Module, ScalarZnxDft, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, - VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, - VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxZero, + Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToRef, Module, ScalarZnxDft, ScalarZnxDftOps, + ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, + VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxZero, }; use sampling::source::Source; use crate::{ - elem::Infos, - gglwe_ciphertext::GGLWECiphertext, - ggsw_ciphertext::GGSWCiphertext, - glwe_ciphertext::GLWECiphertext, - glwe_plaintext::GLWEPlaintext, - keys::SecretKeyFourier, - keyswitch_key::GLWESwitchingKey, - utils::derive_size, - vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace}, + elem::Infos, ggsw_ciphertext::GGSWCiphertext, glwe_ciphertext::GLWECiphertext, glwe_plaintext::GLWEPlaintext, + keys::SecretKeyFourier, keyswitch_key::GLWESwitchingKey, utils::derive_size, }; pub struct GLWECiphertextFourier { @@ -24,11 +17,11 @@ pub struct GLWECiphertextFourier { } impl GLWECiphertextFourier, B> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize, rank: usize) -> Self { + pub fn new(module: &Module, basek: usize, k: usize, rank: usize) -> Self { Self { - data: module.new_vec_znx_dft(rank + 1, derive_size(log_base2k, log_k)), - basek: log_base2k, - k: log_k, + data: module.new_vec_znx_dft(rank + 1, derive_size(basek, k)), + basek: basek, + k: k, } } } @@ -92,33 +85,56 @@ impl GLWECiphertextFourier, FFT64> { pub fn keyswitch_scratch_space( module: &Module, - res_size: usize, - lhs: usize, - rhs: usize, - rank_in: usize, - rank_out: usize, + out_size: usize, + out_rank: usize, + in_size: usize, + in_rank: usize, + ksk_size: usize, ) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_fourier_scratch_space( - module, res_size, lhs, rhs, rank_in, rank_out, - ) + let res_dft: usize = module.bytes_of_vec_znx_dft(out_rank + 1, out_size); + + let vmp = module.bytes_of_vec_znx_dft(in_rank, in_size) + + module.vmp_apply_tmp_bytes( + out_size, + in_size, + in_size, + in_rank + 1, + out_rank + 1, + ksk_size, + ); + let res_small: usize = module.bytes_of_vec_znx(out_rank + 1, out_size); + let add_a0: usize = module.bytes_of_vec_znx_big(1, in_size) + module.vec_znx_idft_tmp_bytes(); + let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(); + + res_dft + (vmp | add_a0 | (res_small + normalize)) } - pub fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_fourier_inplace_scratch_space( - module, res_size, rhs, rank, - ) + pub fn keyswitch_inplace_scratch_space(module: &Module, out_size: usize, out_rank: usize, ksk_size: usize) -> usize { + Self::keyswitch_scratch_space(module, out_size, out_rank, out_size, out_rank, ksk_size) } - pub fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize, rank: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_fourier_scratch_space( - module, res_size, lhs, rhs, rank, rank, - ) + pub fn external_product_scratch_space( + module: &Module, + out_size: usize, + in_size: usize, + ggsw_size: usize, + rank: usize, + ) -> usize { + let res_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size); + let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, rank + 1, rank + 1, ggsw_size); + let res_small: usize = module.bytes_of_vec_znx(rank + 1, out_size); + let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(); + + res_dft + (vmp | (res_small + normalize)) } - pub fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_fourier_inplace_scratch_space( - module, res_size, rhs, rank, - ) + pub fn external_product_inplace_scratch_space( + module: &Module, + out_size: usize, + ggsw_size: usize, + rank: usize, + ) -> usize { + Self::external_product_scratch_space(module, out_size, out_size, ggsw_size, rank) } } @@ -158,7 +174,61 @@ where VecZnxDft: VecZnxDftToRef, MatZnxDft: MatZnxDftToRef, { - rhs.0.prod_with_glwe_fourier(module, self, lhs, scratch); + let basek: usize = self.basek(); + + #[cfg(debug_assertions)] + { + assert_eq!(lhs.rank(), rhs.rank_in()); + assert_eq!(self.rank(), rhs.rank_out()); + assert_eq!(self.basek(), basek); + assert_eq!(lhs.basek(), basek); + assert_eq!(rhs.n(), module.n()); + assert_eq!(self.n(), module.n()); + assert_eq!(lhs.n(), module.n()); + assert!( + scratch.available() + >= GLWECiphertextFourier::keyswitch_scratch_space( + module, + self.size(), + self.rank(), + lhs.size(), + lhs.rank(), + rhs.size(), + ) + ); + } + + let cols_in: usize = rhs.rank_in(); + let cols_out: usize = rhs.rank_out() + 1; + + // Buffer of the result of VMP in DFT + let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise + + { + // Applies VMP + let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, lhs.size()); + (0..cols_in).for_each(|col_i| { + module.vec_znx_dft_copy(&mut ai_dft, col_i, lhs, col_i + 1); + }); + module.vmp_apply(&mut res_dft, &ai_dft, rhs, scratch2); + } + + // Switches result of VMP outside of DFT + let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft); + + { + // Switches lhs 0-th outside of DFT domain and adds on + let (mut a0_big, scratch2) = scratch1.tmp_vec_znx_big(module, 1, lhs.size()); + module.vec_znx_idft(&mut a0_big, 0, lhs, 0, scratch2); + module.vec_znx_big_add_inplace(&mut res_big, 0, &a0_big, 0); + } + + // Space fr normalized VMP result outside of DFT domain + let (mut res_small, scratch2) = scratch1.tmp_vec_znx(module, cols_out, lhs.size()); + (0..cols_out).for_each(|i| { + module.vec_znx_big_normalize(basek, &mut res_small, i, &res_big, i, scratch2); + module.vec_znx_dft(self, i, &res_small, i); + }); } pub fn keyswitch_inplace( @@ -169,7 +239,10 @@ where ) where MatZnxDft: MatZnxDftToRef, { - rhs.0.prod_with_glwe_fourier_inplace(module, self, scratch); + unsafe { + let self_ptr: *mut GLWECiphertextFourier = self as *mut GLWECiphertextFourier; + self.keyswitch(&module, &*self_ptr, rhs, scratch); + } } pub fn external_product( @@ -182,7 +255,37 @@ where VecZnxDft: VecZnxDftToRef, MatZnxDft: MatZnxDftToRef, { - rhs.prod_with_glwe_fourier(module, self, lhs, scratch); + let basek: usize = self.basek(); + + #[cfg(debug_assertions)] + { + assert_eq!(rhs.rank(), lhs.rank()); + assert_eq!(rhs.rank(), self.rank()); + assert_eq!(self.basek(), basek); + assert_eq!(lhs.basek(), basek); + assert_eq!(rhs.n(), module.n()); + assert_eq!(self.n(), module.n()); + assert_eq!(lhs.n(), module.n()); + } + + let cols: usize = rhs.rank() + 1; + + // Space for VMP result in DFT domain and high precision + let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); + + { + module.vmp_apply(&mut res_dft, lhs, rhs, scratch1); + } + + // VMP result in high precision + let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft); + + // Space for VMP result normalized + let (mut res_small, scratch2) = scratch1.tmp_vec_znx(module, cols, rhs.size()); + (0..cols).for_each(|i| { + module.vec_znx_big_normalize(basek, &mut res_small, i, &res_big, i, scratch2); + module.vec_znx_dft(self, i, &res_small, i); + }); } pub fn external_product_inplace( @@ -193,7 +296,10 @@ where ) where MatZnxDft: MatZnxDftToRef, { - rhs.prod_with_glwe_fourier_inplace(module, self, scratch); + unsafe { + let self_ptr: *mut GLWECiphertextFourier = self as *mut GLWECiphertextFourier; + self.external_product(&module, &*self_ptr, rhs, scratch); + } } } @@ -247,6 +353,7 @@ where pt.k = pt.k().min(self.k()); } + #[allow(dead_code)] pub(crate) fn idft(&self, module: &Module, res: &mut GLWECiphertext, scratch: &mut Scratch) where GLWECiphertext: VecZnxToMut, diff --git a/core/src/glwe_plaintext.rs b/core/src/glwe_plaintext.rs index 75088d1..4900fa0 100644 --- a/core/src/glwe_plaintext.rs +++ b/core/src/glwe_plaintext.rs @@ -43,10 +43,10 @@ where } impl GLWEPlaintext> { - pub fn new(module: &Module, base2k: usize, k: usize) -> Self { + pub fn new(module: &Module, basek: usize, k: usize) -> Self { Self { - data: module.new_vec_znx(1, derive_size(base2k, k)), - basek: base2k, + data: module.new_vec_znx(1, derive_size(basek, k)), + basek: basek, k, } } diff --git a/core/src/keyswitch_key.rs b/core/src/keyswitch_key.rs index 37774eb..8b9f13d 100644 --- a/core/src/keyswitch_key.rs +++ b/core/src/keyswitch_key.rs @@ -1,6 +1,6 @@ use base2k::{ Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, ScalarZnxDftToRef, - ScalarZnxToRef, Scratch, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, + ScalarZnxToRef, Scratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef, ZnxZero, }; use sampling::source::Source; @@ -10,7 +10,6 @@ use crate::{ ggsw_ciphertext::GGSWCiphertext, glwe_ciphertext_fourier::GLWECiphertextFourier, keys::{SecretKey, SecretKeyFourier}, - vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace}, }; pub struct GLWESwitchingKey(pub(crate) GGLWECiphertext); @@ -39,6 +38,20 @@ impl Infos for GLWESwitchingKey { } } +impl GLWESwitchingKey { + pub fn rank(&self) -> usize { + self.0.data.cols_out() - 1 + } + + pub fn rank_in(&self) -> usize { + self.0.data.cols_in() + } + + pub fn rank_out(&self) -> usize { + self.0.data.cols_out() - 1 + } +} + impl MatZnxDftToMut for GLWESwitchingKey where MatZnxDft: MatZnxDftToMut, @@ -131,33 +144,46 @@ where impl GLWESwitchingKey, FFT64> { pub fn keyswitch_scratch_space( module: &Module, - res_size: usize, - lhs: usize, - rhs: usize, - rank_in: usize, - rank_out: usize, + out_size: usize, + out_rank: usize, + in_size: usize, + in_rank: usize, + ksk_size: usize, ) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space( - module, res_size, lhs, rhs, rank_in, rank_out, - ) + let tmp_in: usize = module.bytes_of_vec_znx_dft(in_rank + 1, in_size); + let tmp_out: usize = module.bytes_of_vec_znx_dft(out_rank + 1, out_size); + let ksk: usize = GLWECiphertextFourier::keyswitch_scratch_space(module, out_size, out_rank, in_size, in_rank, ksk_size); + tmp_in + tmp_out + ksk } - pub fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_inplace_scratch_space( - module, res_size, rhs, rank, - ) + pub fn keyswitch_inplace_scratch_space(module: &Module, out_size: usize, out_rank: usize, ksk_size: usize) -> usize { + let tmp: usize = module.bytes_of_vec_znx_dft(out_rank + 1, out_size); + let ksk: usize = GLWECiphertextFourier::keyswitch_inplace_scratch_space(module, out_size, out_rank, ksk_size); + tmp + ksk } - pub fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize, rank: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space( - module, res_size, lhs, rhs, rank, rank, - ) + pub fn external_product_scratch_space( + module: &Module, + out_size: usize, + in_size: usize, + ggsw_size: usize, + rank: usize, + ) -> usize { + let tmp_in: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size); + let tmp_out: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size); + let ggsw: usize = GLWECiphertextFourier::external_product_scratch_space(module, out_size, in_size, ggsw_size, rank); + tmp_in + tmp_out + ggsw } - pub fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( - module, res_size, rhs, rank, - ) + pub fn external_product_inplace_scratch_space( + module: &Module, + out_size: usize, + ggsw_size: usize, + rank: usize, + ) -> usize { + let tmp: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size); + let ggsw: usize = GLWECiphertextFourier::external_product_inplace_scratch_space(module, out_size, ggsw_size, rank); + tmp + ggsw } } @@ -175,8 +201,62 @@ where MatZnxDft: MatZnxDftToRef, MatZnxDft: MatZnxDftToRef, { - rhs.0 - .prod_with_vec_glwe(module, &mut self.0, &lhs.0, scratch); + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank_in(), + lhs.rank_in(), + "ksk_out input rank: {} != ksk_in input rank: {}", + self.rank_in(), + lhs.rank_in() + ); + assert_eq!( + lhs.rank_out(), + rhs.rank_in(), + "ksk_in output rank: {} != ksk_apply input rank: {}", + self.rank_out(), + rhs.rank_in() + ); + assert_eq!( + self.rank_out(), + rhs.rank_out(), + "ksk_out output rank: {} != ksk_apply output rank: {}", + self.rank_out(), + rhs.rank_out() + ); + } + + let (tmp_in_data, scratch1) = scratch.tmp_vec_znx_dft(module, lhs.rank_out() + 1, lhs.size()); + + let mut tmp_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_in_data, + basek: lhs.basek(), + k: lhs.k(), + }; + + let (tmp_out_data, scratch2) = scratch1.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size()); + + let mut tmp_out: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_out_data, + basek: self.basek(), + k: self.k(), + }; + + (0..self.rank_in()).for_each(|col_i| { + (0..self.rows()).for_each(|row_j| { + lhs.get_row(module, row_j, col_i, &mut tmp_in); + tmp_out.keyswitch(module, &tmp_in, rhs, scratch2); + self.set_row(module, row_j, col_i, &tmp_out); + }); + }); + + tmp_out.data.zero(); + + (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { + (0..self.rank_in()).for_each(|col_j| { + self.set_row(module, row_i, col_j, &tmp_out); + }); + }); } pub fn keyswitch_inplace( @@ -187,8 +267,32 @@ where ) where MatZnxDft: MatZnxDftToRef, { - rhs.0 - .prod_with_vec_glwe_inplace(module, &mut self.0, scratch); + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank_out(), + rhs.rank_out(), + "ksk_out output rank: {} != ksk_apply output rank: {}", + self.rank_out(), + rhs.rank_out() + ); + } + + let (tmp_data, scratch1) = scratch.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size()); + + let mut tmp: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_data, + basek: self.basek(), + k: self.k(), + }; + + (0..self.rank_in()).for_each(|col_i| { + (0..self.rows()).for_each(|row_j| { + self.get_row(module, row_j, col_i, &mut tmp); + tmp.keyswitch_inplace(module, rhs, scratch1); + self.set_row(module, row_j, col_i, &tmp); + }); + }); } pub fn external_product( @@ -201,7 +305,62 @@ where MatZnxDft: MatZnxDftToRef, MatZnxDft: MatZnxDftToRef, { - rhs.prod_with_vec_glwe(module, &mut self.0, &lhs.0, scratch); + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank_in(), + lhs.rank_in(), + "ksk_out input rank: {} != ksk_in input rank: {}", + self.rank_in(), + lhs.rank_in() + ); + assert_eq!( + lhs.rank_out(), + rhs.rank(), + "ksk_in output rank: {} != ggsw rank: {}", + self.rank_out(), + rhs.rank() + ); + assert_eq!( + self.rank_out(), + rhs.rank(), + "ksk_out output rank: {} != ggsw rank: {}", + self.rank_out(), + rhs.rank() + ); + } + + let (tmp_in_data, scratch1) = scratch.tmp_vec_znx_dft(module, lhs.rank_out() + 1, lhs.size()); + + let mut tmp_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_in_data, + basek: lhs.basek(), + k: lhs.k(), + }; + + let (tmp_out_data, scratch2) = scratch1.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size()); + + let mut tmp_out: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_out_data, + basek: self.basek(), + k: self.k(), + }; + + (0..self.rank_in()).for_each(|col_i| { + (0..self.rows()).for_each(|row_j| { + lhs.get_row(module, row_j, col_i, &mut tmp_in); + tmp_out.external_product(module, &tmp_in, rhs, scratch2); + self.set_row(module, row_j, col_i, &tmp_out); + }); + }); + + tmp_out.data.zero(); + + (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { + (0..self.rank_in()).for_each(|col_j| { + self.set_row(module, row_i, col_j, &tmp_out); + }); + }); } pub fn external_product_inplace( @@ -212,6 +371,31 @@ where ) where MatZnxDft: MatZnxDftToRef, { - rhs.prod_with_vec_glwe_inplace(module, &mut self.0, scratch); + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank_out(), + rhs.rank(), + "ksk_out output rank: {} != ggsw rank: {}", + self.rank_out(), + rhs.rank() + ); + } + + let (tmp_data, scratch1) = scratch.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size()); + + let mut tmp: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_data, + basek: self.basek(), + k: self.k(), + }; + + (0..self.rank_in()).for_each(|col_i| { + (0..self.rows()).for_each(|row_j| { + self.get_row(module, row_j, col_i, &mut tmp); + tmp.external_product_inplace(module, rhs, scratch1); + self.set_row(module, row_j, col_i, &tmp); + }); + }); } } diff --git a/core/src/lib.rs b/core/src/lib.rs index 14392df..60d57c2 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -9,6 +9,5 @@ pub mod keyswitch_key; #[cfg(test)] mod test_fft64; mod utils; -pub mod vec_glwe_product; pub(crate) const SIX_SIGMA: f64 = 6.0; diff --git a/core/src/test_fft64/gglwe.rs b/core/src/test_fft64/gglwe.rs index 8327325..3ba02a0 100644 --- a/core/src/test_fft64/gglwe.rs +++ b/core/src/test_fft64/gglwe.rs @@ -1,505 +1,510 @@ -// use base2k::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, ZnxViewMut}; -// use sampling::source::Source; -// -// use crate::{ -// elem::{GetRow, Infos}, -// ggsw_ciphertext::GGSWCiphertext, -// glwe_ciphertext_fourier::GLWECiphertextFourier, -// glwe_plaintext::GLWEPlaintext, -// keys::{SecretKey, SecretKeyFourier}, -// keyswitch_key::GLWESwitchingKey, -// test_fft64::ggsw::noise_rgsw_product, -// }; -// -// #[test] -// fn encrypt_sk() { -// let module: Module = Module::::new(2048); -// let log_base2k: usize = 8; -// let log_k_ct: usize = 54; -// let rows: usize = 4; -// let rank: usize = 1; -// let rank_out: usize = 1; -// -// let sigma: f64 = 3.2; -// -// let mut ct: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, log_base2k, log_k_ct, rows, rank, rank_out); -// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); -// let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); -// -// let mut source_xs: Source = Source::new([0u8; 32]); -// let mut source_xe: Source = Source::new([0u8; 32]); -// let mut source_xa: Source = Source::new([0u8; 32]); -// -// pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); -// -// let mut scratch: ScratchOwned = ScratchOwned::new( -// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct.size()) -// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size()), -// ); -// -// let mut sk: SecretKey> = SecretKey::new(&module, rank); -// sk.fill_ternary_prob(0.5, &mut source_xs); -// sk.fill_zero(); -// -// let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk_dft.dft(&module, &sk); -// -// ct.encrypt_sk( -// &module, -// &pt_scalar, -// &sk_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_ct, rank); -// -// (0..ct.rows()).for_each(|row_i| { -// ct.get_row(&module, row_i, 0, &mut ct_rlwe_dft); -// ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); -// module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_scalar, 0); -// let std_pt: f64 = pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2(); -// assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); -// }); -// } -// -// #[test] -// fn keyswitch() { -// let module: Module = Module::::new(2048); -// let log_base2k: usize = 12; -// let log_k_grlwe: usize = 60; -// let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; -// -// let rank: usize = 1; -// -// let sigma: f64 = 3.2; -// -// let mut ct_grlwe_s0s1: GLWESwitchingKey, FFT64> = -// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); -// let mut ct_grlwe_s1s2: GLWESwitchingKey, FFT64> = -// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); -// let mut ct_grlwe_s0s2: GLWESwitchingKey, FFT64> = -// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); -// -// let mut source_xs: Source = Source::new([0u8; 32]); -// let mut source_xe: Source = Source::new([0u8; 32]); -// let mut source_xa: Source = Source::new([0u8; 32]); -// -// let mut scratch: ScratchOwned = ScratchOwned::new( -// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe_s0s1.size()) -// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_s0s2.size()) -// | GLWESwitchingKey::keyswitch_scratch_space( -// &module, -// ct_grlwe_s0s2.size(), -// ct_grlwe_s0s1.size(), -// ct_grlwe_s1s2.size(), -// ), -// ); -// -// let mut sk0: SecretKey> = SecretKey::new(&module, rank); -// sk0.fill_ternary_prob(0.5, &mut source_xs); -// -// let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk0_dft.dft(&module, &sk0); -// -// let mut sk1: SecretKey> = SecretKey::new(&module, rank); -// sk1.fill_ternary_prob(0.5, &mut source_xs); -// -// let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk1_dft.dft(&module, &sk1); -// -// let mut sk2: SecretKey> = SecretKey::new(&module, rank); -// sk2.fill_ternary_prob(0.5, &mut source_xs); -// -// let mut sk2_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk2_dft.dft(&module, &sk2); -// -// GRLWE_{s1}(s0) = s0 -> s1 -// ct_grlwe_s0s1.encrypt_sk( -// &module, -// &sk0.data, -// &sk1_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// GRLWE_{s2}(s1) -> s1 -> s2 -// ct_grlwe_s1s2.encrypt_sk( -// &module, -// &sk1.data, -// &sk2_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) -// ct_grlwe_s0s2.keyswitch(&module, &ct_grlwe_s0s1, &ct_grlwe_s1s2, scratch.borrow()); -// -// let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = -// GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe, rank); -// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); -// -// (0..ct_grlwe_s0s2.rows()).for_each(|row_i| { -// ct_grlwe_s0s2.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); -// ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); -// module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, 0); -// -// let noise_have: f64 = pt.data.std(0, log_base2k).log2(); -// let noise_want: f64 = noise_grlwe_rlwe_product( -// module.n() as f64, -// log_base2k, -// 0.5, -// 0.5, -// 0f64, -// sigma * sigma, -// 0f64, -// log_k_grlwe, -// log_k_grlwe, -// ); -// -// assert!( -// (noise_have - noise_want).abs() <= 0.1, -// "{} {}", -// noise_have, -// noise_want -// ); -// }); -// } -// -// #[test] -// fn keyswitch_inplace() { -// let module: Module = Module::::new(2048); -// let log_base2k: usize = 12; -// let log_k_grlwe: usize = 60; -// let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; -// -// let rank: usize = 1; -// let rank_out: usize = 1; -// -// let sigma: f64 = 3.2; -// -// let mut ct_grlwe_s0s1: GLWESwitchingKey, FFT64> = -// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); -// let mut ct_grlwe_s1s2: GLWESwitchingKey, FFT64> = -// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); -// -// let mut source_xs: Source = Source::new([0u8; 32]); -// let mut source_xe: Source = Source::new([0u8; 32]); -// let mut source_xa: Source = Source::new([0u8; 32]); -// -// let mut scratch: ScratchOwned = ScratchOwned::new( -// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe_s0s1.size()) -// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_s0s1.size()) -// | GLWESwitchingKey::keyswitch_inplace_scratch_space(&module, ct_grlwe_s0s1.size(), ct_grlwe_s1s2.size()), -// ); -// -// let mut sk0: SecretKey> = SecretKey::new(&module, rank); -// sk0.fill_ternary_prob(0.5, &mut source_xs); -// -// let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk0_dft.dft(&module, &sk0); -// -// let mut sk1: SecretKey> = SecretKey::new(&module, rank); -// sk1.fill_ternary_prob(0.5, &mut source_xs); -// -// let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk1_dft.dft(&module, &sk1); -// -// let mut sk2: SecretKey> = SecretKey::new(&module, rank); -// sk2.fill_ternary_prob(0.5, &mut source_xs); -// -// let mut sk2_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk2_dft.dft(&module, &sk2); -// -// GRLWE_{s1}(s0) = s0 -> s1 -// ct_grlwe_s0s1.encrypt_sk( -// &module, -// &sk0.data, -// &sk1_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// GRLWE_{s2}(s1) -> s1 -> s2 -// ct_grlwe_s1s2.encrypt_sk( -// &module, -// &sk1.data, -// &sk2_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) -// ct_grlwe_s0s1.keyswitch_inplace(&module, &ct_grlwe_s1s2, scratch.borrow()); -// -// let ct_grlwe_s0s2: GLWESwitchingKey, FFT64> = ct_grlwe_s0s1; -// -// let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = -// GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe, rank); -// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); -// -// (0..ct_grlwe_s0s2.rows()).for_each(|row_i| { -// ct_grlwe_s0s2.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); -// ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); -// module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, 0); -// -// let noise_have: f64 = pt.data.std(0, log_base2k).log2(); -// let noise_want: f64 = noise_grlwe_rlwe_product( -// module.n() as f64, -// log_base2k, -// 0.5, -// 0.5, -// 0f64, -// sigma * sigma, -// 0f64, -// log_k_grlwe, -// log_k_grlwe, -// ); -// -// assert!( -// (noise_have - noise_want).abs() <= 0.1, -// "{} {}", -// noise_have, -// noise_want -// ); -// }); -// } -// -// #[test] -// fn external_product() { -// let module: Module = Module::::new(2048); -// let log_base2k: usize = 12; -// let log_k_grlwe: usize = 60; -// let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; -// -// let rank: usize = 1; -// let rank_out: usize = 1; -// -// let sigma: f64 = 3.2; -// -// let mut ct_grlwe_in: GLWESwitchingKey, FFT64> = -// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); -// let mut ct_grlwe_out: GLWESwitchingKey, FFT64> = -// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); -// let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); -// -// let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); -// let mut pt_grlwe: ScalarZnx> = module.new_scalar_znx(1); -// -// let mut source_xs: Source = Source::new([0u8; 32]); -// let mut source_xe: Source = Source::new([0u8; 32]); -// let mut source_xa: Source = Source::new([0u8; 32]); -// -// let mut scratch: ScratchOwned = ScratchOwned::new( -// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe_in.size()) -// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_out.size()) -// | GLWESwitchingKey::external_product_scratch_space( -// &module, -// ct_grlwe_out.size(), -// ct_grlwe_in.size(), -// ct_rgsw.size(), -// ) -// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()), -// ); -// -// let k: usize = 1; -// -// pt_rgsw.raw_mut()[k] = 1; // X^{k} -// -// pt_grlwe.fill_ternary_prob(0, 0.5, &mut source_xs); -// -// let mut sk: SecretKey> = SecretKey::new(&module, rank); -// sk.fill_ternary_prob(0.5, &mut source_xs); -// -// let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk_dft.dft(&module, &sk); -// -// GRLWE_{s1}(s0) = s0 -> s1 -// ct_grlwe_in.encrypt_sk( -// &module, -// &pt_grlwe, -// &sk_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// ct_rgsw.encrypt_sk( -// &module, -// &pt_rgsw, -// &sk_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// GRLWE_(m) (x) RGSW_(X^k) = GRLWE_(m * X^k) -// ct_grlwe_out.external_product(&module, &ct_grlwe_in, &ct_rgsw, scratch.borrow()); -// -// let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = -// GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe, rank); -// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); -// -// module.vec_znx_rotate_inplace(k as i64, &mut pt_grlwe, 0); -// -// (0..ct_grlwe_out.rows()).for_each(|row_i| { -// ct_grlwe_out.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); -// ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); -// module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_grlwe, 0); -// -// let noise_have: f64 = pt.data.std(0, log_base2k).log2(); -// -// let var_gct_err_lhs: f64 = sigma * sigma; -// let var_gct_err_rhs: f64 = 0f64; -// -// let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} -// let var_a0_err: f64 = sigma * sigma; -// let var_a1_err: f64 = 1f64 / 12f64; -// -// let noise_want: f64 = noise_rgsw_product( -// module.n() as f64, -// log_base2k, -// 0.5, -// var_msg, -// var_a0_err, -// var_a1_err, -// var_gct_err_lhs, -// var_gct_err_rhs, -// log_k_grlwe, -// log_k_grlwe, -// ); -// -// assert!( -// (noise_have - noise_want).abs() <= 0.1, -// "{} {}", -// noise_have, -// noise_want -// ); -// }); -// } -// -// #[test] -// fn external_product_inplace() { -// let module: Module = Module::::new(2048); -// let log_base2k: usize = 12; -// let log_k_grlwe: usize = 60; -// let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; -// -// let rank: usize = 1; -// let rank_out: usize = 1; -// -// let sigma: f64 = 3.2; -// -// let mut ct_grlwe: GLWESwitchingKey, FFT64> = -// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); -// let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); -// -// let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); -// let mut pt_grlwe: ScalarZnx> = module.new_scalar_znx(1); -// -// let mut source_xs: Source = Source::new([0u8; 32]); -// let mut source_xe: Source = Source::new([0u8; 32]); -// let mut source_xa: Source = Source::new([0u8; 32]); -// -// let mut scratch: ScratchOwned = ScratchOwned::new( -// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) -// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe.size()) -// | GLWESwitchingKey::external_product_inplace_scratch_space(&module, ct_grlwe.size(), ct_rgsw.size()) -// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()), -// ); -// -// let k: usize = 1; -// -// pt_rgsw.raw_mut()[k] = 1; // X^{k} -// -// pt_grlwe.fill_ternary_prob(0, 0.5, &mut source_xs); -// -// let mut sk: SecretKey> = SecretKey::new(&module, rank); -// sk.fill_ternary_prob(0.5, &mut source_xs); -// -// let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk_dft.dft(&module, &sk); -// -// GRLWE_{s1}(s0) = s0 -> s1 -// ct_grlwe.encrypt_sk( -// &module, -// &pt_grlwe, -// &sk_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// ct_rgsw.encrypt_sk( -// &module, -// &pt_rgsw, -// &sk_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// GRLWE_(m) (x) RGSW_(X^k) = GRLWE_(m * X^k) -// ct_grlwe.external_product_inplace(&module, &ct_rgsw, scratch.borrow()); -// -// let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = -// GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe, rank); -// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); -// -// module.vec_znx_rotate_inplace(k as i64, &mut pt_grlwe, 0); -// -// (0..ct_grlwe.rows()).for_each(|row_i| { -// ct_grlwe.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); -// ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); -// module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_grlwe, 0); -// -// let noise_have: f64 = pt.data.std(0, log_base2k).log2(); -// -// let var_gct_err_lhs: f64 = sigma * sigma; -// let var_gct_err_rhs: f64 = 0f64; -// -// let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} -// let var_a0_err: f64 = sigma * sigma; -// let var_a1_err: f64 = 1f64 / 12f64; -// -// let noise_want: f64 = noise_rgsw_product( -// module.n() as f64, -// log_base2k, -// 0.5, -// var_msg, -// var_a0_err, -// var_a1_err, -// var_gct_err_lhs, -// var_gct_err_rhs, -// log_k_grlwe, -// log_k_grlwe, -// ); -// -// assert!( -// (noise_have - noise_want).abs() <= 0.1, -// "{} {}", -// noise_have, -// noise_want -// ); -// }); -// } +use base2k::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, ZnxViewMut}; +use sampling::source::Source; + +use crate::{ + elem::{GetRow, Infos}, + ggsw_ciphertext::GGSWCiphertext, + glwe_ciphertext_fourier::GLWECiphertextFourier, + glwe_plaintext::GLWEPlaintext, + keys::{SecretKey, SecretKeyFourier}, + keyswitch_key::GLWESwitchingKey, +}; + +#[test] +fn encrypt_sk() { + (1..4).for_each(|rank_in| { + (1..4).for_each(|rank_out| { + println!("test encrypt_sk rank_in rank_out: {} {}", rank_in, rank_out); + test_encrypt_sk(11, 8, 54, 3.2, rank_in, rank_out); + }); + }); +} + +fn test_encrypt_sk(log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank_in: usize, rank_out: usize) { + let module: Module = Module::::new(1 << log_n); + let rows = (k_ksk + basek - 1) / basek; + + let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank_in, rank_out); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ksk); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_out, ksk.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ksk.size()), + ); + + let mut sk_in: SecretKey> = SecretKey::new(&module, rank_in); + sk_in.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_in_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_in); + sk_in_dft.dft(&module, &sk_in); + + let mut sk_out: SecretKey> = SecretKey::new(&module, rank_out); + sk_out.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_out_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_out); + sk_out_dft.dft(&module, &sk_out); + + ksk.encrypt_sk( + &module, + &sk_in, + &sk_out_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + let mut ct_gglwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ksk, rank_out); + + (0..ksk.rank_in()).for_each(|col_i| { + (0..ksk.rows()).for_each(|row_i| { + ksk.get_row(&module, row_i, 0, &mut ct_gglwe_fourier); + ct_gglwe_fourier.decrypt(&module, &mut pt, &sk_out_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk_in, col_i); + let std_pt: f64 = pt.data.std(0, basek) * (k_ksk as f64).exp2(); + assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); + }); + }); +} + +#[test] +fn keyswitch() { + let module: Module = Module::::new(2048); + let basek: usize = 12; + let log_k_grlwe: usize = 60; + let rows: usize = (log_k_grlwe + basek - 1) / basek; + + let rank: usize = 1; + + let sigma: f64 = 3.2; + + let mut ct_grlwe_s0s1: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, basek, log_k_grlwe, rows, rank, rank); + let mut ct_grlwe_s1s2: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, basek, log_k_grlwe, rows, rank, rank); + let mut ct_grlwe_s0s2: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, basek, log_k_grlwe, rows, rank, rank); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe_s0s1.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_s0s2.size()) + | GLWESwitchingKey::keyswitch_scratch_space( + &module, + ct_grlwe_s0s2.size(), + ct_grlwe_s0s1.size(), + ct_grlwe_s1s2.size(), + ), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module, rank); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module, rank); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk1_dft.dft(&module, &sk1); + + let mut sk2: SecretKey> = SecretKey::new(&module, rank); + sk2.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk2_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk2_dft.dft(&module, &sk2); + + // GRLWE_{s1}(s0) = s0 -> s1 + ct_grlwe_s0s1.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + // GRLWE_{s2}(s1) -> s1 -> s2 + ct_grlwe_s1s2.encrypt_sk( + &module, + &sk1.data, + &sk2_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + // GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) + ct_grlwe_s0s2.keyswitch(&module, &ct_grlwe_s0s1, &ct_grlwe_s1s2, scratch.borrow()); + + let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::new(&module, basek, log_k_grlwe, rank); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, log_k_grlwe); + + (0..ct_grlwe_s0s2.rows()).for_each(|row_i| { + ct_grlwe_s0s2.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); + ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, 0); + + let noise_have: f64 = pt.data.std(0, basek).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + basek, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_grlwe, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + }); +} + +#[test] +fn keyswitch_inplace() { + let module: Module = Module::::new(2048); + let basek: usize = 12; + let log_k_grlwe: usize = 60; + let rows: usize = (log_k_grlwe + basek - 1) / basek; + + let rank: usize = 1; + let rank_out: usize = 1; + + let sigma: f64 = 3.2; + + let mut ct_grlwe_s0s1: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, basek, log_k_grlwe, rows, rank, rank_out); + let mut ct_grlwe_s1s2: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, basek, log_k_grlwe, rows, rank, rank_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe_s0s1.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_s0s1.size()) + | GLWESwitchingKey::keyswitch_inplace_scratch_space(&module, ct_grlwe_s0s1.size(), ct_grlwe_s1s2.size()), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module, rank); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module, rank); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk1_dft.dft(&module, &sk1); + + let mut sk2: SecretKey> = SecretKey::new(&module, rank); + sk2.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk2_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk2_dft.dft(&module, &sk2); + + // GRLWE_{s1}(s0) = s0 -> s1 + ct_grlwe_s0s1.encrypt_sk( + &module, + &sk0.data, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + // GRLWE_{s2}(s1) -> s1 -> s2 + ct_grlwe_s1s2.encrypt_sk( + &module, + &sk1.data, + &sk2_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + // GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) + ct_grlwe_s0s1.keyswitch_inplace(&module, &ct_grlwe_s1s2, scratch.borrow()); + + let ct_grlwe_s0s2: GLWESwitchingKey, FFT64> = ct_grlwe_s0s1; + + let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::new(&module, basek, log_k_grlwe, rank); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, log_k_grlwe); + + (0..ct_grlwe_s0s2.rows()).for_each(|row_i| { + ct_grlwe_s0s2.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); + ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, 0); + + let noise_have: f64 = pt.data.std(0, basek).log2(); + let noise_want: f64 = noise_grlwe_rlwe_product( + module.n() as f64, + basek, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + log_k_grlwe, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + }); +} + +#[test] +fn external_product() { + let module: Module = Module::::new(2048); + let basek: usize = 12; + let log_k_grlwe: usize = 60; + let rows: usize = (log_k_grlwe + basek - 1) / basek; + + let rank: usize = 1; + let rank_out: usize = 1; + + let sigma: f64 = 3.2; + + let mut ct_grlwe_in: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, basek, log_k_grlwe, rows, rank, rank_out); + let mut ct_grlwe_out: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, basek, log_k_grlwe, rows, rank, rank_out); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, log_k_grlwe, rows, rank); + + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_grlwe: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe_in.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_out.size()) + | GLWESwitchingKey::external_product_scratch_space( + &module, + ct_grlwe_out.size(), + ct_grlwe_in.size(), + ct_rgsw.size(), + ) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()), + ); + + let k: usize = 1; + + pt_rgsw.raw_mut()[k] = 1; // X^{k} + + pt_grlwe.fill_ternary_prob(0, 0.5, &mut source_xs); + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + // GRLWE_{s1}(s0) = s0 -> s1 + ct_grlwe_in.encrypt_sk( + &module, + &pt_grlwe, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + // GRLWE_(m) (x) RGSW_(X^k) = GRLWE_(m * X^k) + ct_grlwe_out.external_product(&module, &ct_grlwe_in, &ct_rgsw, scratch.borrow()); + + let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::new(&module, basek, log_k_grlwe, rank); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, log_k_grlwe); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_grlwe, 0); + + (0..ct_grlwe_out.rows()).for_each(|row_i| { + ct_grlwe_out.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); + ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_grlwe, 0); + + let noise_have: f64 = pt.data.std(0, basek).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_rgsw_product( + module.n() as f64, + basek, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + log_k_grlwe, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + }); +} + +#[test] +fn external_product_inplace() { + let module: Module = Module::::new(2048); + let basek: usize = 12; + let log_k_grlwe: usize = 60; + let rows: usize = (log_k_grlwe + basek - 1) / basek; + + let rank: usize = 1; + let rank_out: usize = 1; + + let sigma: f64 = 3.2; + + let mut ct_grlwe: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, log_k_grlwe, rows, rank, rank_out); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, log_k_grlwe, rows, rank); + + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_grlwe: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe.size()) + | GLWESwitchingKey::external_product_inplace_scratch_space(&module, ct_grlwe.size(), ct_rgsw.size()) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()), + ); + + let k: usize = 1; + + pt_rgsw.raw_mut()[k] = 1; // X^{k} + + pt_grlwe.fill_ternary_prob(0, 0.5, &mut source_xs); + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + // GRLWE_{s1}(s0) = s0 -> s1 + ct_grlwe.encrypt_sk( + &module, + &pt_grlwe, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + // GRLWE_(m) (x) RGSW_(X^k) = GRLWE_(m * X^k) + ct_grlwe.external_product_inplace(&module, &ct_rgsw, scratch.borrow()); + + let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::new(&module, basek, log_k_grlwe, rank); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, log_k_grlwe); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_grlwe, 0); + + (0..ct_grlwe.rows()).for_each(|row_i| { + ct_grlwe.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); + ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_grlwe, 0); + + let noise_have: f64 = pt.data.std(0, basek).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_rgsw_product( + module.n() as f64, + basek, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + log_k_grlwe, + log_k_grlwe, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + }); +} pub(crate) fn noise_gglwe_product( n: f64, - log_base2k: usize, + basek: usize, var_xs: f64, var_msg: f64, var_a_err: f64, @@ -510,12 +515,12 @@ pub(crate) fn noise_gglwe_product( b_logq: usize, ) -> f64 { let a_logq: usize = a_logq.min(b_logq); - let a_cols: usize = (a_logq + log_base2k - 1) / log_base2k; + let a_cols: usize = (a_logq + basek - 1) / basek; let b_scale = 2.0f64.powi(b_logq as i32); let a_scale: f64 = 2.0f64.powi((b_logq - a_logq) as i32); - let base: f64 = (1 << (log_base2k)) as f64; + let base: f64 = (1 << (basek)) as f64; let var_base: f64 = base * base / 12f64; // lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2) diff --git a/core/src/test_fft64/ggsw.rs b/core/src/test_fft64/ggsw.rs index eb8c532..cf34dda 100644 --- a/core/src/test_fft64/ggsw.rs +++ b/core/src/test_fft64/ggsw.rs @@ -1,575 +1,345 @@ -// use base2k::{ -// FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, -// VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, VecZnxToMut, ZnxViewMut, ZnxZero, -// }; -// use sampling::source::Source; -// -// use crate::{ -// elem::{GetRow, Infos}, -// ggsw_ciphertext::GGSWCiphertext, -// glwe_ciphertext_fourier::GLWECiphertextFourier, -// glwe_plaintext::GLWEPlaintext, -// keys::{SecretKey, SecretKeyFourier}, -// keyswitch_key::GLWESwitchingKey, -// test_fft64::gglwe::noise_grlwe_rlwe_product, -// }; -// -// #[test] -// fn encrypt_sk() { -// let module: Module = Module::::new(2048); -// let log_base2k: usize = 8; -// let log_k_ct: usize = 54; -// let rows: usize = 4; -// let rank: usize = 1; -// -// let sigma: f64 = 3.2; -// -// let mut ct: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_ct, rows, rank); -// let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); -// let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); -// let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); -// -// let mut source_xs: Source = Source::new([0u8; 32]); -// let mut source_xe: Source = Source::new([0u8; 32]); -// let mut source_xa: Source = Source::new([0u8; 32]); -// -// pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); -// -// let mut scratch: ScratchOwned = ScratchOwned::new( -// GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct.size()) -// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size()), -// ); -// -// let mut sk: SecretKey> = SecretKey::new(&module, rank); -// sk.fill_ternary_prob(0.5, &mut source_xs); -// -// let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk_dft.dft(&module, &sk); -// -// ct.encrypt_sk( -// &module, -// &pt_scalar, -// &sk_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_ct, rank); -// let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct.size()); -// let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct.size()); -// -// (0..ct.rank()).for_each(|col_j| { -// (0..ct.rows()).for_each(|row_i| { -// module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0); -// -// if col_j == 1 { -// module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); -// module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0); -// module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); -// module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); -// } -// -// ct.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); -// -// ct_rlwe_dft.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); -// -// module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); -// -// let std_pt: f64 = pt_have.data.std(0, log_base2k) * (log_k_ct as f64).exp2(); -// assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); -// -// pt_want.data.zero(); -// }); -// }); -// } -// -// #[test] -// fn keyswitch() { -// let module: Module = Module::::new(2048); -// let log_base2k: usize = 12; -// let log_k_grlwe: usize = 60; -// let log_k_rgsw_in: usize = 45; -// let log_k_rgsw_out: usize = 45; -// let rows: usize = (log_k_rgsw_in + log_base2k - 1) / log_base2k; -// -// let rank: usize = 1; -// -// let sigma: f64 = 3.2; -// -// let mut ct_grlwe: GLWESwitchingKey, FFT64> = -// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); -// let mut ct_rgsw_in: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_in, rows, rank); -// let mut ct_rgsw_out: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_out, rows, rank); -// let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); -// -// let mut source_xs: Source = Source::new([0u8; 32]); -// let mut source_xe: Source = Source::new([0u8; 32]); -// let mut source_xa: Source = Source::new([0u8; 32]); -// -// Random input plaintext -// pt_rgsw.fill_ternary_prob(0, 0.5, &mut source_xs); -// -// let mut scratch: ScratchOwned = ScratchOwned::new( -// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) -// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw_out.size()) -// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw_in.size()) -// | GGSWCiphertext::keyswitch_scratch_space( -// &module, -// ct_rgsw_out.size(), -// ct_rgsw_in.size(), -// ct_grlwe.size(), -// ), -// ); -// -// let mut sk0: SecretKey> = SecretKey::new(&module, rank); -// sk0.fill_ternary_prob(0.5, &mut source_xs); -// -// let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk0_dft.dft(&module, &sk0); -// -// let mut sk1: SecretKey> = SecretKey::new(&module, rank); -// sk1.fill_ternary_prob(0.5, &mut source_xs); -// -// let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk1_dft.dft(&module, &sk1); -// -// ct_grlwe.encrypt_sk( -// &module, -// &sk0.data, -// &sk1_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// ct_rgsw_in.encrypt_sk( -// &module, -// &pt_rgsw, -// &sk0_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// ct_rgsw_out.keyswitch(&module, &ct_rgsw_in, &ct_grlwe, scratch.borrow()); -// -// let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = -// GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_out, rank); -// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_out); -// let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_out.size()); -// let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw_out.size()); -// let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_out); -// -// (0..ct_rgsw_out.rank()).for_each(|col_j| { -// (0..ct_rgsw_out.rows()).for_each(|row_i| { -// module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw, 0); -// -// if col_j == 1 { -// module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); -// module.svp_apply_inplace(&mut pt_dft, 0, &sk0_dft, 0); -// module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); -// module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); -// } -// -// ct_rgsw_out.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); -// ct_rlwe_dft.decrypt(&module, &mut pt, &sk1_dft, scratch.borrow()); -// -// module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); -// -// let noise_have: f64 = pt.data.std(0, log_base2k).log2(); -// let noise_want: f64 = noise_grlwe_rlwe_product( -// module.n() as f64, -// log_base2k, -// 0.5, -// 0.5, -// 0f64, -// sigma * sigma, -// 0f64, -// log_k_grlwe, -// log_k_grlwe, -// ); -// -// assert!( -// (noise_have - noise_want).abs() <= 0.2, -// "have: {} want: {}", -// noise_have, -// noise_want -// ); -// -// pt_want.data.zero(); -// }); -// }); -// } -// -// #[test] -// fn keyswitch_inplace() { -// let module: Module = Module::::new(2048); -// let log_base2k: usize = 12; -// let log_k_grlwe: usize = 60; -// let log_k_rgsw: usize = 45; -// let rows: usize = (log_k_rgsw + log_base2k - 1) / log_base2k; -// let rank: usize = 1; -// -// let sigma: f64 = 3.2; -// -// let mut ct_grlwe: GLWESwitchingKey, FFT64> = -// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); -// let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw, rows, rank); -// let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); -// -// let mut source_xs: Source = Source::new([0u8; 32]); -// let mut source_xe: Source = Source::new([0u8; 32]); -// let mut source_xa: Source = Source::new([0u8; 32]); -// -// Random input plaintext -// pt_rgsw.fill_ternary_prob(0, 0.5, &mut source_xs); -// -// let mut scratch: ScratchOwned = ScratchOwned::new( -// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) -// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw.size()) -// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()) -// | GGSWCiphertext::keyswitch_inplace_scratch_space(&module, ct_rgsw.size(), ct_grlwe.size()), -// ); -// -// let mut sk0: SecretKey> = SecretKey::new(&module, rank); -// sk0.fill_ternary_prob(0.5, &mut source_xs); -// -// let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk0_dft.dft(&module, &sk0); -// -// let mut sk1: SecretKey> = SecretKey::new(&module, rank); -// sk1.fill_ternary_prob(0.5, &mut source_xs); -// -// let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk1_dft.dft(&module, &sk1); -// -// ct_grlwe.encrypt_sk( -// &module, -// &sk0.data, -// &sk1_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// ct_rgsw.encrypt_sk( -// &module, -// &pt_rgsw, -// &sk0_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// ct_rgsw.keyswitch_inplace(&module, &ct_grlwe, scratch.borrow()); -// -// let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = -// GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw, rank); -// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw); -// let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw.size()); -// let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw.size()); -// let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw); -// -// (0..ct_rgsw.rank()).for_each(|col_j| { -// (0..ct_rgsw.rows()).for_each(|row_i| { -// module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw, 0); -// -// if col_j == 1 { -// module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); -// module.svp_apply_inplace(&mut pt_dft, 0, &sk0_dft, 0); -// module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); -// module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); -// } -// -// ct_rgsw.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); -// ct_rlwe_dft.decrypt(&module, &mut pt, &sk1_dft, scratch.borrow()); -// -// module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); -// -// let noise_have: f64 = pt.data.std(0, log_base2k).log2(); -// let noise_want: f64 = noise_grlwe_rlwe_product( -// module.n() as f64, -// log_base2k, -// 0.5, -// 0.5, -// 0f64, -// sigma * sigma, -// 0f64, -// log_k_grlwe, -// log_k_grlwe, -// ); -// -// assert!( -// (noise_have - noise_want).abs() <= 0.2, -// "have: {} want: {}", -// noise_have, -// noise_want -// ); -// -// pt_want.data.zero(); -// }); -// }); -// } -// -// #[test] -// fn external_product() { -// let module: Module = Module::::new(2048); -// let log_base2k: usize = 12; -// let log_k_rgsw_rhs: usize = 60; -// let log_k_rgsw_lhs_in: usize = 45; -// let log_k_rgsw_lhs_out: usize = 45; -// let rows: usize = (log_k_rgsw_lhs_in + log_base2k - 1) / log_base2k; -// let rank: usize = 1; -// -// let sigma: f64 = 3.2; -// -// let mut ct_rgsw_rhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_rhs, rows, rank); -// let mut ct_rgsw_lhs_in: GGSWCiphertext, FFT64> = -// GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs_in, rows, rank); -// let mut ct_rgsw_lhs_out: GGSWCiphertext, FFT64> = -// GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs_out, rows, rank); -// let mut pt_rgsw_lhs: ScalarZnx> = module.new_scalar_znx(1); -// let mut pt_rgsw_rhs: ScalarZnx> = module.new_scalar_znx(1); -// -// let mut source_xs: Source = Source::new([0u8; 32]); -// let mut source_xe: Source = Source::new([0u8; 32]); -// let mut source_xa: Source = Source::new([0u8; 32]); -// -// Random input plaintext -// pt_rgsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs); -// -// let k: usize = 1; -// -// pt_rgsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} -// -// let mut scratch: ScratchOwned = ScratchOwned::new( -// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_rgsw_rhs.size()) -// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw_lhs_out.size()) -// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw_lhs_in.size()) -// | GGSWCiphertext::external_product_scratch_space( -// &module, -// ct_rgsw_lhs_out.size(), -// ct_rgsw_lhs_in.size(), -// ct_rgsw_rhs.size(), -// ), -// ); -// -// let mut sk: SecretKey> = SecretKey::new(&module, rank); -// sk.fill_ternary_prob(0.5, &mut source_xs); -// -// let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk_dft.dft(&module, &sk); -// -// ct_rgsw_rhs.encrypt_sk( -// &module, -// &pt_rgsw_rhs, -// &sk_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// ct_rgsw_lhs_in.encrypt_sk( -// &module, -// &pt_rgsw_lhs, -// &sk_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// ct_rgsw_lhs_out.external_product(&module, &ct_rgsw_lhs_in, &ct_rgsw_rhs, scratch.borrow()); -// -// let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = -// GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_lhs_out, rank); -// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs_out); -// let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_lhs_out.size()); -// let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw_lhs_out.size()); -// let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs_out); -// -// module.vec_znx_rotate_inplace(k as i64, &mut pt_rgsw_lhs, 0); -// -// (0..ct_rgsw_lhs_out.rank()).for_each(|col_j| { -// (0..ct_rgsw_lhs_out.rows()).for_each(|row_i| { -// module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw_lhs, 0); -// -// if col_j == 1 { -// module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); -// module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0); -// module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); -// module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); -// } -// -// ct_rgsw_lhs_out.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); -// ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); -// -// module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); -// -// let noise_have: f64 = pt.data.std(0, log_base2k).log2(); -// -// let var_gct_err_lhs: f64 = sigma * sigma; -// let var_gct_err_rhs: f64 = 0f64; -// -// let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} -// let var_a0_err: f64 = sigma * sigma; -// let var_a1_err: f64 = 1f64 / 12f64; -// -// let noise_want: f64 = noise_rgsw_product( -// module.n() as f64, -// log_base2k, -// 0.5, -// var_msg, -// var_a0_err, -// var_a1_err, -// var_gct_err_lhs, -// var_gct_err_rhs, -// log_k_rgsw_lhs_in, -// log_k_rgsw_rhs, -// ); -// -// assert!( -// (noise_have - noise_want).abs() <= 0.1, -// "have: {} want: {}", -// noise_have, -// noise_want -// ); -// -// pt_want.data.zero(); -// }); -// }); -// } -// -// #[test] -// fn external_product_inplace() { -// let module: Module = Module::::new(2048); -// let log_base2k: usize = 12; -// let log_k_rgsw_rhs: usize = 60; -// let log_k_rgsw_lhs: usize = 45; -// let rows: usize = (log_k_rgsw_lhs + log_base2k - 1) / log_base2k; -// let rank: usize = 1; -// -// let sigma: f64 = 3.2; -// -// let mut ct_rgsw_rhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_rhs, rows, rank); -// let mut ct_rgsw_lhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs, rows, rank); -// let mut pt_rgsw_lhs: ScalarZnx> = module.new_scalar_znx(1); -// let mut pt_rgsw_rhs: ScalarZnx> = module.new_scalar_znx(1); -// -// let mut source_xs: Source = Source::new([0u8; 32]); -// let mut source_xe: Source = Source::new([0u8; 32]); -// let mut source_xa: Source = Source::new([0u8; 32]); -// -// Random input plaintext -// pt_rgsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs); -// -// let k: usize = 1; -// -// pt_rgsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} -// -// let mut scratch: ScratchOwned = ScratchOwned::new( -// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_rgsw_rhs.size()) -// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw_lhs.size()) -// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw_lhs.size()) -// | GGSWCiphertext::external_product_inplace_scratch_space(&module, ct_rgsw_lhs.size(), ct_rgsw_rhs.size()), -// ); -// -// let mut sk: SecretKey> = SecretKey::new(&module, rank); -// sk.fill_ternary_prob(0.5, &mut source_xs); -// -// let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); -// sk_dft.dft(&module, &sk); -// -// ct_rgsw_rhs.encrypt_sk( -// &module, -// &pt_rgsw_rhs, -// &sk_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// ct_rgsw_lhs.encrypt_sk( -// &module, -// &pt_rgsw_lhs, -// &sk_dft, -// &mut source_xa, -// &mut source_xe, -// sigma, -// scratch.borrow(), -// ); -// -// ct_rgsw_lhs.external_product_inplace(&module, &ct_rgsw_rhs, scratch.borrow()); -// -// let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = -// GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_lhs, rank); -// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs); -// let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_lhs.size()); -// let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw_lhs.size()); -// let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs); -// -// module.vec_znx_rotate_inplace(k as i64, &mut pt_rgsw_lhs, 0); -// -// (0..ct_rgsw_lhs.rank()).for_each(|col_j| { -// (0..ct_rgsw_lhs.rows()).for_each(|row_i| { -// module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw_lhs, 0); -// -// if col_j == 1 { -// module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); -// module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0); -// module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); -// module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); -// } -// -// ct_rgsw_lhs.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); -// ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); -// -// module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); -// -// let noise_have: f64 = pt.data.std(0, log_base2k).log2(); -// -// let var_gct_err_lhs: f64 = sigma * sigma; -// let var_gct_err_rhs: f64 = 0f64; -// -// let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} -// let var_a0_err: f64 = sigma * sigma; -// let var_a1_err: f64 = 1f64 / 12f64; -// -// let noise_want: f64 = noise_rgsw_product( -// module.n() as f64, -// log_base2k, -// 0.5, -// var_msg, -// var_a0_err, -// var_a1_err, -// var_gct_err_lhs, -// var_gct_err_rhs, -// log_k_rgsw_lhs, -// log_k_rgsw_rhs, -// ); -// -// assert!( -// (noise_have - noise_want).abs() <= 0.1, -// "have: {} want: {}", -// noise_have, -// noise_want -// ); -// -// pt_want.data.zero(); -// }); -// }); -// } -pub(crate) fn noise_ggsw_gglwe_product( +use base2k::{ + FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, + VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, VecZnxToMut, ZnxViewMut, ZnxZero, +}; +use sampling::source::Source; + +use crate::{ + elem::{GetRow, Infos}, + ggsw_ciphertext::GGSWCiphertext, + glwe_ciphertext_fourier::GLWECiphertextFourier, + glwe_plaintext::GLWEPlaintext, + keys::{SecretKey, SecretKeyFourier}, + keyswitch_key::GLWESwitchingKey, +}; + +#[test] +fn encrypt_sk() { + (1..4).for_each(|rank| { + println!("test encrypt_sk rank: {}", rank); + test_encrypt_sk(11, 8, 54, 3.2, rank); + }); +} + +#[test] +fn external_product() { + (1..4).for_each(|rank| { + println!("test external_product rank: {}", rank); + test_external_product(12, 12, 60, rank, 3.2); + }); +} + +#[test] +fn external_product_inplace() { + (1..4).for_each(|rank| { + println!("test external_product rank: {}", rank); + test_external_product_inplace(12, 15, 60, rank, 3.2); + }); +} + +fn test_encrypt_sk(log_n: usize, basek: usize, k_ggsw: usize, sigma: f64, rank: usize) { + let module: Module = Module::::new(1 << log_n); + + let rows: usize = (k_ggsw + basek - 1) / basek; + + let mut ct: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ggsw); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ggsw); + let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size()), + ); + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + ct.encrypt_sk( + &module, + &pt_scalar, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ggsw, rank); + let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct.size()); + let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct.size()); + + (0..ct.rank() + 1).for_each(|col_j| { + (0..ct.rows()).for_each(|row_i| { + module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0); + + // mul with sk[col_j-1] + if col_j > 0 { + module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1); + module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); + module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); + } + + ct.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); + + ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let std_pt: f64 = pt_have.data.std(0, basek) * (k_ggsw as f64).exp2(); + assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); + + pt_want.data.zero(); + }); + }); +} + +fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, rank: usize, sigma: f64) { + let module: Module = Module::::new(1 << log_n); + + let rows: usize = (k_ggsw + basek - 1) / basek; + + let mut ct_ggsw_rhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank); + let mut ct_ggsw_lhs_in: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank); + let mut ct_ggsw_lhs_out: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank); + let mut pt_ggsw_lhs: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_ggsw_rhs: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + pt_ggsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs); + + let k: usize = 1; + + pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_ggsw_rhs.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_ggsw_lhs_out.size()) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_ggsw_lhs_in.size()) + | GGSWCiphertext::external_product_scratch_space( + &module, + ct_ggsw_lhs_out.size(), + ct_ggsw_lhs_in.size(), + ct_ggsw_rhs.size(), + rank, + ), + ); + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + ct_ggsw_rhs.encrypt_sk( + &module, + &pt_ggsw_rhs, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_ggsw_lhs_in.encrypt_sk( + &module, + &pt_ggsw_lhs, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_ggsw_lhs_out.external_product(&module, &ct_ggsw_lhs_in, &ct_ggsw_rhs, scratch.borrow()); + + let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ggsw, rank); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ggsw); + let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_ggsw_lhs_out.size()); + let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_ggsw_lhs_out.size()); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ggsw); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_ggsw_lhs, 0); + + (0..ct_ggsw_lhs_out.rank() + 1).for_each(|col_j| { + (0..ct_ggsw_lhs_out.rows()).for_each(|row_i| { + module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_ggsw_lhs, 0); + + if col_j > 0 { + module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1); + module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); + module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); + } + + ct_ggsw_lhs_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); + ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); + + let noise_have: f64 = pt.data.std(0, basek).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_ggsw_product( + module.n() as f64, + basek, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + rank as f64, + k_ggsw, + k_ggsw, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "have: {} want: {}", + noise_have, + noise_want + ); + + pt_want.data.zero(); + }); + }); +} + +fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, rank: usize, sigma: f64) { + let module: Module = Module::::new(1 << log_n); + let rows: usize = (k_ggsw + basek - 1) / basek; + + let mut ct_ggsw_rhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank); + let mut ct_ggsw_lhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank); + let mut pt_ggsw_lhs: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_ggsw_rhs: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + pt_ggsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs); + + let k: usize = 1; + + pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_ggsw_rhs.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_ggsw_lhs.size()) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_ggsw_lhs.size()) + | GGSWCiphertext::external_product_inplace_scratch_space(&module, ct_ggsw_lhs.size(), ct_ggsw_rhs.size(), rank), + ); + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + ct_ggsw_rhs.encrypt_sk( + &module, + &pt_ggsw_rhs, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_ggsw_lhs.encrypt_sk( + &module, + &pt_ggsw_lhs, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_ggsw_lhs.external_product_inplace(&module, &ct_ggsw_rhs, scratch.borrow()); + + let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ggsw, rank); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ggsw); + let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_ggsw_lhs.size()); + let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_ggsw_lhs.size()); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ggsw); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_ggsw_lhs, 0); + + (0..ct_ggsw_lhs.rank() + 1).for_each(|col_j| { + (0..ct_ggsw_lhs.rows()).for_each(|row_i| { + module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_ggsw_lhs, 0); + + if col_j > 0 { + module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1); + module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); + module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); + } + + ct_ggsw_lhs.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); + ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); + + let noise_have: f64 = pt.data.std(0, basek).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_ggsw_product( + module.n() as f64, + basek, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + rank as f64, + k_ggsw, + k_ggsw, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "have: {} want: {}", + noise_have, + noise_want + ); + + pt_want.data.zero(); + }); + }); +} +pub(crate) fn noise_ggsw_product( n: f64, - log_base2k: usize, + basek: usize, var_xs: f64, var_msg: f64, var_a0_err: f64, @@ -581,12 +351,12 @@ pub(crate) fn noise_ggsw_gglwe_product( b_logq: usize, ) -> f64 { let a_logq: usize = a_logq.min(b_logq); - let a_cols: usize = (a_logq + log_base2k - 1) / log_base2k; + let a_cols: usize = (a_logq + basek - 1) / basek; let b_scale = 2.0f64.powi(b_logq as i32); let a_scale: f64 = 2.0f64.powi((b_logq - a_logq) as i32); - let base: f64 = (1 << (log_base2k)) as f64; + let base: f64 = (1 << (basek)) as f64; let var_base: f64 = base * base / 12f64; // lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2) diff --git a/core/src/test_fft64/glwe.rs b/core/src/test_fft64/glwe.rs index 21bae6d..2d83791 100644 --- a/core/src/test_fft64/glwe.rs +++ b/core/src/test_fft64/glwe.rs @@ -13,7 +13,7 @@ use crate::{ glwe_plaintext::GLWEPlaintext, keys::{GLWEPublicKey, SecretKey, SecretKeyFourier}, keyswitch_key::GLWESwitchingKey, - test_fft64::{gglwe::noise_gglwe_product, ggsw::noise_ggsw_gglwe_product}, + test_fft64::{gglwe::noise_gglwe_product, ggsw::noise_ggsw_product}, }; #[test] @@ -498,7 +498,7 @@ fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, k_ct_in: usi let var_a0_err: f64 = sigma * sigma; let var_a1_err: f64 = 1f64 / 12f64; - let noise_want: f64 = noise_ggsw_gglwe_product( + let noise_want: f64 = noise_ggsw_product( module.n() as f64, basek, 0.5, @@ -595,7 +595,7 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, k_ct let var_a0_err: f64 = sigma * sigma; let var_a1_err: f64 = 1f64 / 12f64; - let noise_want: f64 = noise_ggsw_gglwe_product( + let noise_want: f64 = noise_ggsw_product( module.n() as f64, basek, 0.5, diff --git a/core/src/test_fft64/glwe_fourier.rs b/core/src/test_fft64/glwe_fourier.rs index d5ed622..c737c55 100644 --- a/core/src/test_fft64/glwe_fourier.rs +++ b/core/src/test_fft64/glwe_fourier.rs @@ -6,7 +6,7 @@ use crate::{ glwe_plaintext::GLWEPlaintext, keys::{SecretKey, SecretKeyFourier}, keyswitch_key::GLWESwitchingKey, - test_fft64::{gglwe::noise_gglwe_product, ggsw::noise_ggsw_gglwe_product}, + test_fft64::{gglwe::noise_gglwe_product, ggsw::noise_ggsw_product}, }; use base2k::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, ZnxViewMut}; use sampling::source::Source; @@ -322,7 +322,7 @@ fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, k_ct_in: usi let var_a0_err: f64 = sigma * sigma; let var_a1_err: f64 = 1f64 / 12f64; - let noise_want: f64 = noise_ggsw_gglwe_product( + let noise_want: f64 = noise_ggsw_product( module.n() as f64, basek, 0.5, @@ -422,7 +422,7 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, k_ct let var_a0_err: f64 = sigma * sigma; let var_a1_err: f64 = 1f64 / 12f64; - let noise_want: f64 = noise_ggsw_gglwe_product( + let noise_want: f64 = noise_ggsw_product( module.n() as f64, basek, 0.5, diff --git a/core/src/utils.rs b/core/src/utils.rs index 0bb0b45..c3bc5d5 100644 --- a/core/src/utils.rs +++ b/core/src/utils.rs @@ -1,3 +1,3 @@ -pub(crate) fn derive_size(log_base2k: usize, log_k: usize) -> usize { - (log_k + log_base2k - 1) / log_base2k +pub(crate) fn derive_size(basek: usize, k: usize) -> usize { + (k + basek - 1) / basek } diff --git a/core/src/vec_glwe_product.rs b/core/src/vec_glwe_product.rs deleted file mode 100644 index 08afa1e..0000000 --- a/core/src/vec_glwe_product.rs +++ /dev/null @@ -1,218 +0,0 @@ -use base2k::{ - FFT64, Module, Scratch, VecZnx, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, - VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxZero, -}; - -use crate::{ - elem::{GetRow, Infos, SetRow}, - glwe_ciphertext::GLWECiphertext, - glwe_ciphertext_fourier::GLWECiphertextFourier, -}; - -pub(crate) trait VecGLWEProductScratchSpace { - fn prod_with_glwe_scratch_space( - module: &Module, - res_size: usize, - lhs: usize, - rhs: usize, - rank_in: usize, - rank_out: usize, - ) -> usize; - - fn prod_with_glwe_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { - Self::prod_with_glwe_scratch_space(module, res_size, res_size, rhs, rank, rank) - } - - fn prod_with_glwe_fourier_scratch_space( - module: &Module, - res_size: usize, - lhs: usize, - rhs: usize, - rank_in: usize, - rank_out: usize, - ) -> usize { - (Self::prod_with_glwe_scratch_space(module, res_size, lhs, rhs, rank_in, rank_out) | module.vec_znx_idft_tmp_bytes()) - + module.bytes_of_vec_znx(rank_in + 1, lhs) - + module.bytes_of_vec_znx(rank_out + 1, res_size) - } - - fn prod_with_glwe_fourier_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { - (Self::prod_with_glwe_inplace_scratch_space(module, res_size, rhs, rank) | module.vec_znx_idft_tmp_bytes()) - + module.bytes_of_vec_znx(rank + 1, res_size) - } - - fn prod_with_vec_glwe_scratch_space( - module: &Module, - res_size: usize, - lhs: usize, - rhs: usize, - rank_in: usize, - rank_out: usize, - ) -> usize { - Self::prod_with_glwe_fourier_scratch_space(module, res_size, lhs, rhs, rank_in, rank_out) - + module.bytes_of_vec_znx_dft(rank_in + 1, lhs) - + module.bytes_of_vec_znx_dft(rank_out + 1, res_size) - } - - fn prod_with_vec_glwe_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { - Self::prod_with_glwe_fourier_inplace_scratch_space(module, res_size, rhs, rank) - + module.bytes_of_vec_znx_dft(rank + 1, res_size) - } -} - -pub(crate) trait VecGLWEProduct: Infos { - fn prod_with_glwe( - &self, - module: &Module, - res: &mut GLWECiphertext, - a: &GLWECiphertext, - scratch: &mut Scratch, - ) where - VecZnx: VecZnxToMut, - VecZnx: VecZnxToRef; - - fn prod_with_glwe_inplace(&self, module: &Module, res: &mut GLWECiphertext, scratch: &mut Scratch) - where - VecZnx: VecZnxToMut + VecZnxToRef, - { - unsafe { - let res_ptr: *mut GLWECiphertext = res as *mut GLWECiphertext; // This is ok because [Self::mul_rlwe] only updates res at the end. - self.prod_with_glwe(&module, &mut *res_ptr, &*res_ptr, scratch); - } - } - - fn prod_with_glwe_fourier( - &self, - module: &Module, - res: &mut GLWECiphertextFourier, - a: &GLWECiphertextFourier, - scratch: &mut Scratch, - ) where - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef + ZnxInfos, - VecZnxDft: VecZnxDftToRef + ZnxInfos, - { - let log_base2k: usize = self.basek(); - - #[cfg(debug_assertions)] - { - assert_eq!(res.basek(), log_base2k); - assert_eq!(self.n(), module.n()); - assert_eq!(res.n(), module.n()); - } - - let (a_data, scratch_1) = scratch.tmp_vec_znx(module, a.rank() + 1, a.size()); - - let mut a_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { - data: a_data, - basek: a.basek(), - k: a.k(), - }; - - a.idft(module, &mut a_idft, scratch_1); - - let (res_data, scratch_2) = scratch_1.tmp_vec_znx(module, res.rank() + 1, res.size()); - - let mut res_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { - data: res_data, - basek: res.basek(), - k: res.k(), - }; - - self.prod_with_glwe(module, &mut res_idft, &a_idft, scratch_2); - - res_idft.dft(module, res); - } - - fn prod_with_glwe_fourier_inplace( - &self, - module: &Module, - res: &mut GLWECiphertextFourier, - scratch: &mut Scratch, - ) where - VecZnxDft: VecZnxDftToRef + VecZnxDftToMut, - { - let log_base2k: usize = self.basek(); - - #[cfg(debug_assertions)] - { - assert_eq!(res.basek(), log_base2k); - assert_eq!(self.n(), module.n()); - assert_eq!(res.n(), module.n()); - } - - let (res_data, scratch_1) = scratch.tmp_vec_znx(module, res.rank() + 1, res.size()); - - let mut res_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { - data: res_data, - basek: res.basek(), - k: res.k(), - }; - - res.idft(module, &mut res_idft, scratch_1); - - self.prod_with_glwe_inplace(module, &mut res_idft, scratch_1); - - res_idft.dft(module, res); - } - - fn prod_with_vec_glwe(&self, module: &Module, res: &mut RES, a: &LHS, scratch: &mut Scratch) - where - LHS: GetRow + Infos, - RES: SetRow + Infos, - { - let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, a.cols(), a.size()); - - let mut tmp_a_row: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { - data: tmp_row_data, - basek: a.basek(), - k: a.k(), - }; - - let (tmp_res_data, scratch2) = scratch1.tmp_vec_znx_dft(module, res.cols(), res.size()); - - let mut tmp_res_row: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { - data: tmp_res_data, - basek: res.basek(), - k: res.k(), - }; - - let min_rows: usize = res.rows().min(a.rows()); - - (0..res.rows()).for_each(|row_i| { - (0..res.cols()).for_each(|col_j| { - a.get_row(module, row_i, col_j, &mut tmp_a_row); - self.prod_with_glwe_fourier(module, &mut tmp_res_row, &tmp_a_row, scratch2); - res.set_row(module, row_i, col_j, &tmp_res_row); - }); - }); - - tmp_res_row.data.zero(); - - (min_rows..res.rows()).for_each(|row_i| { - (0..self.cols()).for_each(|col_j| { - res.set_row(module, row_i, col_j, &tmp_res_row); - }); - }); - } - - fn prod_with_vec_glwe_inplace(&self, module: &Module, res: &mut RES, scratch: &mut Scratch) - where - RES: GetRow + SetRow + Infos, - { - let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, res.cols(), res.size()); - - let mut tmp_row: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { - data: tmp_row_data, - basek: res.basek(), - k: res.k(), - }; - - (0..res.rows()).for_each(|row_i| { - (0..res.cols()).for_each(|col_j| { - res.get_row(module, row_i, col_j, &mut tmp_row); - self.prod_with_glwe_fourier_inplace(module, &mut tmp_row, scratch1); - res.set_row(module, row_i, col_j, &tmp_row); - }); - }); - } -}