diff --git a/backend/spqlios-arithmetic b/backend/spqlios-arithmetic index 6ac426c..173b980 160000 --- a/backend/spqlios-arithmetic +++ b/backend/spqlios-arithmetic @@ -1 +1 @@ -Subproject commit 6ac426c9006f123738d2fcaa6f5bb1c33cd890e3 +Subproject commit 173b980c7b8a4f0523d04c2aed061c2e046e846c diff --git a/backend/src/lib.rs b/backend/src/lib.rs index 9e71c8d..49b8420 100644 --- a/backend/src/lib.rs +++ b/backend/src/lib.rs @@ -150,6 +150,10 @@ impl Scratch { unsafe { &mut *(data as *mut [u8] as *mut Self) } } + pub fn zero(&mut self){ + self.data.fill(0); + } + pub fn available(&self) -> usize { let ptr: *const u8 = self.data.as_ptr(); let self_len: usize = self.data.len(); diff --git a/backend/src/vec_znx_dft.rs b/backend/src/vec_znx_dft.rs index 82b2cf4..e8bf782 100644 --- a/backend/src/vec_znx_dft.rs +++ b/backend/src/vec_znx_dft.rs @@ -62,6 +62,13 @@ impl> ZnxView for VecZnxDft { type Scalar = f64; } +impl + AsRef<[u8]>> VecZnxDft{ + pub fn set_size(&mut self, size: usize){ + assert!(size <= self.data.as_ref().len() / (self.n * self.cols())); + self.size = size + } +} + pub(crate) fn bytes_of_vec_znx_dft(module: &Module, cols: usize, size: usize) -> usize { unsafe { vec_znx_dft::bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols } } diff --git a/core/benches/keyswitch_glwe_fft64.rs b/core/benches/keyswitch_glwe_fft64.rs index 69651fd..51d8049 100644 --- a/core/benches/keyswitch_glwe_fft64.rs +++ b/core/benches/keyswitch_glwe_fft64.rs @@ -1,8 +1,8 @@ use backend::{FFT64, Module, ScratchOwned}; -use core::{GLWECiphertext, GLWESecret, GLWESwitchingKey, Infos}; +use core::{AutomorphismKey, GLWECiphertext, GLWESecret, GLWESwitchingKey, Infos}; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; use sampling::source::Source; -use std::hint::black_box; +use std::{hint::black_box, time::Duration}; fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { let mut group = c.benchmark_group("keyswitch_glwe_fft64"); @@ -13,6 +13,7 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { k_ct_in: usize, k_ct_out: usize, k_ksk: usize, + digits: usize, rank_in: usize, rank_out: usize, } @@ -26,13 +27,12 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { let k_grlwe: usize = p.k_ksk; let rank_in: usize = p.rank_in; let rank_out: usize = p.rank_out; - let digits: usize = 1; + let digits: usize = p.digits; - let rows: usize = (p.k_ct_in + p.basek - 1) / p.basek; + let rows: usize = (p.k_ct_in + (p.basek * digits) - 1) / (p.basek * digits); let sigma: f64 = 3.2; - let mut ksk: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::alloc(&module, basek, k_grlwe, rows, digits, rank_in, rank_out); + let mut ksk: AutomorphismKey, FFT64> = AutomorphismKey::alloc(&module, basek, k_grlwe, rows, digits, rank_out); let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_rlwe_in, rank_in); let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_rlwe_out, rank_out); @@ -63,8 +63,8 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { ksk.generate_from_sk( &module, + -1, &sk_in, - &sk_out, &mut source_xa, &mut source_xe, sigma, @@ -81,21 +81,20 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { ); move || { - ct_out.keyswitch( - black_box(&module), - black_box(&ct_in), - black_box(&ksk), - black_box(scratch.borrow()), - ); + ct_out.automorphism(&module, &ct_in, &ksk, scratch.borrow()); } } + let digits: usize = 1; + let basek: usize = 19; + let params_set: Vec = vec![Params { - log_n: 16, - basek: 50, - k_ct_in: 1250, - k_ct_out: 1250, - k_ksk: 1250 + 66, + log_n: 15, + basek: basek, + k_ct_in: 874 - digits * basek, + k_ct_out: 874 - digits * basek, + k_ksk: 874, + digits: digits, rank_in: 1, rank_out: 1, }]; @@ -103,6 +102,8 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { for params in params_set { let id = BenchmarkId::new("KEYSWITCH_GLWE_FFT64", ""); let mut runner = runner(params); + group.sample_size(500); + group.measurement_time(Duration::from_secs(40)); group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); } diff --git a/core/src/ggsw_ciphertext.rs b/core/src/ggsw_ciphertext.rs index b15f6fc..7b58771 100644 --- a/core/src/ggsw_ciphertext.rs +++ b/core/src/ggsw_ciphertext.rs @@ -344,7 +344,12 @@ impl + AsRef<[u8]>> GGSWCiphertext { // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M[i], c2 ) // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M[i]) + let digits: usize = tsk.digits(); + let (mut tmp_dft_i, scratch1) = scratch.tmp_vec_znx_dft(module, cols, tsk.size()); + let (mut tmp_a, scratch2) = scratch1.tmp_vec_znx_dft(module, 1, (ci_dft.size() + digits - 1) / digits); + let res_size: usize = res.to_mut().size(); + { // Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2 // @@ -358,28 +363,30 @@ impl + AsRef<[u8]>> GGSWCiphertext { // = // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2) (1..cols).for_each(|col_i| { - let digits: usize = tsk.digits(); - + let pmat: &MatZnxDft = &tsk.at(col_i - 1, col_j - 1).0.data; // Selects Enc(s[i]s[j]) // Extracts a[i] and multipies with Enc(s[i]s[j]) - if col_i == 1 { - (0..digits).for_each(|di| { - let (mut tmp_a, scratch2) = scratch1.tmp_vec_znx_dft(module, 1, (ci_dft.size() + di) / digits); - module.vec_znx_dft_copy(digits, digits - 1 - di, &mut tmp_a, 0, ci_dft, col_i); - if di == 0 { - module.vmp_apply(&mut tmp_dft_i, &tmp_a, pmat, scratch2); - } else { - module.vmp_apply_add(&mut tmp_dft_i, &tmp_a, pmat, di, scratch2); - } - }); - } else { - (0..digits).for_each(|di| { - let (mut tmp_a, scratch2) = scratch1.tmp_vec_znx_dft(module, 1, (ci_dft.size() + di) / digits); - module.vec_znx_dft_copy(digits, digits - 1 - di, &mut tmp_a, 0, ci_dft, col_i); + (0..digits).for_each(|di| { + + tmp_a.set_size((ci_dft.size() + di) / digits); + + // Small optimization for digits > 2 + // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then + // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}. + // As such we can ignore the last digits-2 limbs safely of the sum of vmp products. + // It is possible to further ignore the last digits-1 limbs, but this introduce + // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same + // noise is kept with respect to the ideal functionality. + //tmp_dft_i.set_size(res_size - ((digits - di) as isize - 2).max(0) as usize); + + module.vec_znx_dft_copy(digits, digits - 1 - di, &mut tmp_a, 0, ci_dft, col_i); + if di == 0 && col_i == 1 { + module.vmp_apply(&mut tmp_dft_i, &tmp_a, pmat, scratch2); + } else { module.vmp_apply_add(&mut tmp_dft_i, &tmp_a, pmat, di, scratch2); - }); - } + } + }); }); } diff --git a/core/src/glwe_ciphertext.rs b/core/src/glwe_ciphertext.rs index 38e639e..dc8b5f0 100644 --- a/core/src/glwe_ciphertext.rs +++ b/core/src/glwe_ciphertext.rs @@ -1,7 +1,5 @@ use backend::{ - AddNormal, Backend, FFT64, FillUniform, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxAlloc, ScalarZnxDftAlloc, - ScalarZnxDftOps, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc, - VecZnxDftOps, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxZero, + AddNormal, Backend, FillUniform, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxAlloc, ScalarZnxDftAlloc, ScalarZnxDftOps, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxZero, FFT64 }; use sampling::source::Source; @@ -94,14 +92,15 @@ impl GLWECiphertext> { rank_in: usize, rank_out: usize, ) -> usize { - let res_dft: usize = GLWECiphertextFourier::bytes_of(module, basek, k_out, rank_out); + let res_dft: usize = GLWECiphertextFourier::bytes_of(module, basek, k_out, rank_out + 1); let in_size: usize = div_ceil(div_ceil(k_in, basek), digits); let out_size: usize = div_ceil(k_out, basek); let ksk_size: usize = div_ceil(k_ksk, basek); + let ai_dft: usize = module.bytes_of_vec_znx_dft(rank_in, in_size); let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, rank_in, rank_out + 1, ksk_size) + module.bytes_of_vec_znx_dft(rank_in, in_size); let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(); - return res_dft + (vmp | normalize); + return res_dft + ((ai_dft + vmp) | normalize); } pub fn keyswitch_from_fourier_scratch_space( @@ -494,20 +493,29 @@ impl + AsMut<[u8]>> GLWECiphertext { let cols_in: usize = rhs.rank_in(); let cols_out: usize = rhs.rank_out() + 1; + let digits: usize = rhs.digits(); let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise - + let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, (lhs.size() + digits - 1) / digits); + ai_dft.zero(); { - let digits = rhs.digits(); - (0..digits).for_each(|di| { - // (lhs.size() + di) / digits = (a - (digit - di - 1) + digit - 1) / digits - let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, (lhs.size() + di) / digits); + + ai_dft.set_size((lhs.size() + di) / digits); + + // Small optimization for digits > 2 + // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then + // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}. + // As such we can ignore the last digits-2 limbs safely of the sum of vmp products. + // It is possible to further ignore the last digits-1 limbs, but this introduce + // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same + // noise is kept with respect to the ideal functionality. + //res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize); (0..cols_in).for_each(|col_i| { module.vec_znx_dft( digits, - digits - 1 - di, + digits - di - 1, &mut ai_dft, col_i, &lhs.data, @@ -587,15 +595,24 @@ impl + AsMut<[u8]>> GLWECiphertext { } let cols: usize = rhs.rank() + 1; + let digits: usize = rhs.digits(); let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); // Todo optimise + let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, (lhs.size() + digits-1) / digits); { - let digits = rhs.digits(); - (0..digits).for_each(|di| { // (lhs.size() + di) / digits = (a - (digit - di - 1) + digit - 1) / digits - let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, (lhs.size() + di) / digits); + a_dft.set_size((lhs.size() + di) / digits); + + // Small optimization for digits > 2 + // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then + // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}. + // As such we can ignore the last digits-2 limbs safely of the sum of vmp products. + // It is possible to further ignore the last digits-1 limbs, but this introduce + // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same + // noise is kept with respect to the ideal functionality. + //res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize); (0..cols).for_each(|col_i| { module.vec_znx_dft(digits, digits - 1 - di, &mut a_dft, col_i, &lhs.data, col_i); diff --git a/core/src/glwe_ciphertext_fourier.rs b/core/src/glwe_ciphertext_fourier.rs index c132a4a..c3cdc40 100644 --- a/core/src/glwe_ciphertext_fourier.rs +++ b/core/src/glwe_ciphertext_fourier.rs @@ -96,7 +96,7 @@ impl GLWECiphertextFourier, FFT64> { pub fn external_product_scratch_space( module: &Module, basek: usize, - k_out: usize, + _k_out: usize, k_in: usize, k_ggsw: usize, digits: usize, @@ -197,17 +197,19 @@ impl + AsRef<[u8]>> GLWECiphertextFourier } let cols: usize = rhs.rank() + 1; + let digits = rhs.digits(); + // Space for VMP result in DFT domain and high precision let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); + let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, (lhs.size() + digits - 1) / digits); { - let digits = rhs.digits(); - (0..digits).for_each(|di| { - // (lhs.size() + di) / digits = (a - (digit - di - 1) + digit - 1) / digits - let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, (lhs.size() + di) / digits); + a_dft.set_size((lhs.size() + di) / digits); + res_dft.set_size(rhs.size() - (digits - di - 1)); + (0..cols).for_each(|col_i| { module.vec_znx_dft_copy(digits, digits - 1 - di, &mut a_dft, col_i, &lhs.data, col_i); }); diff --git a/core/src/test_fft64/glwe.rs b/core/src/test_fft64/glwe.rs index bc1ca9e..ac68bff 100644 --- a/core/src/test_fft64/glwe.rs +++ b/core/src/test_fft64/glwe.rs @@ -557,7 +557,7 @@ fn test_automorphism( ); assert!( - (noise_have - noise_want).abs() <= 0.5, + noise_have <= noise_want + 1.0, "{} {}", noise_have, noise_want