From f517a730a3fba5c2d14f2fcd5088f1156065814c Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Wed, 14 May 2025 16:34:52 +0200 Subject: [PATCH] updated key-switch for rank switching & updated glwe key-switching test --- core/src/gglwe_ciphertext.rs | 118 +-- core/src/ggsw_ciphertext.rs | 62 +- core/src/glwe_ciphertext.rs | 33 +- core/src/glwe_ciphertext_fourier.rs | 33 +- core/src/keyswitch_key.rs | 46 +- core/src/test_fft64/gglwe.rs | 1003 +++++++++++------------ core/src/test_fft64/ggsw.rs | 1139 +++++++++++++-------------- core/src/test_fft64/glwe.rs | 100 ++- core/src/test_fft64/glwe_fourier.rs | 876 ++++++++++---------- core/src/vec_glwe_product.rs | 72 +- 10 files changed, 1806 insertions(+), 1676 deletions(-) diff --git a/core/src/gglwe_ciphertext.rs b/core/src/gglwe_ciphertext.rs index 9d2c79a..2a86c63 100644 --- a/core/src/gglwe_ciphertext.rs +++ b/core/src/gglwe_ciphertext.rs @@ -52,6 +52,14 @@ impl GGLWECiphertext { pub fn rank(&self) -> usize { self.data.cols_out() - 1 } + + pub fn rank_in(&self) -> usize { + self.data.cols_in() + } + + pub fn rank_out(&self) -> usize { + self.data.cols_out() - 1 + } } impl MatZnxDftToMut for GGLWECiphertext @@ -104,7 +112,8 @@ where { #[cfg(debug_assertions)] { - assert_eq!(self.rank(), sk_dft.rank()); + assert_eq!(self.rank_in(), pt.cols()); + assert_eq!(self.rank_out(), sk_dft.rank()); assert_eq!(self.n(), module.n()); assert_eq!(sk_dft.n(), module.n()); assert_eq!(pt.n(), module.n()); @@ -115,11 +124,12 @@ where let basek: usize = self.basek(); let k: usize = self.k(); - let cols: usize = self.rank() + 1; + let cols_in: usize = self.rank_in(); + let cols_out: usize = self.rank_out() + 1; let (tmp_znx_pt, scrach_1) = scratch.tmp_vec_znx(module, 1, size); - let (tmp_znx_ct, scrach_2) = scrach_1.tmp_vec_znx(module, cols, size); - let (tmp_znx_dft_ct, scratch_3) = scrach_2.tmp_vec_znx_dft(module, cols, size); + let (tmp_znx_ct, scrach_2) = scrach_1.tmp_vec_znx(module, cols_out, size); + let (tmp_znx_dft_ct, scratch_3) = scrach_2.tmp_vec_znx_dft(module, cols_out, size); let mut vec_znx_pt: GLWEPlaintext<&mut [u8]> = GLWEPlaintext { data: tmp_znx_pt, @@ -139,29 +149,42 @@ where k, }; - (0..rows).for_each(|row_i| { - // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt - module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_i, pt, 0); - module.vec_znx_normalize_inplace(basek, &mut vec_znx_pt, 0, scratch_3); + // For each input column (i.e. rank) produces a GGLWE ciphertext of rank_out+1 columns + // + // Example for ksk rank 2 to rank 3: + // + // (-(a0*s0 + a1*s1 + a2*s2) + s0', a0, a1, a2) + // (-(b0*s0 + b1*s1 + b2*s2) + s0', b0, b1, b2) + // + // Example ksk rank 2 to rank 1 + // + // (-(a*s) + s0, a) + // (-(b*s) + s1, b) + (0..cols_in).for_each(|col_i| { + (0..rows).for_each(|row_i| { + // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt + module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_i, pt, col_i); // Selects the i-th + module.vec_znx_normalize_inplace(basek, &mut vec_znx_pt, 0, scratch_3); - // rlwe encrypt of vec_znx_pt into vec_znx_ct - vec_znx_ct.encrypt_sk( - module, - &vec_znx_pt, - sk_dft, - source_xa, - source_xe, - sigma, - scratch_3, - ); + // rlwe encrypt of vec_znx_pt into vec_znx_ct + vec_znx_ct.encrypt_sk( + module, + &vec_znx_pt, + sk_dft, + source_xa, + source_xe, + sigma, + scratch_3, + ); - vec_znx_pt.data.zero(); // zeroes for next iteration + vec_znx_pt.data.zero(); // zeroes for next iteration - // Switch vec_znx_ct into DFT domain - vec_znx_ct.dft(module, &mut vec_znx_ct_dft); + // Switch vec_znx_ct into DFT domain + vec_znx_ct.dft(module, &mut vec_znx_ct_dft); - // Stores vec_znx_dft_ct into thw i-th row of the MatZnxDft - module.vmp_prepare_row(self, row_i, 0, &vec_znx_ct_dft); + // Stores vec_znx_dft_ct into thw i-th row of the MatZnxDft + module.vmp_prepare_row(self, row_i, col_i, &vec_znx_ct_dft); + }); }); } } @@ -174,10 +197,6 @@ where where VecZnxDft: VecZnxDftToMut, { - #[cfg(debug_assertions)] - { - assert_eq!(col_j, 0); - } module.vmp_extract_row(res, self, row_i, col_j); } } @@ -190,20 +209,23 @@ where where VecZnxDft: VecZnxDftToRef, { - #[cfg(debug_assertions)] - { - assert_eq!(col_j, 0); - } module.vmp_prepare_row(self, row_i, col_j, a); } } impl VecGLWEProductScratchSpace for GGLWECiphertext, FFT64> { - fn prod_with_glwe_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { - module.bytes_of_vec_znx_dft(2, grlwe_size) + fn prod_with_glwe_scratch_space( + module: &Module, + res_size: usize, + a_size: usize, + grlwe_size: usize, + rank_in: usize, + rank_out: usize, + ) -> usize { + module.bytes_of_vec_znx_dft(rank_out + 1, grlwe_size) + (module.vec_znx_big_normalize_tmp_bytes() - | (module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 1, 2, grlwe_size) - + module.bytes_of_vec_znx_dft(1, a_size))) + | (module.vmp_apply_tmp_bytes(res_size, a_size, a_size, rank_in, rank_out + 1, grlwe_size) + + module.bytes_of_vec_znx_dft(rank_in, a_size))) } } @@ -222,30 +244,38 @@ where VecZnx: VecZnxToMut, VecZnx: VecZnxToRef, { - let log_base2k: usize = self.basek(); + let basek: usize = self.basek(); #[cfg(debug_assertions)] { - assert_eq!(res.basek(), log_base2k); - assert_eq!(a.basek(), log_base2k); + assert_eq!(a.rank(), self.rank_in()); + assert_eq!(res.rank(), self.rank_out()); + assert_eq!(res.basek(), basek); + assert_eq!(a.basek(), basek); assert_eq!(self.n(), module.n()); assert_eq!(res.n(), module.n()); assert_eq!(a.n(), module.n()); } - let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, 2, self.size()); // Todo optimise + let cols_in: usize = self.rank_in(); + let cols_out: usize = self.rank_out() + 1; + + let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, self.size()); // Todo optimise { - let (mut a1_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, 1, a.size()); - module.vec_znx_dft(&mut a1_dft, 0, a, 1); - module.vmp_apply(&mut res_dft, &a1_dft, self, scratch2); + let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, a.size()); + (0..cols_in).for_each(|col_i| { + module.vec_znx_dft(&mut ai_dft, col_i, a, col_i + 1); + }); + module.vmp_apply(&mut res_dft, &ai_dft, self, scratch2); } let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0); - module.vec_znx_big_normalize(log_base2k, res, 0, &res_big, 0, scratch1); - module.vec_znx_big_normalize(log_base2k, res, 1, &res_big, 1, scratch1); + (0..cols_out).for_each(|i| { + module.vec_znx_big_normalize(basek, res, i, &res_big, i, scratch1); + }); } } diff --git a/core/src/ggsw_ciphertext.rs b/core/src/ggsw_ciphertext.rs index d277d78..625e09b 100644 --- a/core/src/ggsw_ciphertext.rs +++ b/core/src/ggsw_ciphertext.rs @@ -82,27 +82,34 @@ impl GGSWCiphertext, FFT64> { + module.bytes_of_vec_znx_dft(rank + 1, size) } - pub fn keyswitch_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + pub fn keyswitch_scratch_space( + module: &Module, + res_size: usize, + lhs: usize, + rhs: usize, + rank_in: usize, + rank_out: usize, + ) -> usize { , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space( - module, res_size, lhs, rhs, + module, res_size, lhs, rhs, rank_in, rank_out, ) } - pub fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + pub fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_inplace_scratch_space( - module, res_size, rhs, + module, res_size, rhs, rank, ) } - pub fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + pub fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize, rank: usize) -> usize { , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space( - module, res_size, lhs, rhs, + module, res_size, lhs, rhs, rank, rank, ) } - pub fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + pub fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( - module, res_size, rhs, + module, res_size, rhs, rank, ) } } @@ -265,9 +272,24 @@ where } impl VecGLWEProductScratchSpace for GGSWCiphertext, FFT64> { - fn prod_with_glwe_scratch_space(module: &Module, res_size: usize, a_size: usize, rgsw_size: usize) -> usize { - module.bytes_of_vec_znx_dft(2, rgsw_size) - + ((module.bytes_of_vec_znx_dft(2, a_size) + module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 2, 2, rgsw_size)) + fn prod_with_glwe_scratch_space( + module: &Module, + res_size: usize, + a_size: usize, + rgsw_size: usize, + rank_in: usize, + rank_out: usize, + ) -> usize { + module.bytes_of_vec_znx_dft(rank_out + 1, rgsw_size) + + ((module.bytes_of_vec_znx_dft(rank_in + 1, a_size) + + module.vmp_apply_tmp_bytes( + res_size, + a_size, + a_size, + rank_in + 1, + rank_out + 1, + rgsw_size, + )) | module.vec_znx_big_normalize_tmp_bytes()) } } @@ -290,6 +312,8 @@ where #[cfg(debug_assertions)] { + assert_eq!(self.rank(), a.rank()); + assert_eq!(self.rank(), res.rank()); assert_eq!(res.basek(), log_base2k); assert_eq!(a.basek(), log_base2k); assert_eq!(self.n(), module.n()); @@ -297,18 +321,22 @@ where assert_eq!(a.n(), module.n()); } - let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, 2, self.size()); // Todo optimise + let cols: usize = self.rank() + 1; + + let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, self.size()); // Todo optimise { - let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, 2, a.size()); - module.vec_znx_dft(&mut a_dft, 0, a, 0); - module.vec_znx_dft(&mut a_dft, 1, a, 1); + let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, a.size()); + (0..cols).for_each(|col_i| { + module.vec_znx_dft(&mut a_dft, col_i, a, col_i); + }); module.vmp_apply(&mut res_dft, &a_dft, self, scratch2); } let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); - module.vec_znx_big_normalize(log_base2k, res, 0, &res_big, 0, scratch1); - module.vec_znx_big_normalize(log_base2k, res, 1, &res_big, 1, scratch1); + (0..cols).for_each(|i| { + module.vec_znx_big_normalize(log_base2k, res, i, &res_big, i, scratch1); + }); } } diff --git a/core/src/glwe_ciphertext.rs b/core/src/glwe_ciphertext.rs index 5f8d086..063ceb2 100644 --- a/core/src/glwe_ciphertext.rs +++ b/core/src/glwe_ciphertext.rs @@ -113,23 +113,34 @@ impl GLWECiphertext> { + module.bytes_of_vec_znx_big(1, ct_size) } - pub fn keyswitch_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(module, res_size, lhs, rhs) - } - - pub fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( - module, res_size, rhs, + pub fn keyswitch_scratch_space( + module: &Module, + res_size: usize, + lhs: usize, + rhs: usize, + rank_in: usize, + rank_out: usize, + ) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space( + module, res_size, lhs, rhs, rank_in, rank_out, ) } - pub fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(module, res_size, lhs, rhs) + pub fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( + module, res_size, rhs, rank, + ) } - pub fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + pub fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize, rank: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space( + module, res_size, lhs, rhs, rank, rank, + ) + } + + pub fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( - module, res_size, rhs, + module, res_size, rhs, rank, ) } } diff --git a/core/src/glwe_ciphertext_fourier.rs b/core/src/glwe_ciphertext_fourier.rs index e31d0dc..a16aba8 100644 --- a/core/src/glwe_ciphertext_fourier.rs +++ b/core/src/glwe_ciphertext_fourier.rs @@ -90,23 +90,34 @@ impl GLWECiphertextFourier, FFT64> { + module.bytes_of_vec_znx_big(1, ct_size) } - pub fn keyswitch_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(module, res_size, lhs, rhs) - } - - pub fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( - module, res_size, rhs, + pub fn keyswitch_scratch_space( + module: &Module, + res_size: usize, + lhs: usize, + rhs: usize, + rank_in: usize, + rank_out: usize, + ) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space( + module, res_size, lhs, rhs, rank_in, rank_out, ) } - pub fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(module, res_size, lhs, rhs) + pub fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( + module, res_size, rhs, rank, + ) } - pub fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + pub fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize, rank: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space( + module, res_size, lhs, rhs, rank, rank, + ) + } + + pub fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( - module, res_size, rhs, + module, res_size, rhs, rank, ) } } diff --git a/core/src/keyswitch_key.rs b/core/src/keyswitch_key.rs index f9500a7..37774eb 100644 --- a/core/src/keyswitch_key.rs +++ b/core/src/keyswitch_key.rs @@ -9,7 +9,7 @@ use crate::{ gglwe_ciphertext::GGLWECiphertext, ggsw_ciphertext::GGSWCiphertext, glwe_ciphertext_fourier::GLWECiphertextFourier, - keys::SecretKeyFourier, + keys::{SecretKey, SecretKeyFourier}, vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace}, }; @@ -103,46 +103,60 @@ impl GLWESwitchingKey where MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, { - pub fn encrypt_sk( + pub fn encrypt_sk( &mut self, module: &Module, - pt: &ScalarZnx, - sk_dft: &SecretKeyFourier, + sk_in: &SecretKey, + sk_out_dft: &SecretKeyFourier, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, scratch: &mut Scratch, ) where - ScalarZnx: ScalarZnxToRef, - ScalarZnxDft: ScalarZnxDftToRef, + ScalarZnx: ScalarZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, { - self.0 - .encrypt_sk(module, pt, sk_dft, source_xa, source_xe, sigma, scratch); + self.0.encrypt_sk( + module, + &sk_in.data, + sk_out_dft, + source_xa, + source_xe, + sigma, + scratch, + ); } } impl GLWESwitchingKey, FFT64> { - pub fn keyswitch_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + pub fn keyswitch_scratch_space( + module: &Module, + res_size: usize, + lhs: usize, + rhs: usize, + rank_in: usize, + rank_out: usize, + ) -> usize { , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space( - module, res_size, lhs, rhs, + module, res_size, lhs, rhs, rank_in, rank_out, ) } - pub fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + pub fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_inplace_scratch_space( - module, res_size, rhs, + module, res_size, rhs, rank, ) } - pub fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + pub fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize, rank: usize) -> usize { , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space( - module, res_size, lhs, rhs, + module, res_size, lhs, rhs, rank, rank, ) } - pub fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + pub fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( - module, res_size, rhs, + module, res_size, rhs, rank, ) } } diff --git a/core/src/test_fft64/gglwe.rs b/core/src/test_fft64/gglwe.rs index e4a566d..8327325 100644 --- a/core/src/test_fft64/gglwe.rs +++ b/core/src/test_fft64/gglwe.rs @@ -1,504 +1,503 @@ -use base2k::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, ZnxViewMut}; -use sampling::source::Source; - -use crate::{ - elem::{GetRow, Infos}, - ggsw_ciphertext::GGSWCiphertext, - glwe_ciphertext_fourier::GLWECiphertextFourier, - glwe_plaintext::GLWEPlaintext, - keys::{SecretKey, SecretKeyFourier}, - keyswitch_key::GLWESwitchingKey, - test_fft64::ggsw::noise_rgsw_product, -}; - -#[test] -fn encrypt_sk() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 8; - let log_k_ct: usize = 54; - let rows: usize = 4; - let rank: usize = 1; - let rank_out: usize = 1; - - let sigma: f64 = 3.2; - - let mut ct: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, log_base2k, log_k_ct, rows, rank, rank_out); - let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); - let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct.size()) - | GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size()), - ); - - let mut sk: SecretKey> = SecretKey::new(&module, rank); - // sk.fill_ternary_prob(0.5, &mut source_xs); - sk.fill_zero(); - - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk_dft.dft(&module, &sk); - - ct.encrypt_sk( - &module, - &pt_scalar, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_ct, rank); - - (0..ct.rows()).for_each(|row_i| { - ct.get_row(&module, row_i, 0, &mut ct_rlwe_dft); - ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); - module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_scalar, 0); - let std_pt: f64 = pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2(); - assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); - }); -} - -#[test] -fn keyswitch() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; - - let rank: usize = 1; - - let sigma: f64 = 3.2; - - let mut ct_grlwe_s0s1: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); - let mut ct_grlwe_s1s2: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); - let mut ct_grlwe_s0s2: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe_s0s1.size()) - | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_s0s2.size()) - | GLWESwitchingKey::keyswitch_scratch_space( - &module, - ct_grlwe_s0s2.size(), - ct_grlwe_s0s1.size(), - ct_grlwe_s1s2.size(), - ), - ); - - let mut sk0: SecretKey> = SecretKey::new(&module, rank); - sk0.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk0_dft.dft(&module, &sk0); - - let mut sk1: SecretKey> = SecretKey::new(&module, rank); - sk1.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk1_dft.dft(&module, &sk1); - - let mut sk2: SecretKey> = SecretKey::new(&module, rank); - sk2.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk2_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk2_dft.dft(&module, &sk2); - - // GRLWE_{s1}(s0) = s0 -> s1 - ct_grlwe_s0s1.encrypt_sk( - &module, - &sk0.data, - &sk1_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - // GRLWE_{s2}(s1) -> s1 -> s2 - ct_grlwe_s1s2.encrypt_sk( - &module, - &sk1.data, - &sk2_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - // GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) - ct_grlwe_s0s2.keyswitch(&module, &ct_grlwe_s0s1, &ct_grlwe_s1s2, scratch.borrow()); - - let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe, rank); - let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); - - (0..ct_grlwe_s0s2.rows()).for_each(|row_i| { - ct_grlwe_s0s2.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); - ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); - module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, 0); - - let noise_have: f64 = pt.data.std(0, log_base2k).log2(); - let noise_want: f64 = noise_grlwe_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - log_k_grlwe, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - }); -} - -#[test] -fn keyswitch_inplace() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; - - let rank: usize = 1; - let rank_out: usize = 1; - - let sigma: f64 = 3.2; - - let mut ct_grlwe_s0s1: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); - let mut ct_grlwe_s1s2: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe_s0s1.size()) - | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_s0s1.size()) - | GLWESwitchingKey::keyswitch_inplace_scratch_space(&module, ct_grlwe_s0s1.size(), ct_grlwe_s1s2.size()), - ); - - let mut sk0: SecretKey> = SecretKey::new(&module, rank); - sk0.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk0_dft.dft(&module, &sk0); - - let mut sk1: SecretKey> = SecretKey::new(&module, rank); - sk1.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk1_dft.dft(&module, &sk1); - - let mut sk2: SecretKey> = SecretKey::new(&module, rank); - sk2.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk2_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk2_dft.dft(&module, &sk2); - - // GRLWE_{s1}(s0) = s0 -> s1 - ct_grlwe_s0s1.encrypt_sk( - &module, - &sk0.data, - &sk1_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - // GRLWE_{s2}(s1) -> s1 -> s2 - ct_grlwe_s1s2.encrypt_sk( - &module, - &sk1.data, - &sk2_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - // GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) - ct_grlwe_s0s1.keyswitch_inplace(&module, &ct_grlwe_s1s2, scratch.borrow()); - - let ct_grlwe_s0s2: GLWESwitchingKey, FFT64> = ct_grlwe_s0s1; - - let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe, rank); - let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); - - (0..ct_grlwe_s0s2.rows()).for_each(|row_i| { - ct_grlwe_s0s2.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); - ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); - module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, 0); - - let noise_have: f64 = pt.data.std(0, log_base2k).log2(); - let noise_want: f64 = noise_grlwe_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - log_k_grlwe, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - }); -} - -#[test] -fn external_product() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; - - let rank: usize = 1; - let rank_out: usize = 1; - - let sigma: f64 = 3.2; - - let mut ct_grlwe_in: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); - let mut ct_grlwe_out: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); - let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); - - let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_grlwe: ScalarZnx> = module.new_scalar_znx(1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe_in.size()) - | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_out.size()) - | GLWESwitchingKey::external_product_scratch_space( - &module, - ct_grlwe_out.size(), - ct_grlwe_in.size(), - ct_rgsw.size(), - ) - | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()), - ); - - let k: usize = 1; - - pt_rgsw.raw_mut()[k] = 1; // X^{k} - - pt_grlwe.fill_ternary_prob(0, 0.5, &mut source_xs); - - let mut sk: SecretKey> = SecretKey::new(&module, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk_dft.dft(&module, &sk); - - // GRLWE_{s1}(s0) = s0 -> s1 - ct_grlwe_in.encrypt_sk( - &module, - &pt_grlwe, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_rgsw.encrypt_sk( - &module, - &pt_rgsw, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - // GRLWE_(m) (x) RGSW_(X^k) = GRLWE_(m * X^k) - ct_grlwe_out.external_product(&module, &ct_grlwe_in, &ct_rgsw, scratch.borrow()); - - let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe, rank); - let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); - - module.vec_znx_rotate_inplace(k as i64, &mut pt_grlwe, 0); - - (0..ct_grlwe_out.rows()).for_each(|row_i| { - ct_grlwe_out.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); - ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); - module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_grlwe, 0); - - let noise_have: f64 = pt.data.std(0, log_base2k).log2(); - - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_want: f64 = noise_rgsw_product( - module.n() as f64, - log_base2k, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - log_k_grlwe, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - }); -} - -#[test] -fn external_product_inplace() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; - - let rank: usize = 1; - let rank_out: usize = 1; - - let sigma: f64 = 3.2; - - let mut ct_grlwe: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); - let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); - - let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_grlwe: ScalarZnx> = module.new_scalar_znx(1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) - | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe.size()) - | GLWESwitchingKey::external_product_inplace_scratch_space(&module, ct_grlwe.size(), ct_rgsw.size()) - | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()), - ); - - let k: usize = 1; - - pt_rgsw.raw_mut()[k] = 1; // X^{k} - - pt_grlwe.fill_ternary_prob(0, 0.5, &mut source_xs); - - let mut sk: SecretKey> = SecretKey::new(&module, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk_dft.dft(&module, &sk); - - // GRLWE_{s1}(s0) = s0 -> s1 - ct_grlwe.encrypt_sk( - &module, - &pt_grlwe, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_rgsw.encrypt_sk( - &module, - &pt_rgsw, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - // GRLWE_(m) (x) RGSW_(X^k) = GRLWE_(m * X^k) - ct_grlwe.external_product_inplace(&module, &ct_rgsw, scratch.borrow()); - - let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe, rank); - let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); - - module.vec_znx_rotate_inplace(k as i64, &mut pt_grlwe, 0); - - (0..ct_grlwe.rows()).for_each(|row_i| { - ct_grlwe.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); - ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); - module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_grlwe, 0); - - let noise_have: f64 = pt.data.std(0, log_base2k).log2(); - - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_want: f64 = noise_rgsw_product( - module.n() as f64, - log_base2k, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - log_k_grlwe, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); - }); -} - -pub(crate) fn noise_grlwe_rlwe_product( +// use base2k::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, ZnxViewMut}; +// use sampling::source::Source; +// +// use crate::{ +// elem::{GetRow, Infos}, +// ggsw_ciphertext::GGSWCiphertext, +// glwe_ciphertext_fourier::GLWECiphertextFourier, +// glwe_plaintext::GLWEPlaintext, +// keys::{SecretKey, SecretKeyFourier}, +// keyswitch_key::GLWESwitchingKey, +// test_fft64::ggsw::noise_rgsw_product, +// }; +// +// #[test] +// fn encrypt_sk() { +// let module: Module = Module::::new(2048); +// let log_base2k: usize = 8; +// let log_k_ct: usize = 54; +// let rows: usize = 4; +// let rank: usize = 1; +// let rank_out: usize = 1; +// +// let sigma: f64 = 3.2; +// +// let mut ct: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, log_base2k, log_k_ct, rows, rank, rank_out); +// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); +// let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); +// +// let mut source_xs: Source = Source::new([0u8; 32]); +// let mut source_xe: Source = Source::new([0u8; 32]); +// let mut source_xa: Source = Source::new([0u8; 32]); +// +// pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); +// +// let mut scratch: ScratchOwned = ScratchOwned::new( +// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct.size()) +// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size()), +// ); +// +// let mut sk: SecretKey> = SecretKey::new(&module, rank); +// sk.fill_ternary_prob(0.5, &mut source_xs); +// sk.fill_zero(); +// +// let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk_dft.dft(&module, &sk); +// +// ct.encrypt_sk( +// &module, +// &pt_scalar, +// &sk_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_ct, rank); +// +// (0..ct.rows()).for_each(|row_i| { +// ct.get_row(&module, row_i, 0, &mut ct_rlwe_dft); +// ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); +// module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_scalar, 0); +// let std_pt: f64 = pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2(); +// assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); +// }); +// } +// +// #[test] +// fn keyswitch() { +// let module: Module = Module::::new(2048); +// let log_base2k: usize = 12; +// let log_k_grlwe: usize = 60; +// let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; +// +// let rank: usize = 1; +// +// let sigma: f64 = 3.2; +// +// let mut ct_grlwe_s0s1: GLWESwitchingKey, FFT64> = +// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); +// let mut ct_grlwe_s1s2: GLWESwitchingKey, FFT64> = +// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); +// let mut ct_grlwe_s0s2: GLWESwitchingKey, FFT64> = +// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); +// +// let mut source_xs: Source = Source::new([0u8; 32]); +// let mut source_xe: Source = Source::new([0u8; 32]); +// let mut source_xa: Source = Source::new([0u8; 32]); +// +// let mut scratch: ScratchOwned = ScratchOwned::new( +// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe_s0s1.size()) +// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_s0s2.size()) +// | GLWESwitchingKey::keyswitch_scratch_space( +// &module, +// ct_grlwe_s0s2.size(), +// ct_grlwe_s0s1.size(), +// ct_grlwe_s1s2.size(), +// ), +// ); +// +// let mut sk0: SecretKey> = SecretKey::new(&module, rank); +// sk0.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk0_dft.dft(&module, &sk0); +// +// let mut sk1: SecretKey> = SecretKey::new(&module, rank); +// sk1.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk1_dft.dft(&module, &sk1); +// +// let mut sk2: SecretKey> = SecretKey::new(&module, rank); +// sk2.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk2_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk2_dft.dft(&module, &sk2); +// +// GRLWE_{s1}(s0) = s0 -> s1 +// ct_grlwe_s0s1.encrypt_sk( +// &module, +// &sk0.data, +// &sk1_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// GRLWE_{s2}(s1) -> s1 -> s2 +// ct_grlwe_s1s2.encrypt_sk( +// &module, +// &sk1.data, +// &sk2_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) +// ct_grlwe_s0s2.keyswitch(&module, &ct_grlwe_s0s1, &ct_grlwe_s1s2, scratch.borrow()); +// +// let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = +// GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe, rank); +// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); +// +// (0..ct_grlwe_s0s2.rows()).for_each(|row_i| { +// ct_grlwe_s0s2.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); +// ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); +// module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, 0); +// +// let noise_have: f64 = pt.data.std(0, log_base2k).log2(); +// let noise_want: f64 = noise_grlwe_rlwe_product( +// module.n() as f64, +// log_base2k, +// 0.5, +// 0.5, +// 0f64, +// sigma * sigma, +// 0f64, +// log_k_grlwe, +// log_k_grlwe, +// ); +// +// assert!( +// (noise_have - noise_want).abs() <= 0.1, +// "{} {}", +// noise_have, +// noise_want +// ); +// }); +// } +// +// #[test] +// fn keyswitch_inplace() { +// let module: Module = Module::::new(2048); +// let log_base2k: usize = 12; +// let log_k_grlwe: usize = 60; +// let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; +// +// let rank: usize = 1; +// let rank_out: usize = 1; +// +// let sigma: f64 = 3.2; +// +// let mut ct_grlwe_s0s1: GLWESwitchingKey, FFT64> = +// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); +// let mut ct_grlwe_s1s2: GLWESwitchingKey, FFT64> = +// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); +// +// let mut source_xs: Source = Source::new([0u8; 32]); +// let mut source_xe: Source = Source::new([0u8; 32]); +// let mut source_xa: Source = Source::new([0u8; 32]); +// +// let mut scratch: ScratchOwned = ScratchOwned::new( +// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe_s0s1.size()) +// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_s0s1.size()) +// | GLWESwitchingKey::keyswitch_inplace_scratch_space(&module, ct_grlwe_s0s1.size(), ct_grlwe_s1s2.size()), +// ); +// +// let mut sk0: SecretKey> = SecretKey::new(&module, rank); +// sk0.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk0_dft.dft(&module, &sk0); +// +// let mut sk1: SecretKey> = SecretKey::new(&module, rank); +// sk1.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk1_dft.dft(&module, &sk1); +// +// let mut sk2: SecretKey> = SecretKey::new(&module, rank); +// sk2.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk2_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk2_dft.dft(&module, &sk2); +// +// GRLWE_{s1}(s0) = s0 -> s1 +// ct_grlwe_s0s1.encrypt_sk( +// &module, +// &sk0.data, +// &sk1_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// GRLWE_{s2}(s1) -> s1 -> s2 +// ct_grlwe_s1s2.encrypt_sk( +// &module, +// &sk1.data, +// &sk2_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) +// ct_grlwe_s0s1.keyswitch_inplace(&module, &ct_grlwe_s1s2, scratch.borrow()); +// +// let ct_grlwe_s0s2: GLWESwitchingKey, FFT64> = ct_grlwe_s0s1; +// +// let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = +// GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe, rank); +// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); +// +// (0..ct_grlwe_s0s2.rows()).for_each(|row_i| { +// ct_grlwe_s0s2.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); +// ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); +// module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, 0); +// +// let noise_have: f64 = pt.data.std(0, log_base2k).log2(); +// let noise_want: f64 = noise_grlwe_rlwe_product( +// module.n() as f64, +// log_base2k, +// 0.5, +// 0.5, +// 0f64, +// sigma * sigma, +// 0f64, +// log_k_grlwe, +// log_k_grlwe, +// ); +// +// assert!( +// (noise_have - noise_want).abs() <= 0.1, +// "{} {}", +// noise_have, +// noise_want +// ); +// }); +// } +// +// #[test] +// fn external_product() { +// let module: Module = Module::::new(2048); +// let log_base2k: usize = 12; +// let log_k_grlwe: usize = 60; +// let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; +// +// let rank: usize = 1; +// let rank_out: usize = 1; +// +// let sigma: f64 = 3.2; +// +// let mut ct_grlwe_in: GLWESwitchingKey, FFT64> = +// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); +// let mut ct_grlwe_out: GLWESwitchingKey, FFT64> = +// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); +// let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); +// +// let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); +// let mut pt_grlwe: ScalarZnx> = module.new_scalar_znx(1); +// +// let mut source_xs: Source = Source::new([0u8; 32]); +// let mut source_xe: Source = Source::new([0u8; 32]); +// let mut source_xa: Source = Source::new([0u8; 32]); +// +// let mut scratch: ScratchOwned = ScratchOwned::new( +// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe_in.size()) +// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_out.size()) +// | GLWESwitchingKey::external_product_scratch_space( +// &module, +// ct_grlwe_out.size(), +// ct_grlwe_in.size(), +// ct_rgsw.size(), +// ) +// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()), +// ); +// +// let k: usize = 1; +// +// pt_rgsw.raw_mut()[k] = 1; // X^{k} +// +// pt_grlwe.fill_ternary_prob(0, 0.5, &mut source_xs); +// +// let mut sk: SecretKey> = SecretKey::new(&module, rank); +// sk.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk_dft.dft(&module, &sk); +// +// GRLWE_{s1}(s0) = s0 -> s1 +// ct_grlwe_in.encrypt_sk( +// &module, +// &pt_grlwe, +// &sk_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// ct_rgsw.encrypt_sk( +// &module, +// &pt_rgsw, +// &sk_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// GRLWE_(m) (x) RGSW_(X^k) = GRLWE_(m * X^k) +// ct_grlwe_out.external_product(&module, &ct_grlwe_in, &ct_rgsw, scratch.borrow()); +// +// let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = +// GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe, rank); +// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); +// +// module.vec_znx_rotate_inplace(k as i64, &mut pt_grlwe, 0); +// +// (0..ct_grlwe_out.rows()).for_each(|row_i| { +// ct_grlwe_out.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); +// ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); +// module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_grlwe, 0); +// +// let noise_have: f64 = pt.data.std(0, log_base2k).log2(); +// +// let var_gct_err_lhs: f64 = sigma * sigma; +// let var_gct_err_rhs: f64 = 0f64; +// +// let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} +// let var_a0_err: f64 = sigma * sigma; +// let var_a1_err: f64 = 1f64 / 12f64; +// +// let noise_want: f64 = noise_rgsw_product( +// module.n() as f64, +// log_base2k, +// 0.5, +// var_msg, +// var_a0_err, +// var_a1_err, +// var_gct_err_lhs, +// var_gct_err_rhs, +// log_k_grlwe, +// log_k_grlwe, +// ); +// +// assert!( +// (noise_have - noise_want).abs() <= 0.1, +// "{} {}", +// noise_have, +// noise_want +// ); +// }); +// } +// +// #[test] +// fn external_product_inplace() { +// let module: Module = Module::::new(2048); +// let log_base2k: usize = 12; +// let log_k_grlwe: usize = 60; +// let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; +// +// let rank: usize = 1; +// let rank_out: usize = 1; +// +// let sigma: f64 = 3.2; +// +// let mut ct_grlwe: GLWESwitchingKey, FFT64> = +// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); +// let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); +// +// let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); +// let mut pt_grlwe: ScalarZnx> = module.new_scalar_znx(1); +// +// let mut source_xs: Source = Source::new([0u8; 32]); +// let mut source_xe: Source = Source::new([0u8; 32]); +// let mut source_xa: Source = Source::new([0u8; 32]); +// +// let mut scratch: ScratchOwned = ScratchOwned::new( +// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) +// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe.size()) +// | GLWESwitchingKey::external_product_inplace_scratch_space(&module, ct_grlwe.size(), ct_rgsw.size()) +// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()), +// ); +// +// let k: usize = 1; +// +// pt_rgsw.raw_mut()[k] = 1; // X^{k} +// +// pt_grlwe.fill_ternary_prob(0, 0.5, &mut source_xs); +// +// let mut sk: SecretKey> = SecretKey::new(&module, rank); +// sk.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk_dft.dft(&module, &sk); +// +// GRLWE_{s1}(s0) = s0 -> s1 +// ct_grlwe.encrypt_sk( +// &module, +// &pt_grlwe, +// &sk_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// ct_rgsw.encrypt_sk( +// &module, +// &pt_rgsw, +// &sk_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// GRLWE_(m) (x) RGSW_(X^k) = GRLWE_(m * X^k) +// ct_grlwe.external_product_inplace(&module, &ct_rgsw, scratch.borrow()); +// +// let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = +// GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe, rank); +// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); +// +// module.vec_znx_rotate_inplace(k as i64, &mut pt_grlwe, 0); +// +// (0..ct_grlwe.rows()).for_each(|row_i| { +// ct_grlwe.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); +// ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); +// module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_grlwe, 0); +// +// let noise_have: f64 = pt.data.std(0, log_base2k).log2(); +// +// let var_gct_err_lhs: f64 = sigma * sigma; +// let var_gct_err_rhs: f64 = 0f64; +// +// let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} +// let var_a0_err: f64 = sigma * sigma; +// let var_a1_err: f64 = 1f64 / 12f64; +// +// let noise_want: f64 = noise_rgsw_product( +// module.n() as f64, +// log_base2k, +// 0.5, +// var_msg, +// var_a0_err, +// var_a1_err, +// var_gct_err_lhs, +// var_gct_err_rhs, +// log_k_grlwe, +// log_k_grlwe, +// ); +// +// assert!( +// (noise_have - noise_want).abs() <= 0.1, +// "{} {}", +// noise_have, +// noise_want +// ); +// }); +// } +pub(crate) fn noise_gglwe_product( n: f64, log_base2k: usize, var_xs: f64, @@ -506,6 +505,7 @@ pub(crate) fn noise_grlwe_rlwe_product( var_a_err: f64, var_gct_err_lhs: f64, var_gct_err_rhs: f64, + rank_in: f64, a_logq: usize, b_logq: usize, ) -> f64 { @@ -522,6 +522,7 @@ pub(crate) fn noise_grlwe_rlwe_product( // rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs let mut noise: f64 = (a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs); noise += var_msg * var_a_err * a_scale * a_scale * n; + noise *= rank_in; noise = noise.sqrt(); noise /= b_scale; noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}] diff --git a/core/src/test_fft64/ggsw.rs b/core/src/test_fft64/ggsw.rs index f1903c1..c514ef9 100644 --- a/core/src/test_fft64/ggsw.rs +++ b/core/src/test_fft64/ggsw.rs @@ -1,573 +1,572 @@ -use base2k::{ - FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, - VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, VecZnxToMut, ZnxViewMut, ZnxZero, -}; -use sampling::source::Source; - -use crate::{ - elem::{GetRow, Infos}, - ggsw_ciphertext::GGSWCiphertext, - glwe_ciphertext_fourier::GLWECiphertextFourier, - glwe_plaintext::GLWEPlaintext, - keys::{SecretKey, SecretKeyFourier}, - keyswitch_key::GLWESwitchingKey, - test_fft64::gglwe::noise_grlwe_rlwe_product, -}; - -#[test] -fn encrypt_sk() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 8; - let log_k_ct: usize = 54; - let rows: usize = 4; - let rank: usize = 1; - - let sigma: f64 = 3.2; - - let mut ct: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_ct, rows, rank); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); - let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct.size()) - | GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size()), - ); - - let mut sk: SecretKey> = SecretKey::new(&module, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk_dft.dft(&module, &sk); - - ct.encrypt_sk( - &module, - &pt_scalar, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_ct, rank); - let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct.size()); - let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct.size()); - - (0..ct.rank()).for_each(|col_j| { - (0..ct.rows()).for_each(|row_i| { - module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0); - - if col_j == 1 { - module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0); - module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); - module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); - } - - ct.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); - - ct_rlwe_dft.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - let std_pt: f64 = pt_have.data.std(0, log_base2k) * (log_k_ct as f64).exp2(); - assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); - - pt_want.data.zero(); - }); - }); -} - -#[test] -fn keyswitch() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rgsw_in: usize = 45; - let log_k_rgsw_out: usize = 45; - let rows: usize = (log_k_rgsw_in + log_base2k - 1) / log_base2k; - - let rank: usize = 1; - - let sigma: f64 = 3.2; - - let mut ct_grlwe: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); - let mut ct_rgsw_in: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_in, rows, rank); - let mut ct_rgsw_out: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_out, rows, rank); - let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_rgsw.fill_ternary_prob(0, 0.5, &mut source_xs); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) - | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw_out.size()) - | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw_in.size()) - | GGSWCiphertext::keyswitch_scratch_space( - &module, - ct_rgsw_out.size(), - ct_rgsw_in.size(), - ct_grlwe.size(), - ), - ); - - let mut sk0: SecretKey> = SecretKey::new(&module, rank); - sk0.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk0_dft.dft(&module, &sk0); - - let mut sk1: SecretKey> = SecretKey::new(&module, rank); - sk1.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk1_dft.dft(&module, &sk1); - - ct_grlwe.encrypt_sk( - &module, - &sk0.data, - &sk1_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_rgsw_in.encrypt_sk( - &module, - &pt_rgsw, - &sk0_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_rgsw_out.keyswitch(&module, &ct_rgsw_in, &ct_grlwe, scratch.borrow()); - - let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_out, rank); - let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_out); - let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_out.size()); - let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw_out.size()); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_out); - - (0..ct_rgsw_out.rank()).for_each(|col_j| { - (0..ct_rgsw_out.rows()).for_each(|row_i| { - module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw, 0); - - if col_j == 1 { - module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk0_dft, 0); - module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); - module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); - } - - ct_rgsw_out.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); - ct_rlwe_dft.decrypt(&module, &mut pt, &sk1_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); - - let noise_have: f64 = pt.data.std(0, log_base2k).log2(); - let noise_want: f64 = noise_grlwe_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - log_k_grlwe, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.2, - "have: {} want: {}", - noise_have, - noise_want - ); - - pt_want.data.zero(); - }); - }); -} - -#[test] -fn keyswitch_inplace() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rgsw: usize = 45; - let rows: usize = (log_k_rgsw + log_base2k - 1) / log_base2k; - let rank: usize = 1; - - let sigma: f64 = 3.2; - - let mut ct_grlwe: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); - let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw, rows, rank); - let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_rgsw.fill_ternary_prob(0, 0.5, &mut source_xs); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) - | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw.size()) - | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()) - | GGSWCiphertext::keyswitch_inplace_scratch_space(&module, ct_rgsw.size(), ct_grlwe.size()), - ); - - let mut sk0: SecretKey> = SecretKey::new(&module, rank); - sk0.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk0_dft.dft(&module, &sk0); - - let mut sk1: SecretKey> = SecretKey::new(&module, rank); - sk1.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk1_dft.dft(&module, &sk1); - - ct_grlwe.encrypt_sk( - &module, - &sk0.data, - &sk1_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_rgsw.encrypt_sk( - &module, - &pt_rgsw, - &sk0_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_rgsw.keyswitch_inplace(&module, &ct_grlwe, scratch.borrow()); - - let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw, rank); - let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw); - let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw.size()); - let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw.size()); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw); - - (0..ct_rgsw.rank()).for_each(|col_j| { - (0..ct_rgsw.rows()).for_each(|row_i| { - module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw, 0); - - if col_j == 1 { - module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk0_dft, 0); - module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); - module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); - } - - ct_rgsw.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); - ct_rlwe_dft.decrypt(&module, &mut pt, &sk1_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); - - let noise_have: f64 = pt.data.std(0, log_base2k).log2(); - let noise_want: f64 = noise_grlwe_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - log_k_grlwe, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.2, - "have: {} want: {}", - noise_have, - noise_want - ); - - pt_want.data.zero(); - }); - }); -} - -#[test] -fn external_product() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_rgsw_rhs: usize = 60; - let log_k_rgsw_lhs_in: usize = 45; - let log_k_rgsw_lhs_out: usize = 45; - let rows: usize = (log_k_rgsw_lhs_in + log_base2k - 1) / log_base2k; - let rank: usize = 1; - - let sigma: f64 = 3.2; - - let mut ct_rgsw_rhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_rhs, rows, rank); - let mut ct_rgsw_lhs_in: GGSWCiphertext, FFT64> = - GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs_in, rows, rank); - let mut ct_rgsw_lhs_out: GGSWCiphertext, FFT64> = - GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs_out, rows, rank); - let mut pt_rgsw_lhs: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_rgsw_rhs: ScalarZnx> = module.new_scalar_znx(1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_rgsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs); - - let k: usize = 1; - - pt_rgsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_rgsw_rhs.size()) - | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw_lhs_out.size()) - | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw_lhs_in.size()) - | GGSWCiphertext::external_product_scratch_space( - &module, - ct_rgsw_lhs_out.size(), - ct_rgsw_lhs_in.size(), - ct_rgsw_rhs.size(), - ), - ); - - let mut sk: SecretKey> = SecretKey::new(&module, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk_dft.dft(&module, &sk); - - ct_rgsw_rhs.encrypt_sk( - &module, - &pt_rgsw_rhs, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_rgsw_lhs_in.encrypt_sk( - &module, - &pt_rgsw_lhs, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_rgsw_lhs_out.external_product(&module, &ct_rgsw_lhs_in, &ct_rgsw_rhs, scratch.borrow()); - - let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_lhs_out, rank); - let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs_out); - let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_lhs_out.size()); - let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw_lhs_out.size()); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs_out); - - module.vec_znx_rotate_inplace(k as i64, &mut pt_rgsw_lhs, 0); - - (0..ct_rgsw_lhs_out.rank()).for_each(|col_j| { - (0..ct_rgsw_lhs_out.rows()).for_each(|row_i| { - module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw_lhs, 0); - - if col_j == 1 { - module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0); - module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); - module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); - } - - ct_rgsw_lhs_out.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); - ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); - - let noise_have: f64 = pt.data.std(0, log_base2k).log2(); - - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_want: f64 = noise_rgsw_product( - module.n() as f64, - log_base2k, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - log_k_rgsw_lhs_in, - log_k_rgsw_rhs, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "have: {} want: {}", - noise_have, - noise_want - ); - - pt_want.data.zero(); - }); - }); -} - -#[test] -fn external_product_inplace() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_rgsw_rhs: usize = 60; - let log_k_rgsw_lhs: usize = 45; - let rows: usize = (log_k_rgsw_lhs + log_base2k - 1) / log_base2k; - let rank: usize = 1; - - let sigma: f64 = 3.2; - - let mut ct_rgsw_rhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_rhs, rows, rank); - let mut ct_rgsw_lhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs, rows, rank); - let mut pt_rgsw_lhs: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_rgsw_rhs: ScalarZnx> = module.new_scalar_znx(1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_rgsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs); - - let k: usize = 1; - - pt_rgsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_rgsw_rhs.size()) - | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw_lhs.size()) - | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw_lhs.size()) - | GGSWCiphertext::external_product_inplace_scratch_space(&module, ct_rgsw_lhs.size(), ct_rgsw_rhs.size()), - ); - - let mut sk: SecretKey> = SecretKey::new(&module, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk_dft.dft(&module, &sk); - - ct_rgsw_rhs.encrypt_sk( - &module, - &pt_rgsw_rhs, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_rgsw_lhs.encrypt_sk( - &module, - &pt_rgsw_lhs, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_rgsw_lhs.external_product_inplace(&module, &ct_rgsw_rhs, scratch.borrow()); - - let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_lhs, rank); - let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs); - let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_lhs.size()); - let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw_lhs.size()); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs); - - module.vec_znx_rotate_inplace(k as i64, &mut pt_rgsw_lhs, 0); - - (0..ct_rgsw_lhs.rank()).for_each(|col_j| { - (0..ct_rgsw_lhs.rows()).for_each(|row_i| { - module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw_lhs, 0); - - if col_j == 1 { - module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0); - module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); - module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); - } - - ct_rgsw_lhs.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); - ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); - - let noise_have: f64 = pt.data.std(0, log_base2k).log2(); - - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_want: f64 = noise_rgsw_product( - module.n() as f64, - log_base2k, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - log_k_rgsw_lhs, - log_k_rgsw_rhs, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "have: {} want: {}", - noise_have, - noise_want - ); - - pt_want.data.zero(); - }); - }); -} - +// use base2k::{ +// FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, +// VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, VecZnxToMut, ZnxViewMut, ZnxZero, +// }; +// use sampling::source::Source; +// +// use crate::{ +// elem::{GetRow, Infos}, +// ggsw_ciphertext::GGSWCiphertext, +// glwe_ciphertext_fourier::GLWECiphertextFourier, +// glwe_plaintext::GLWEPlaintext, +// keys::{SecretKey, SecretKeyFourier}, +// keyswitch_key::GLWESwitchingKey, +// test_fft64::gglwe::noise_grlwe_rlwe_product, +// }; +// +// #[test] +// fn encrypt_sk() { +// let module: Module = Module::::new(2048); +// let log_base2k: usize = 8; +// let log_k_ct: usize = 54; +// let rows: usize = 4; +// let rank: usize = 1; +// +// let sigma: f64 = 3.2; +// +// let mut ct: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_ct, rows, rank); +// let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); +// let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); +// let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); +// +// let mut source_xs: Source = Source::new([0u8; 32]); +// let mut source_xe: Source = Source::new([0u8; 32]); +// let mut source_xa: Source = Source::new([0u8; 32]); +// +// pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); +// +// let mut scratch: ScratchOwned = ScratchOwned::new( +// GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct.size()) +// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size()), +// ); +// +// let mut sk: SecretKey> = SecretKey::new(&module, rank); +// sk.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk_dft.dft(&module, &sk); +// +// ct.encrypt_sk( +// &module, +// &pt_scalar, +// &sk_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_ct, rank); +// let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct.size()); +// let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct.size()); +// +// (0..ct.rank()).for_each(|col_j| { +// (0..ct.rows()).for_each(|row_i| { +// module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0); +// +// if col_j == 1 { +// module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); +// module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0); +// module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); +// module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); +// } +// +// ct.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); +// +// ct_rlwe_dft.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); +// +// module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); +// +// let std_pt: f64 = pt_have.data.std(0, log_base2k) * (log_k_ct as f64).exp2(); +// assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); +// +// pt_want.data.zero(); +// }); +// }); +// } +// +// #[test] +// fn keyswitch() { +// let module: Module = Module::::new(2048); +// let log_base2k: usize = 12; +// let log_k_grlwe: usize = 60; +// let log_k_rgsw_in: usize = 45; +// let log_k_rgsw_out: usize = 45; +// let rows: usize = (log_k_rgsw_in + log_base2k - 1) / log_base2k; +// +// let rank: usize = 1; +// +// let sigma: f64 = 3.2; +// +// let mut ct_grlwe: GLWESwitchingKey, FFT64> = +// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); +// let mut ct_rgsw_in: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_in, rows, rank); +// let mut ct_rgsw_out: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_out, rows, rank); +// let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); +// +// let mut source_xs: Source = Source::new([0u8; 32]); +// let mut source_xe: Source = Source::new([0u8; 32]); +// let mut source_xa: Source = Source::new([0u8; 32]); +// +// Random input plaintext +// pt_rgsw.fill_ternary_prob(0, 0.5, &mut source_xs); +// +// let mut scratch: ScratchOwned = ScratchOwned::new( +// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) +// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw_out.size()) +// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw_in.size()) +// | GGSWCiphertext::keyswitch_scratch_space( +// &module, +// ct_rgsw_out.size(), +// ct_rgsw_in.size(), +// ct_grlwe.size(), +// ), +// ); +// +// let mut sk0: SecretKey> = SecretKey::new(&module, rank); +// sk0.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk0_dft.dft(&module, &sk0); +// +// let mut sk1: SecretKey> = SecretKey::new(&module, rank); +// sk1.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk1_dft.dft(&module, &sk1); +// +// ct_grlwe.encrypt_sk( +// &module, +// &sk0.data, +// &sk1_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// ct_rgsw_in.encrypt_sk( +// &module, +// &pt_rgsw, +// &sk0_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// ct_rgsw_out.keyswitch(&module, &ct_rgsw_in, &ct_grlwe, scratch.borrow()); +// +// let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = +// GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_out, rank); +// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_out); +// let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_out.size()); +// let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw_out.size()); +// let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_out); +// +// (0..ct_rgsw_out.rank()).for_each(|col_j| { +// (0..ct_rgsw_out.rows()).for_each(|row_i| { +// module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw, 0); +// +// if col_j == 1 { +// module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); +// module.svp_apply_inplace(&mut pt_dft, 0, &sk0_dft, 0); +// module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); +// module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); +// } +// +// ct_rgsw_out.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); +// ct_rlwe_dft.decrypt(&module, &mut pt, &sk1_dft, scratch.borrow()); +// +// module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); +// +// let noise_have: f64 = pt.data.std(0, log_base2k).log2(); +// let noise_want: f64 = noise_grlwe_rlwe_product( +// module.n() as f64, +// log_base2k, +// 0.5, +// 0.5, +// 0f64, +// sigma * sigma, +// 0f64, +// log_k_grlwe, +// log_k_grlwe, +// ); +// +// assert!( +// (noise_have - noise_want).abs() <= 0.2, +// "have: {} want: {}", +// noise_have, +// noise_want +// ); +// +// pt_want.data.zero(); +// }); +// }); +// } +// +// #[test] +// fn keyswitch_inplace() { +// let module: Module = Module::::new(2048); +// let log_base2k: usize = 12; +// let log_k_grlwe: usize = 60; +// let log_k_rgsw: usize = 45; +// let rows: usize = (log_k_rgsw + log_base2k - 1) / log_base2k; +// let rank: usize = 1; +// +// let sigma: f64 = 3.2; +// +// let mut ct_grlwe: GLWESwitchingKey, FFT64> = +// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); +// let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw, rows, rank); +// let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); +// +// let mut source_xs: Source = Source::new([0u8; 32]); +// let mut source_xe: Source = Source::new([0u8; 32]); +// let mut source_xa: Source = Source::new([0u8; 32]); +// +// Random input plaintext +// pt_rgsw.fill_ternary_prob(0, 0.5, &mut source_xs); +// +// let mut scratch: ScratchOwned = ScratchOwned::new( +// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) +// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw.size()) +// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()) +// | GGSWCiphertext::keyswitch_inplace_scratch_space(&module, ct_rgsw.size(), ct_grlwe.size()), +// ); +// +// let mut sk0: SecretKey> = SecretKey::new(&module, rank); +// sk0.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk0_dft.dft(&module, &sk0); +// +// let mut sk1: SecretKey> = SecretKey::new(&module, rank); +// sk1.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk1_dft.dft(&module, &sk1); +// +// ct_grlwe.encrypt_sk( +// &module, +// &sk0.data, +// &sk1_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// ct_rgsw.encrypt_sk( +// &module, +// &pt_rgsw, +// &sk0_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// ct_rgsw.keyswitch_inplace(&module, &ct_grlwe, scratch.borrow()); +// +// let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = +// GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw, rank); +// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw); +// let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw.size()); +// let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw.size()); +// let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw); +// +// (0..ct_rgsw.rank()).for_each(|col_j| { +// (0..ct_rgsw.rows()).for_each(|row_i| { +// module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw, 0); +// +// if col_j == 1 { +// module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); +// module.svp_apply_inplace(&mut pt_dft, 0, &sk0_dft, 0); +// module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); +// module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); +// } +// +// ct_rgsw.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); +// ct_rlwe_dft.decrypt(&module, &mut pt, &sk1_dft, scratch.borrow()); +// +// module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); +// +// let noise_have: f64 = pt.data.std(0, log_base2k).log2(); +// let noise_want: f64 = noise_grlwe_rlwe_product( +// module.n() as f64, +// log_base2k, +// 0.5, +// 0.5, +// 0f64, +// sigma * sigma, +// 0f64, +// log_k_grlwe, +// log_k_grlwe, +// ); +// +// assert!( +// (noise_have - noise_want).abs() <= 0.2, +// "have: {} want: {}", +// noise_have, +// noise_want +// ); +// +// pt_want.data.zero(); +// }); +// }); +// } +// +// #[test] +// fn external_product() { +// let module: Module = Module::::new(2048); +// let log_base2k: usize = 12; +// let log_k_rgsw_rhs: usize = 60; +// let log_k_rgsw_lhs_in: usize = 45; +// let log_k_rgsw_lhs_out: usize = 45; +// let rows: usize = (log_k_rgsw_lhs_in + log_base2k - 1) / log_base2k; +// let rank: usize = 1; +// +// let sigma: f64 = 3.2; +// +// let mut ct_rgsw_rhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_rhs, rows, rank); +// let mut ct_rgsw_lhs_in: GGSWCiphertext, FFT64> = +// GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs_in, rows, rank); +// let mut ct_rgsw_lhs_out: GGSWCiphertext, FFT64> = +// GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs_out, rows, rank); +// let mut pt_rgsw_lhs: ScalarZnx> = module.new_scalar_znx(1); +// let mut pt_rgsw_rhs: ScalarZnx> = module.new_scalar_znx(1); +// +// let mut source_xs: Source = Source::new([0u8; 32]); +// let mut source_xe: Source = Source::new([0u8; 32]); +// let mut source_xa: Source = Source::new([0u8; 32]); +// +// Random input plaintext +// pt_rgsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs); +// +// let k: usize = 1; +// +// pt_rgsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} +// +// let mut scratch: ScratchOwned = ScratchOwned::new( +// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_rgsw_rhs.size()) +// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw_lhs_out.size()) +// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw_lhs_in.size()) +// | GGSWCiphertext::external_product_scratch_space( +// &module, +// ct_rgsw_lhs_out.size(), +// ct_rgsw_lhs_in.size(), +// ct_rgsw_rhs.size(), +// ), +// ); +// +// let mut sk: SecretKey> = SecretKey::new(&module, rank); +// sk.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk_dft.dft(&module, &sk); +// +// ct_rgsw_rhs.encrypt_sk( +// &module, +// &pt_rgsw_rhs, +// &sk_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// ct_rgsw_lhs_in.encrypt_sk( +// &module, +// &pt_rgsw_lhs, +// &sk_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// ct_rgsw_lhs_out.external_product(&module, &ct_rgsw_lhs_in, &ct_rgsw_rhs, scratch.borrow()); +// +// let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = +// GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_lhs_out, rank); +// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs_out); +// let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_lhs_out.size()); +// let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw_lhs_out.size()); +// let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs_out); +// +// module.vec_znx_rotate_inplace(k as i64, &mut pt_rgsw_lhs, 0); +// +// (0..ct_rgsw_lhs_out.rank()).for_each(|col_j| { +// (0..ct_rgsw_lhs_out.rows()).for_each(|row_i| { +// module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw_lhs, 0); +// +// if col_j == 1 { +// module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); +// module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0); +// module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); +// module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); +// } +// +// ct_rgsw_lhs_out.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); +// ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); +// +// module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); +// +// let noise_have: f64 = pt.data.std(0, log_base2k).log2(); +// +// let var_gct_err_lhs: f64 = sigma * sigma; +// let var_gct_err_rhs: f64 = 0f64; +// +// let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} +// let var_a0_err: f64 = sigma * sigma; +// let var_a1_err: f64 = 1f64 / 12f64; +// +// let noise_want: f64 = noise_rgsw_product( +// module.n() as f64, +// log_base2k, +// 0.5, +// var_msg, +// var_a0_err, +// var_a1_err, +// var_gct_err_lhs, +// var_gct_err_rhs, +// log_k_rgsw_lhs_in, +// log_k_rgsw_rhs, +// ); +// +// assert!( +// (noise_have - noise_want).abs() <= 0.1, +// "have: {} want: {}", +// noise_have, +// noise_want +// ); +// +// pt_want.data.zero(); +// }); +// }); +// } +// +// #[test] +// fn external_product_inplace() { +// let module: Module = Module::::new(2048); +// let log_base2k: usize = 12; +// let log_k_rgsw_rhs: usize = 60; +// let log_k_rgsw_lhs: usize = 45; +// let rows: usize = (log_k_rgsw_lhs + log_base2k - 1) / log_base2k; +// let rank: usize = 1; +// +// let sigma: f64 = 3.2; +// +// let mut ct_rgsw_rhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_rhs, rows, rank); +// let mut ct_rgsw_lhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs, rows, rank); +// let mut pt_rgsw_lhs: ScalarZnx> = module.new_scalar_znx(1); +// let mut pt_rgsw_rhs: ScalarZnx> = module.new_scalar_znx(1); +// +// let mut source_xs: Source = Source::new([0u8; 32]); +// let mut source_xe: Source = Source::new([0u8; 32]); +// let mut source_xa: Source = Source::new([0u8; 32]); +// +// Random input plaintext +// pt_rgsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs); +// +// let k: usize = 1; +// +// pt_rgsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} +// +// let mut scratch: ScratchOwned = ScratchOwned::new( +// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_rgsw_rhs.size()) +// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw_lhs.size()) +// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw_lhs.size()) +// | GGSWCiphertext::external_product_inplace_scratch_space(&module, ct_rgsw_lhs.size(), ct_rgsw_rhs.size()), +// ); +// +// let mut sk: SecretKey> = SecretKey::new(&module, rank); +// sk.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk_dft.dft(&module, &sk); +// +// ct_rgsw_rhs.encrypt_sk( +// &module, +// &pt_rgsw_rhs, +// &sk_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// ct_rgsw_lhs.encrypt_sk( +// &module, +// &pt_rgsw_lhs, +// &sk_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// ct_rgsw_lhs.external_product_inplace(&module, &ct_rgsw_rhs, scratch.borrow()); +// +// let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = +// GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_lhs, rank); +// let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs); +// let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_lhs.size()); +// let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw_lhs.size()); +// let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs); +// +// module.vec_znx_rotate_inplace(k as i64, &mut pt_rgsw_lhs, 0); +// +// (0..ct_rgsw_lhs.rank()).for_each(|col_j| { +// (0..ct_rgsw_lhs.rows()).for_each(|row_i| { +// module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw_lhs, 0); +// +// if col_j == 1 { +// module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); +// module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0); +// module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); +// module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); +// } +// +// ct_rgsw_lhs.get_row(&module, row_i, col_j, &mut ct_rlwe_dft); +// ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); +// +// module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); +// +// let noise_have: f64 = pt.data.std(0, log_base2k).log2(); +// +// let var_gct_err_lhs: f64 = sigma * sigma; +// let var_gct_err_rhs: f64 = 0f64; +// +// let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} +// let var_a0_err: f64 = sigma * sigma; +// let var_a1_err: f64 = 1f64 / 12f64; +// +// let noise_want: f64 = noise_rgsw_product( +// module.n() as f64, +// log_base2k, +// 0.5, +// var_msg, +// var_a0_err, +// var_a1_err, +// var_gct_err_lhs, +// var_gct_err_rhs, +// log_k_rgsw_lhs, +// log_k_rgsw_rhs, +// ); +// +// assert!( +// (noise_have - noise_want).abs() <= 0.1, +// "have: {} want: {}", +// noise_have, +// noise_want +// ); +// +// pt_want.data.zero(); +// }); +// }); +// } pub(crate) fn noise_rgsw_product( n: f64, log_base2k: usize, diff --git a/core/src/test_fft64/glwe.rs b/core/src/test_fft64/glwe.rs index dca899b..5f2c876 100644 --- a/core/src/test_fft64/glwe.rs +++ b/core/src/test_fft64/glwe.rs @@ -13,7 +13,7 @@ use crate::{ glwe_plaintext::GLWEPlaintext, keys::{GLWEPublicKey, SecretKey, SecretKeyFourier}, keyswitch_key::GLWESwitchingKey, - test_fft64::{gglwe::noise_grlwe_rlwe_product, ggsw::noise_rgsw_product}, + test_fft64::{gglwe::noise_gglwe_product, ggsw::noise_rgsw_product}, }; #[test] @@ -197,21 +197,32 @@ fn test_encrypt_pk(log_n: usize, basek: usize, k_ct: usize, k_pk: usize, sigma: #[test] fn keyswitch() { - let module: Module = Module::::new(2048); - let basek: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe_in: usize = 45; - let log_k_rlwe_out: usize = 60; - let rows: usize = (log_k_rlwe_in + basek - 1) / basek; - let rank: usize = 1; + (1..4).for_each(|rank_in| { + (1..4).for_each(|rank_out| { + println!("test keyswitch rank_in: {} rank_out: {}", rank_in, rank_out); + test_keyswitch(12, 12, 60, 45, 60, rank_in, rank_out, 3.2); + }); + }); +} - let sigma: f64 = 3.2; +fn test_keyswitch( + log_n: usize, + basek: usize, + k_keyswitch: usize, + k_ct_in: usize, + k_ct_out: usize, + rank_in: usize, + rank_out: usize, + sigma: f64, +) { + let module: Module = Module::::new(1 << log_n); + let rows: usize = (k_ct_in + basek - 1) / basek; - let mut ct_grlwe: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, log_k_grlwe, rows, rank, rank); - let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, basek, log_k_rlwe_in, rank); - let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, basek, log_k_rlwe_out, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, log_k_rlwe_in); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, log_k_rlwe_out); + let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, k_keyswitch, rows, rank_in, rank_out); + let mut ct_in: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct_in, rank_in); + let mut ct_out: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct_out, rank_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct_out); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -223,57 +234,59 @@ fn keyswitch() { .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) - | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe_out.size()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe_in.size()) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_in, ksk.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct_out.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, rank_out, ct_in.size()) | GLWECiphertext::keyswitch_scratch_space( &module, - ct_rlwe_out.size(), - ct_rlwe_in.size(), - ct_grlwe.size(), + ct_out.size(), + ct_in.size(), + ksk.size(), + rank_in, + rank_out, ), ); - let mut sk0: SecretKey> = SecretKey::new(&module, rank); - sk0.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_in: SecretKey> = SecretKey::new(&module, rank_in); + sk_in.fill_ternary_prob(0.5, &mut source_xs); - let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk0_dft.dft(&module, &sk0); + let mut sk_in_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_in); + sk_in_dft.dft(&module, &sk_in); - let mut sk1: SecretKey> = SecretKey::new(&module, rank); - sk1.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_out: SecretKey> = SecretKey::new(&module, rank_out); + sk_out.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk1_dft.dft(&module, &sk1); + let mut sk_out_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_out); + sk_out_dft.dft(&module, &sk_out); - ct_grlwe.encrypt_sk( + ksk.encrypt_sk( &module, - &sk0.data, - &sk1_dft, + &sk_in, + &sk_out_dft, &mut source_xa, &mut source_xe, sigma, scratch.borrow(), ); - ct_rlwe_in.encrypt_sk( + ct_in.encrypt_sk( &module, &pt_want, - &sk0_dft, + &sk_in_dft, &mut source_xa, &mut source_xe, sigma, scratch.borrow(), ); - ct_rlwe_out.keyswitch(&module, &ct_rlwe_in, &ct_grlwe, scratch.borrow()); + ct_out.keyswitch(&module, &ct_in, &ksk, scratch.borrow()); - ct_rlwe_out.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); + ct_out.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); let noise_have: f64 = pt_have.data.std(0, basek).log2(); - let noise_want: f64 = noise_grlwe_rlwe_product( + let noise_want: f64 = noise_gglwe_product( module.n() as f64, basek, 0.5, @@ -281,8 +294,9 @@ fn keyswitch() { 0f64, sigma * sigma, 0f64, - log_k_rlwe_in, - log_k_grlwe, + rank_in as f64, + k_ct_in, + k_keyswitch, ); assert!( @@ -322,7 +336,7 @@ fn keyswich_inplace() { GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size()) | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe.size()) - | GLWECiphertext::keyswitch_inplace_scratch_space(&module, ct_rlwe.size(), ct_grlwe.size()), + | GLWECiphertext::keyswitch_inplace_scratch_space(&module, ct_rlwe.size(), ct_grlwe.size(), rank), ); let mut sk0: SecretKey> = SecretKey::new(&module, rank); @@ -339,7 +353,7 @@ fn keyswich_inplace() { ct_grlwe.encrypt_sk( &module, - &sk0.data, + &sk0, &sk1_dft, &mut source_xa, &mut source_xe, @@ -364,7 +378,7 @@ fn keyswich_inplace() { module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); let noise_have: f64 = pt_have.data.std(0, basek).log2(); - let noise_want: f64 = noise_grlwe_rlwe_product( + let noise_want: f64 = noise_gglwe_product( module.n() as f64, basek, 0.5, @@ -372,6 +386,7 @@ fn keyswich_inplace() { 0f64, sigma * sigma, 0f64, + rank as f64, log_k_rlwe, log_k_grlwe, ); @@ -427,6 +442,7 @@ fn external_product() { ct_rlwe_out.size(), ct_rlwe_in.size(), ct_rgsw.size(), + rank, ), ); @@ -531,7 +547,7 @@ fn external_product_inplace() { GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()) | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size()) | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe.size()) - | GLWECiphertext::external_product_inplace_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size()), + | GLWECiphertext::external_product_inplace_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size(), rank), ); let mut sk: SecretKey> = SecretKey::new(&module, rank); diff --git a/core/src/test_fft64/glwe_fourier.rs b/core/src/test_fft64/glwe_fourier.rs index 16f9eca..f25bac9 100644 --- a/core/src/test_fft64/glwe_fourier.rs +++ b/core/src/test_fft64/glwe_fourier.rs @@ -1,438 +1,438 @@ -use crate::{ - elem::Infos, - ggsw_ciphertext::GGSWCiphertext, - glwe_ciphertext::GLWECiphertext, - glwe_ciphertext_fourier::GLWECiphertextFourier, - glwe_plaintext::GLWEPlaintext, - keys::{SecretKey, SecretKeyFourier}, - keyswitch_key::GLWESwitchingKey, - test_fft64::{gglwe::noise_grlwe_rlwe_product, ggsw::noise_rgsw_product}, -}; -use base2k::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, ZnxViewMut}; -use sampling::source::Source; - -#[test] -fn keyswitch() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe_in: usize = 45; - let log_k_rlwe_out: usize = 60; - let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; - - let rank: usize = 1; - - let sigma: f64 = 3.2; - - let mut ct_grlwe: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); - let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in, rank); - let mut ct_rlwe_in_dft: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_in, rank); - let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_out, rank); - let mut ct_rlwe_out_dft: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_out, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) - | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe_out.size()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe_in.size()) - | GLWECiphertextFourier::keyswitch_scratch_space( - &module, - ct_rlwe_out.size(), - ct_rlwe_in.size(), - ct_grlwe.size(), - ), - ); - - let mut sk0: SecretKey> = SecretKey::new(&module, rank); - sk0.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk0_dft.dft(&module, &sk0); - - let mut sk1: SecretKey> = SecretKey::new(&module, rank); - sk1.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk1_dft.dft(&module, &sk1); - - ct_grlwe.encrypt_sk( - &module, - &sk0.data, - &sk1_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_rlwe_in.encrypt_sk( - &module, - &pt_want, - &sk0_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_rlwe_in.dft(&module, &mut ct_rlwe_in_dft); - ct_rlwe_out_dft.keyswitch(&module, &ct_rlwe_in_dft, &ct_grlwe, scratch.borrow()); - ct_rlwe_out_dft.idft(&module, &mut ct_rlwe_out, scratch.borrow()); - - ct_rlwe_out.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); - let noise_want: f64 = noise_grlwe_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - log_k_rlwe_in, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); -} - -#[test] -fn keyswich_inplace() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe: usize = 45; - let rows: usize = (log_k_rlwe + log_base2k - 1) / log_base2k; - let rank: usize = 1; - - let sigma: f64 = 3.2; - - let mut ct_grlwe: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); - let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe, rank); - let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) - | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe.size()) - | GLWECiphertextFourier::keyswitch_inplace_scratch_space(&module, ct_rlwe_dft.size(), ct_grlwe.size()), - ); - - let mut sk0: SecretKey> = SecretKey::new(&module, rank); - sk0.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk0_dft.dft(&module, &sk0); - - let mut sk1: SecretKey> = SecretKey::new(&module, rank); - sk1.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk1_dft.dft(&module, &sk1); - - ct_grlwe.encrypt_sk( - &module, - &sk0.data, - &sk1_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_rlwe.encrypt_sk( - &module, - &pt_want, - &sk0_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_rlwe.dft(&module, &mut ct_rlwe_dft); - ct_rlwe_dft.keyswitch_inplace(&module, &ct_grlwe, scratch.borrow()); - ct_rlwe_dft.idft(&module, &mut ct_rlwe, scratch.borrow()); - - ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); - let noise_want: f64 = noise_grlwe_rlwe_product( - module.n() as f64, - log_base2k, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - log_k_rlwe, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); -} - -#[test] -fn external_product() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe_in: usize = 45; - let log_k_rlwe_out: usize = 60; - let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; - let rank: usize = 1; - - let sigma: f64 = 3.2; - - let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); - let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in, rank); - let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_out, rank); - let mut ct_rlwe_dft_in: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_in, rank); - let mut ct_rlwe_dft_out: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_out, rank); - let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); - - pt_want.to_mut().at_mut(0, 0)[1] = 1; - - let k: usize = 1; - - pt_rgsw.raw_mut()[k] = 1; // X^{k} - - let mut scratch: ScratchOwned = ScratchOwned::new( - GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()) - | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe_out.size()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe_in.size()) - | GLWECiphertext::external_product_scratch_space( - &module, - ct_rlwe_out.size(), - ct_rlwe_in.size(), - ct_rgsw.size(), - ), - ); - - let mut sk: SecretKey> = SecretKey::new(&module, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk_dft.dft(&module, &sk); - - ct_rgsw.encrypt_sk( - &module, - &pt_rgsw, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_rlwe_in.encrypt_sk( - &module, - &pt_want, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_rlwe_in.dft(&module, &mut ct_rlwe_dft_in); - ct_rlwe_dft_out.external_product(&module, &ct_rlwe_dft_in, &ct_rgsw, scratch.borrow()); - ct_rlwe_dft_out.idft(&module, &mut ct_rlwe_out, scratch.borrow()); - - ct_rlwe_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); - - module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); - - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); - - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_want: f64 = noise_rgsw_product( - module.n() as f64, - log_base2k, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - log_k_rlwe_in, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); -} - -#[test] -fn external_product_inplace() { - let module: Module = Module::::new(2048); - let log_base2k: usize = 12; - let log_k_grlwe: usize = 60; - let log_k_rlwe_in: usize = 45; - let log_k_rlwe_out: usize = 60; - let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; - let rank: usize = 1; - - let sigma: f64 = 3.2; - - let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); - let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in, rank); - let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_in, rank); - let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); - - pt_want.to_mut().at_mut(0, 0)[1] = 1; - - let k: usize = 1; - - pt_rgsw.raw_mut()[k] = 1; // X^{k} - - let mut scratch: ScratchOwned = ScratchOwned::new( - GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()) - | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe.size()) - | GLWECiphertext::external_product_inplace_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size()), - ); - - let mut sk: SecretKey> = SecretKey::new(&module, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk_dft.dft(&module, &sk); - - ct_rgsw.encrypt_sk( - &module, - &pt_rgsw, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_rlwe.encrypt_sk( - &module, - &pt_want, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_rlwe.dft(&module, &mut ct_rlwe_dft); - ct_rlwe_dft.external_product_inplace(&module, &ct_rgsw, scratch.borrow()); - ct_rlwe_dft.idft(&module, &mut ct_rlwe, scratch.borrow()); - - ct_rlwe.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); - - module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); - - module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); - - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_want: f64 = noise_rgsw_product( - module.n() as f64, - log_base2k, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - log_k_rlwe_in, - log_k_grlwe, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.1, - "{} {}", - noise_have, - noise_want - ); -} +// use crate::{ +// elem::Infos, +// ggsw_ciphertext::GGSWCiphertext, +// glwe_ciphertext::GLWECiphertext, +// glwe_ciphertext_fourier::GLWECiphertextFourier, +// glwe_plaintext::GLWEPlaintext, +// keys::{SecretKey, SecretKeyFourier}, +// keyswitch_key::GLWESwitchingKey, +// test_fft64::{gglwe::noise_grlwe_rlwe_product, ggsw::noise_rgsw_product}, +// }; +// use base2k::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, ZnxViewMut}; +// use sampling::source::Source; +// +// #[test] +// fn keyswitch() { +// let module: Module = Module::::new(2048); +// let log_base2k: usize = 12; +// let log_k_grlwe: usize = 60; +// let log_k_rlwe_in: usize = 45; +// let log_k_rlwe_out: usize = 60; +// let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; +// +// let rank: usize = 1; +// +// let sigma: f64 = 3.2; +// +// let mut ct_grlwe: GLWESwitchingKey, FFT64> = +// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); +// let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in, rank); +// let mut ct_rlwe_in_dft: GLWECiphertextFourier, FFT64> = +// GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_in, rank); +// let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_out, rank); +// let mut ct_rlwe_out_dft: GLWECiphertextFourier, FFT64> = +// GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_out, rank); +// let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); +// let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); +// +// let mut source_xs: Source = Source::new([0u8; 32]); +// let mut source_xe: Source = Source::new([0u8; 32]); +// let mut source_xa: Source = Source::new([0u8; 32]); +// +// Random input plaintext +// pt_want +// .data +// .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); +// +// let mut scratch: ScratchOwned = ScratchOwned::new( +// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) +// | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe_out.size()) +// | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe_in.size()) +// | GLWECiphertextFourier::keyswitch_scratch_space( +// &module, +// ct_rlwe_out.size(), +// ct_rlwe_in.size(), +// ct_grlwe.size(), +// ), +// ); +// +// let mut sk0: SecretKey> = SecretKey::new(&module, rank); +// sk0.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk0_dft.dft(&module, &sk0); +// +// let mut sk1: SecretKey> = SecretKey::new(&module, rank); +// sk1.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk1_dft.dft(&module, &sk1); +// +// ct_grlwe.encrypt_sk( +// &module, +// &sk0.data, +// &sk1_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// ct_rlwe_in.encrypt_sk( +// &module, +// &pt_want, +// &sk0_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// ct_rlwe_in.dft(&module, &mut ct_rlwe_in_dft); +// ct_rlwe_out_dft.keyswitch(&module, &ct_rlwe_in_dft, &ct_grlwe, scratch.borrow()); +// ct_rlwe_out_dft.idft(&module, &mut ct_rlwe_out, scratch.borrow()); +// +// ct_rlwe_out.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); +// +// module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); +// +// let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); +// let noise_want: f64 = noise_grlwe_rlwe_product( +// module.n() as f64, +// log_base2k, +// 0.5, +// 0.5, +// 0f64, +// sigma * sigma, +// 0f64, +// log_k_rlwe_in, +// log_k_grlwe, +// ); +// +// assert!( +// (noise_have - noise_want).abs() <= 0.1, +// "{} {}", +// noise_have, +// noise_want +// ); +// } +// +// #[test] +// fn keyswich_inplace() { +// let module: Module = Module::::new(2048); +// let log_base2k: usize = 12; +// let log_k_grlwe: usize = 60; +// let log_k_rlwe: usize = 45; +// let rows: usize = (log_k_rlwe + log_base2k - 1) / log_base2k; +// let rank: usize = 1; +// +// let sigma: f64 = 3.2; +// +// let mut ct_grlwe: GLWESwitchingKey, FFT64> = +// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); +// let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe, rank); +// let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = +// GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe, rank); +// let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe); +// let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe); +// +// let mut source_xs: Source = Source::new([0u8; 32]); +// let mut source_xe: Source = Source::new([0u8; 32]); +// let mut source_xa: Source = Source::new([0u8; 32]); +// +// Random input plaintext +// pt_want +// .data +// .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); +// +// let mut scratch: ScratchOwned = ScratchOwned::new( +// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) +// | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size()) +// | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe.size()) +// | GLWECiphertextFourier::keyswitch_inplace_scratch_space(&module, ct_rlwe_dft.size(), ct_grlwe.size()), +// ); +// +// let mut sk0: SecretKey> = SecretKey::new(&module, rank); +// sk0.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk0_dft.dft(&module, &sk0); +// +// let mut sk1: SecretKey> = SecretKey::new(&module, rank); +// sk1.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk1_dft.dft(&module, &sk1); +// +// ct_grlwe.encrypt_sk( +// &module, +// &sk0.data, +// &sk1_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// ct_rlwe.encrypt_sk( +// &module, +// &pt_want, +// &sk0_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// ct_rlwe.dft(&module, &mut ct_rlwe_dft); +// ct_rlwe_dft.keyswitch_inplace(&module, &ct_grlwe, scratch.borrow()); +// ct_rlwe_dft.idft(&module, &mut ct_rlwe, scratch.borrow()); +// +// ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); +// +// module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); +// +// let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); +// let noise_want: f64 = noise_grlwe_rlwe_product( +// module.n() as f64, +// log_base2k, +// 0.5, +// 0.5, +// 0f64, +// sigma * sigma, +// 0f64, +// log_k_rlwe, +// log_k_grlwe, +// ); +// +// assert!( +// (noise_have - noise_want).abs() <= 0.1, +// "{} {}", +// noise_have, +// noise_want +// ); +// } +// +// #[test] +// fn external_product() { +// let module: Module = Module::::new(2048); +// let log_base2k: usize = 12; +// let log_k_grlwe: usize = 60; +// let log_k_rlwe_in: usize = 45; +// let log_k_rlwe_out: usize = 60; +// let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; +// let rank: usize = 1; +// +// let sigma: f64 = 3.2; +// +// let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); +// let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in, rank); +// let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_out, rank); +// let mut ct_rlwe_dft_in: GLWECiphertextFourier, FFT64> = +// GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_in, rank); +// let mut ct_rlwe_dft_out: GLWECiphertextFourier, FFT64> = +// GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_out, rank); +// let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); +// let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); +// let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); +// +// let mut source_xs: Source = Source::new([0u8; 32]); +// let mut source_xe: Source = Source::new([0u8; 32]); +// let mut source_xa: Source = Source::new([0u8; 32]); +// +// Random input plaintext +// pt_want +// .data +// .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); +// +// pt_want.to_mut().at_mut(0, 0)[1] = 1; +// +// let k: usize = 1; +// +// pt_rgsw.raw_mut()[k] = 1; // X^{k} +// +// let mut scratch: ScratchOwned = ScratchOwned::new( +// GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()) +// | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe_out.size()) +// | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe_in.size()) +// | GLWECiphertext::external_product_scratch_space( +// &module, +// ct_rlwe_out.size(), +// ct_rlwe_in.size(), +// ct_rgsw.size(), +// ), +// ); +// +// let mut sk: SecretKey> = SecretKey::new(&module, rank); +// sk.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk_dft.dft(&module, &sk); +// +// ct_rgsw.encrypt_sk( +// &module, +// &pt_rgsw, +// &sk_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// ct_rlwe_in.encrypt_sk( +// &module, +// &pt_want, +// &sk_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// ct_rlwe_in.dft(&module, &mut ct_rlwe_dft_in); +// ct_rlwe_dft_out.external_product(&module, &ct_rlwe_dft_in, &ct_rgsw, scratch.borrow()); +// ct_rlwe_dft_out.idft(&module, &mut ct_rlwe_out, scratch.borrow()); +// +// ct_rlwe_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); +// +// module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); +// +// module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); +// +// let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); +// +// let var_gct_err_lhs: f64 = sigma * sigma; +// let var_gct_err_rhs: f64 = 0f64; +// +// let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} +// let var_a0_err: f64 = sigma * sigma; +// let var_a1_err: f64 = 1f64 / 12f64; +// +// let noise_want: f64 = noise_rgsw_product( +// module.n() as f64, +// log_base2k, +// 0.5, +// var_msg, +// var_a0_err, +// var_a1_err, +// var_gct_err_lhs, +// var_gct_err_rhs, +// log_k_rlwe_in, +// log_k_grlwe, +// ); +// +// assert!( +// (noise_have - noise_want).abs() <= 0.1, +// "{} {}", +// noise_have, +// noise_want +// ); +// } +// +// #[test] +// fn external_product_inplace() { +// let module: Module = Module::::new(2048); +// let log_base2k: usize = 12; +// let log_k_grlwe: usize = 60; +// let log_k_rlwe_in: usize = 45; +// let log_k_rlwe_out: usize = 60; +// let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; +// let rank: usize = 1; +// +// let sigma: f64 = 3.2; +// +// let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); +// let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in, rank); +// let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = +// GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_in, rank); +// let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); +// let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); +// let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); +// +// let mut source_xs: Source = Source::new([0u8; 32]); +// let mut source_xe: Source = Source::new([0u8; 32]); +// let mut source_xa: Source = Source::new([0u8; 32]); +// +// Random input plaintext +// pt_want +// .data +// .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); +// +// pt_want.to_mut().at_mut(0, 0)[1] = 1; +// +// let k: usize = 1; +// +// pt_rgsw.raw_mut()[k] = 1; // X^{k} +// +// let mut scratch: ScratchOwned = ScratchOwned::new( +// GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()) +// | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size()) +// | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe.size()) +// | GLWECiphertext::external_product_inplace_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size()), +// ); +// +// let mut sk: SecretKey> = SecretKey::new(&module, rank); +// sk.fill_ternary_prob(0.5, &mut source_xs); +// +// let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); +// sk_dft.dft(&module, &sk); +// +// ct_rgsw.encrypt_sk( +// &module, +// &pt_rgsw, +// &sk_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// ct_rlwe.encrypt_sk( +// &module, +// &pt_want, +// &sk_dft, +// &mut source_xa, +// &mut source_xe, +// sigma, +// scratch.borrow(), +// ); +// +// ct_rlwe.dft(&module, &mut ct_rlwe_dft); +// ct_rlwe_dft.external_product_inplace(&module, &ct_rgsw, scratch.borrow()); +// ct_rlwe_dft.idft(&module, &mut ct_rlwe, scratch.borrow()); +// +// ct_rlwe.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); +// +// module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); +// +// module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); +// +// let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); +// +// let var_gct_err_lhs: f64 = sigma * sigma; +// let var_gct_err_rhs: f64 = 0f64; +// +// let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} +// let var_a0_err: f64 = sigma * sigma; +// let var_a1_err: f64 = 1f64 / 12f64; +// +// let noise_want: f64 = noise_rgsw_product( +// module.n() as f64, +// log_base2k, +// 0.5, +// var_msg, +// var_a0_err, +// var_a1_err, +// var_gct_err_lhs, +// var_gct_err_rhs, +// log_k_rlwe_in, +// log_k_grlwe, +// ); +// +// assert!( +// (noise_have - noise_want).abs() <= 0.1, +// "{} {}", +// noise_have, +// noise_want +// ); +// } diff --git a/core/src/vec_glwe_product.rs b/core/src/vec_glwe_product.rs index 63c4769..d3e6636 100644 --- a/core/src/vec_glwe_product.rs +++ b/core/src/vec_glwe_product.rs @@ -10,31 +10,53 @@ use crate::{ }; pub(crate) trait VecGLWEProductScratchSpace { - fn prod_with_glwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize; + fn prod_with_glwe_scratch_space( + module: &Module, + res_size: usize, + lhs: usize, + rhs: usize, + rank_in: usize, + rank_out: usize, + ) -> usize; - fn prod_with_glwe_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { - Self::prod_with_glwe_scratch_space(module, res_size, res_size, rhs) + fn prod_with_glwe_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { + Self::prod_with_glwe_scratch_space(module, res_size, res_size, rhs, rank, rank) } - fn prod_with_glwe_dft_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - (Self::prod_with_glwe_scratch_space(module, res_size, lhs, rhs) | module.vec_znx_idft_tmp_bytes()) - + module.bytes_of_vec_znx(2, lhs) - + module.bytes_of_vec_znx(2, res_size) + fn prod_with_glwe_dft_scratch_space( + module: &Module, + res_size: usize, + lhs: usize, + rhs: usize, + rank_in: usize, + rank_out: usize, + ) -> usize { + (Self::prod_with_glwe_scratch_space(module, res_size, lhs, rhs, rank_in, rank_out) | module.vec_znx_idft_tmp_bytes()) + + module.bytes_of_vec_znx(rank_in, lhs) + + module.bytes_of_vec_znx(rank_out, res_size) } - fn prod_with_glwe_dft_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { - (Self::prod_with_glwe_inplace_scratch_space(module, res_size, rhs) | module.vec_znx_idft_tmp_bytes()) - + module.bytes_of_vec_znx(2, res_size) + fn prod_with_glwe_dft_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { + (Self::prod_with_glwe_inplace_scratch_space(module, res_size, rhs, rank) | module.vec_znx_idft_tmp_bytes()) + + module.bytes_of_vec_znx(rank + 1, res_size) } - fn prod_with_vec_glwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - Self::prod_with_glwe_dft_scratch_space(module, res_size, lhs, rhs) - + module.bytes_of_vec_znx_dft(2, lhs) - + module.bytes_of_vec_znx_dft(2, res_size) + fn prod_with_vec_glwe_scratch_space( + module: &Module, + res_size: usize, + lhs: usize, + rhs: usize, + rank_in: usize, + rank_out: usize, + ) -> usize { + Self::prod_with_glwe_dft_scratch_space(module, res_size, lhs, rhs, rank_in, rank_out) + + module.bytes_of_vec_znx_dft(rank_in + 1, lhs) + + module.bytes_of_vec_znx_dft(rank_out + 1, res_size) } - fn prod_with_vec_glwe_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { - Self::prod_with_glwe_dft_inplace_scratch_space(module, res_size, rhs) + module.bytes_of_vec_znx_dft(2, res_size) + fn prod_with_vec_glwe_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize, rank: usize) -> usize { + Self::prod_with_glwe_dft_inplace_scratch_space(module, res_size, rhs, rank) + + module.bytes_of_vec_znx_dft(rank + 1, res_size) } } @@ -78,7 +100,7 @@ pub(crate) trait VecGLWEProduct: Infos { assert_eq!(res.n(), module.n()); } - let (a_data, scratch_1) = scratch.tmp_vec_znx(module, 2, a.size()); + let (a_data, scratch_1) = scratch.tmp_vec_znx(module, a.rank() + 1, a.size()); let mut a_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { data: a_data, @@ -88,7 +110,7 @@ pub(crate) trait VecGLWEProduct: Infos { a.idft(module, &mut a_idft, scratch_1); - let (res_data, scratch_2) = scratch_1.tmp_vec_znx(module, 2, res.size()); + let (res_data, scratch_2) = scratch_1.tmp_vec_znx(module, res.rank() + 1, res.size()); let mut res_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { data: res_data, @@ -98,8 +120,7 @@ pub(crate) trait VecGLWEProduct: Infos { self.prod_with_glwe(module, &mut res_idft, &a_idft, scratch_2); - module.vec_znx_dft(res, 0, &res_idft, 0); - module.vec_znx_dft(res, 1, &res_idft, 1); + res_idft.dft(module, res); } fn prod_with_glwe_fourier_inplace( @@ -119,7 +140,7 @@ pub(crate) trait VecGLWEProduct: Infos { assert_eq!(res.n(), module.n()); } - let (res_data, scratch_1) = scratch.tmp_vec_znx(module, 2, res.size()); + let (res_data, scratch_1) = scratch.tmp_vec_znx(module, res.rank() + 1, res.size()); let mut res_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { data: res_data, @@ -131,8 +152,7 @@ pub(crate) trait VecGLWEProduct: Infos { self.prod_with_glwe_inplace(module, &mut res_idft, scratch_1); - module.vec_znx_dft(res, 0, &res_idft, 0); - module.vec_znx_dft(res, 1, &res_idft, 1); + res_idft.dft(module, res); } fn prod_with_vec_glwe(&self, module: &Module, res: &mut RES, a: &LHS, scratch: &mut Scratch) @@ -140,7 +160,7 @@ pub(crate) trait VecGLWEProduct: Infos { LHS: GetRow + Infos, RES: SetRow + Infos, { - let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, a.size()); + let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, a.cols(), a.size()); let mut tmp_a_row: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { data: tmp_row_data, @@ -148,7 +168,7 @@ pub(crate) trait VecGLWEProduct: Infos { k: a.k(), }; - let (tmp_res_data, scratch2) = scratch1.tmp_vec_znx_dft(module, 2, res.size()); + let (tmp_res_data, scratch2) = scratch1.tmp_vec_znx_dft(module, res.cols(), res.size()); let mut tmp_res_row: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { data: tmp_res_data, @@ -179,7 +199,7 @@ pub(crate) trait VecGLWEProduct: Infos { where RES: GetRow + SetRow + Infos, { - let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, res.size()); + let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, res.cols(), res.size()); let mut tmp_row: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { data: tmp_row_data,