diff --git a/core/src/glwe_ciphertext.rs b/core/src/glwe_ciphertext.rs index ed8a39e..56ca1a9 100644 --- a/core/src/glwe_ciphertext.rs +++ b/core/src/glwe_ciphertext.rs @@ -435,17 +435,30 @@ where VecZnx: VecZnxToMut, ScalarZnxDft: ScalarZnxDftToRef, { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), sk_dft.rank()); + assert_eq!(self.n(), module.n()); + assert_eq!(pt.n(), module.n()); + assert_eq!(sk_dft.n(), module.n()); + } + + let cols: usize = self.rank() + 1; + let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, self.size()); // TODO optimize size when pt << ct + c0_big.zero(); { - let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, self.size()); // TODO optimize size when pt << ct - module.vec_znx_dft(&mut c0_dft, 0, self, 1); + (1..cols).for_each(|i| { + // ci_dft = DFT(a[i]) * DFT(s[i]) + let (mut ci_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, self.size()); // TODO optimize size when pt << ct + module.vec_znx_dft(&mut ci_dft, 0, self, i); + module.svp_apply_inplace(&mut ci_dft, 0, sk_dft, i - 1); + let ci_big = module.vec_znx_idft_consume(ci_dft); - // c0_dft = DFT(a) * DFT(s) - module.svp_apply_inplace(&mut c0_dft, 0, sk_dft, 0); - - // c0_big = IDFT(c0_dft) - module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0); + // c0_big += a[i] * s[i] + module.vec_znx_big_add_inplace(&mut c0_big, 0, &ci_big, 0); + }); } // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) diff --git a/core/src/test_fft64/glwe.rs b/core/src/test_fft64/glwe.rs index 48b6cb6..5ffe2dd 100644 --- a/core/src/test_fft64/glwe.rs +++ b/core/src/test_fft64/glwe.rs @@ -1,6 +1,6 @@ use base2k::{ Decoding, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, - ZnxViewMut, ZnxZero, + ZnxView, ZnxViewMut, ZnxZero, }; use itertools::izip; use sampling::source::Source; @@ -17,18 +17,26 @@ use crate::{ }; #[test] -fn encrypt_sk() { - let module: Module = Module::::new(32); - let log_base2k: usize = 8; - let log_k_ct: usize = 54; - let log_k_pt: usize = 30; - let rank: usize = 1; +fn encrypt_sk_rank_1() { + encrypt_sk(11, 8, 54, 30, 3.2, 1); +} - let sigma: f64 = 3.2; +#[test] +fn encrypt_sk_rank_2() { + encrypt_sk(5, 8, 54, 30, 3.2, 2); +} + +#[test] +fn encrypt_sk_rank_3() { + encrypt_sk(11, 8, 54, 30, 3.2, 3); +} + +fn encrypt_sk(log_n: usize, basek: usize, k_ct: usize, k_pt: usize, sigma: f64, rank: usize) { + let module: Module = Module::::new(1 << log_n); let bound: f64 = sigma * 6.0; - let mut ct: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_ct, rank); - let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_pt); + let mut ct: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct, rank); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_pt); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -51,8 +59,7 @@ fn encrypt_sk() { .iter_mut() .for_each(|x| *x = source_xa.next_i64() & 0xFF); - pt.data - .encode_vec_i64(0, log_base2k, log_k_pt, &data_want, 10); + pt.data.encode_vec_i64(0, basek, k_pt, &data_want, 10); ct.encrypt_sk( &module, @@ -72,10 +79,10 @@ fn encrypt_sk() { let mut data_have: Vec = vec![0i64; module.n()]; pt.data - .decode_vec_i64(0, log_base2k, pt.size() * log_base2k, &mut data_have); + .decode_vec_i64(0, basek, pt.size() * basek, &mut data_have); // TODO: properly assert the decryption noise through std(dec(ct) - pt) - let scale: f64 = (1 << (pt.size() * log_base2k - log_k_pt)) as f64; + let scale: f64 = (1 << (pt.size() * basek - k_pt)) as f64; izip!(data_want.iter(), data_have.iter()).for_each(|(a, b)| { let b_scaled = (*b as f64) / scale; assert!( @@ -90,14 +97,14 @@ fn encrypt_sk() { #[test] fn encrypt_zero_sk() { let module: Module = Module::::new(1024); - let log_base2k: usize = 8; - let log_k_ct: usize = 55; + let basek: usize = 8; + let k_ct: usize = 55; let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([1u8; 32]); @@ -108,7 +115,7 @@ fn encrypt_zero_sk() { let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk_dft.dft(&module, &sk); - let mut ct_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_ct, rank); + let mut ct_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ct, rank); let mut scratch: ScratchOwned = ScratchOwned::new( GLWECiphertextFourier::decrypt_scratch_space(&module, ct_dft.size()) @@ -126,22 +133,22 @@ fn encrypt_zero_sk() { ); ct_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); - assert!((sigma - pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2()) <= 0.2); + assert!((sigma - pt.data.std(0, basek) * (k_ct as f64).exp2()) <= 0.2); } #[test] fn encrypt_pk() { let module: Module = Module::::new(32); - let log_base2k: usize = 8; - let log_k_ct: usize = 54; + let basek: usize = 8; + let k_ct: usize = 54; let log_k_pk: usize = 64; let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_ct, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); + let mut ct: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -153,7 +160,7 @@ fn encrypt_pk() { let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk_dft.dft(&module, &sk); - let mut pk: GLWEPublicKey, FFT64> = GLWEPublicKey::new(&module, log_base2k, log_k_pk, rank); + let mut pk: GLWEPublicKey, FFT64> = GLWEPublicKey::new(&module, basek, log_k_pk, rank); pk.generate( &module, &sk_dft, @@ -175,9 +182,7 @@ fn encrypt_pk() { .iter_mut() .for_each(|x| *x = source_xa.next_i64() & 0); - pt_want - .data - .encode_vec_i64(0, log_base2k, log_k_ct, &data_want, 10); + pt_want.data.encode_vec_i64(0, basek, k_ct, &data_want, 10); ct.encrypt_pk( &module, @@ -190,34 +195,33 @@ fn encrypt_pk() { scratch.borrow(), ); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct); ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); module.vec_znx_sub_ab_inplace(&mut pt_want, 0, &pt_have, 0); - assert!(((1.0f64 / 12.0).sqrt() - pt_want.data.std(0, log_base2k) * (log_k_ct as f64).exp2()).abs() < 0.2); + assert!(((1.0f64 / 12.0).sqrt() - pt_want.data.std(0, basek) * (k_ct as f64).exp2()).abs() < 0.2); } #[test] fn keyswitch() { let module: Module = Module::::new(2048); - let log_base2k: usize = 12; + let basek: usize = 12; let log_k_grlwe: usize = 60; let log_k_rlwe_in: usize = 45; let log_k_rlwe_out: usize = 60; - let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + let rows: usize = (log_k_rlwe_in + basek - 1) / basek; let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); - let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in, rank); - let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_out, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); + let mut ct_grlwe: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, log_k_grlwe, rows, rank, rank); + let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, basek, log_k_rlwe_in, rank); + let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, basek, log_k_rlwe_out, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, log_k_rlwe_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, log_k_rlwe_out); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -226,7 +230,7 @@ fn keyswitch() { // Random input plaintext pt_want .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::new( GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) @@ -280,10 +284,10 @@ fn keyswitch() { module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + let noise_have: f64 = pt_have.data.std(0, basek).log2(); let noise_want: f64 = noise_grlwe_rlwe_product( module.n() as f64, - log_base2k, + basek, 0.5, 0.5, 0f64, @@ -304,20 +308,19 @@ fn keyswitch() { #[test] fn keyswich_inplace() { let module: Module = Module::::new(2048); - let log_base2k: usize = 12; + let basek: usize = 12; let log_k_grlwe: usize = 60; let log_k_rlwe: usize = 45; - let rows: usize = (log_k_rlwe + log_base2k - 1) / log_base2k; + let rows: usize = (log_k_rlwe + basek - 1) / basek; let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); - let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe); + let mut ct_grlwe: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, log_k_grlwe, rows, rank, rank); + let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, basek, log_k_rlwe, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, log_k_rlwe); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, log_k_rlwe); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -326,7 +329,7 @@ fn keyswich_inplace() { // Random input plaintext pt_want .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::new( GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) @@ -375,10 +378,10 @@ fn keyswich_inplace() { module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + let noise_have: f64 = pt_have.data.std(0, basek).log2(); let noise_want: f64 = noise_grlwe_rlwe_product( module.n() as f64, - log_base2k, + basek, 0.5, 0.5, 0f64, @@ -399,22 +402,22 @@ fn keyswich_inplace() { #[test] fn external_product() { let module: Module = Module::::new(2048); - let log_base2k: usize = 12; + let basek: usize = 12; let log_k_grlwe: usize = 60; let log_k_rlwe_in: usize = 45; let log_k_rlwe_out: usize = 60; - let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + let rows: usize = (log_k_rlwe_in + basek - 1) / basek; let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); - let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in, rank); - let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_out, rank); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, log_k_grlwe, rows, rank); + let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, basek, log_k_rlwe_in, rank); + let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, basek, log_k_rlwe_out, rank); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, log_k_rlwe_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, log_k_rlwe_out); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -423,7 +426,7 @@ fn external_product() { // Random input plaintext pt_want .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); pt_want.to_mut().at_mut(0, 0)[1] = 1; @@ -479,7 +482,7 @@ fn external_product() { module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + let noise_have: f64 = pt_have.data.std(0, basek).log2(); let var_gct_err_lhs: f64 = sigma * sigma; let var_gct_err_rhs: f64 = 0f64; @@ -490,7 +493,7 @@ fn external_product() { let noise_want: f64 = noise_rgsw_product( module.n() as f64, - log_base2k, + basek, 0.5, var_msg, var_a0_err, @@ -512,21 +515,21 @@ fn external_product() { #[test] fn external_product_inplace() { let module: Module = Module::::new(2048); - let log_base2k: usize = 12; + let basek: usize = 12; let log_k_grlwe: usize = 60; let log_k_rlwe_in: usize = 45; let log_k_rlwe_out: usize = 60; - let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + let rows: usize = (log_k_rlwe_in + basek - 1) / basek; let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); - let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in, rank); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, log_k_grlwe, rows, rank); + let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, basek, log_k_rlwe_in, rank); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, log_k_rlwe_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, log_k_rlwe_out); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -535,7 +538,7 @@ fn external_product_inplace() { // Random input plaintext pt_want .data - .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); pt_want.to_mut().at_mut(0, 0)[1] = 1; @@ -586,7 +589,7 @@ fn external_product_inplace() { module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); - let noise_have: f64 = pt_have.data.std(0, log_base2k).log2(); + let noise_have: f64 = pt_have.data.std(0, basek).log2(); let var_gct_err_lhs: f64 = sigma * sigma; let var_gct_err_rhs: f64 = 0f64; @@ -597,7 +600,7 @@ fn external_product_inplace() { let noise_want: f64 = noise_rgsw_product( module.n() as f64, - log_base2k, + basek, 0.5, var_msg, var_a0_err,