From 655b22ef21718e8500431d90098128c751ef5fbd Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Wed, 11 Jun 2025 14:31:32 +0200 Subject: [PATCH] Small optimization + more fixes --- backend/src/lib.rs | 2 +- backend/src/mat_znx_dft_ops.rs | 2 +- backend/src/vec_znx_dft.rs | 12 +++++++----- backend/src/vec_znx_dft_ops.rs | 6 +++--- core/src/ggsw_ciphertext.rs | 9 +++------ core/src/glwe_ciphertext.rs | 15 ++++++++------- core/src/glwe_ciphertext_fourier.rs | 14 ++++++++++---- 7 files changed, 33 insertions(+), 27 deletions(-) diff --git a/backend/src/lib.rs b/backend/src/lib.rs index 49b8420..dcf4325 100644 --- a/backend/src/lib.rs +++ b/backend/src/lib.rs @@ -150,7 +150,7 @@ impl Scratch { unsafe { &mut *(data as *mut [u8] as *mut Self) } } - pub fn zero(&mut self){ + pub fn zero(&mut self) { self.data.fill(0); } diff --git a/backend/src/mat_znx_dft_ops.rs b/backend/src/mat_znx_dft_ops.rs index d831f73..9656dfb 100644 --- a/backend/src/mat_znx_dft_ops.rs +++ b/backend/src/mat_znx_dft_ops.rs @@ -660,7 +660,7 @@ mod tests { (0..*digits).for_each(|di| { (0..a_cols).for_each(|col_i| { - module.vec_znx_dft(digits - 1 - di, *digits, &mut a_dft, col_i, &a, col_i); + module.vec_znx_dft(*digits, digits - 1 - di, &mut a_dft, col_i, &a, col_i); }); if di == 0 { diff --git a/backend/src/vec_znx_dft.rs b/backend/src/vec_znx_dft.rs index e8bf782..d5c0ad5 100644 --- a/backend/src/vec_znx_dft.rs +++ b/backend/src/vec_znx_dft.rs @@ -2,9 +2,7 @@ use std::marker::PhantomData; use crate::ffi::vec_znx_dft; use crate::znx_base::ZnxInfos; -use crate::{ - Backend, DataView, DataViewMut, FFT64, Module, VecZnxBig, ZnxSliceSize, ZnxView, alloc_aligned, -}; +use crate::{Backend, DataView, DataViewMut, FFT64, Module, VecZnxBig, ZnxSliceSize, ZnxView, alloc_aligned}; use std::fmt; pub struct VecZnxDft { @@ -62,11 +60,15 @@ impl> ZnxView for VecZnxDft { type Scalar = f64; } -impl + AsRef<[u8]>> VecZnxDft{ - pub fn set_size(&mut self, size: usize){ +impl + AsRef<[u8]>> VecZnxDft { + pub fn set_size(&mut self, size: usize) { assert!(size <= self.data.as_ref().len() / (self.n * self.cols())); self.size = size } + + pub fn max_size(&mut self) -> usize { + self.data.as_ref().len() / (self.n * self.cols) + } } pub(crate) fn bytes_of_vec_znx_dft(module: &Module, cols: usize, size: usize) -> usize { diff --git a/backend/src/vec_znx_dft_ops.rs b/backend/src/vec_znx_dft_ops.rs index 6bfa9fb..963de18 100644 --- a/backend/src/vec_znx_dft_ops.rs +++ b/backend/src/vec_znx_dft_ops.rs @@ -163,10 +163,10 @@ impl VecZnxDftOps for Module { (0..min_steps).for_each(|j| { let limb: usize = offset + j * step; - if limb < a_ref.size(){ + if limb < a_ref.size() { res_mut - .at_mut(res_col, j) - .copy_from_slice(a_ref.at(a_col, limb)); + .at_mut(res_col, j) + .copy_from_slice(a_ref.at(a_col, limb)); } }); (min_steps..res_mut.size()).for_each(|j| { diff --git a/core/src/ggsw_ciphertext.rs b/core/src/ggsw_ciphertext.rs index 7b58771..8a1aa70 100644 --- a/core/src/ggsw_ciphertext.rs +++ b/core/src/ggsw_ciphertext.rs @@ -348,7 +348,6 @@ impl + AsRef<[u8]>> GGSWCiphertext { let (mut tmp_dft_i, scratch1) = scratch.tmp_vec_znx_dft(module, cols, tsk.size()); let (mut tmp_a, scratch2) = scratch1.tmp_vec_znx_dft(module, 1, (ci_dft.size() + digits - 1) / digits); - let res_size: usize = res.to_mut().size(); { // Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2 @@ -363,23 +362,21 @@ impl + AsRef<[u8]>> GGSWCiphertext { // = // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2) (1..cols).for_each(|col_i| { - let pmat: &MatZnxDft = &tsk.at(col_i - 1, col_j - 1).0.data; // Selects Enc(s[i]s[j]) // Extracts a[i] and multipies with Enc(s[i]s[j]) (0..digits).for_each(|di| { - tmp_a.set_size((ci_dft.size() + di) / digits); // Small optimization for digits > 2 // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}. // As such we can ignore the last digits-2 limbs safely of the sum of vmp products. - // It is possible to further ignore the last digits-1 limbs, but this introduce + // It is possible to further ignore the last digits-1 limbs, but this introduce // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same // noise is kept with respect to the ideal functionality. - //tmp_dft_i.set_size(res_size - ((digits - di) as isize - 2).max(0) as usize); - + tmp_dft_i.set_size(tsk.size() - ((digits - di) as isize - 2).max(0) as usize); + module.vec_znx_dft_copy(digits, digits - 1 - di, &mut tmp_a, 0, ci_dft, col_i); if di == 0 && col_i == 1 { module.vmp_apply(&mut tmp_dft_i, &tmp_a, pmat, scratch2); diff --git a/core/src/glwe_ciphertext.rs b/core/src/glwe_ciphertext.rs index dc8b5f0..868de43 100644 --- a/core/src/glwe_ciphertext.rs +++ b/core/src/glwe_ciphertext.rs @@ -1,5 +1,7 @@ use backend::{ - AddNormal, Backend, FillUniform, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxAlloc, ScalarZnxDftAlloc, ScalarZnxDftOps, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxZero, FFT64 + AddNormal, Backend, FFT64, FillUniform, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxAlloc, ScalarZnxDftAlloc, + ScalarZnxDftOps, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc, + VecZnxDftOps, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxZero, }; use sampling::source::Source; @@ -500,17 +502,16 @@ impl + AsMut<[u8]>> GLWECiphertext { ai_dft.zero(); { (0..digits).for_each(|di| { - ai_dft.set_size((lhs.size() + di) / digits); // Small optimization for digits > 2 // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}. // As such we can ignore the last digits-2 limbs safely of the sum of vmp products. - // It is possible to further ignore the last digits-1 limbs, but this introduce + // It is possible to further ignore the last digits-1 limbs, but this introduce // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same // noise is kept with respect to the ideal functionality. - //res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize); + res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize); (0..cols_in).for_each(|col_i| { module.vec_znx_dft( @@ -598,7 +599,7 @@ impl + AsMut<[u8]>> GLWECiphertext { let digits: usize = rhs.digits(); let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); // Todo optimise - let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, (lhs.size() + digits-1) / digits); + let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, (lhs.size() + digits - 1) / digits); { (0..digits).for_each(|di| { @@ -609,10 +610,10 @@ impl + AsMut<[u8]>> GLWECiphertext { // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}. // As such we can ignore the last digits-2 limbs safely of the sum of vmp products. - // It is possible to further ignore the last digits-1 limbs, but this introduce + // It is possible to further ignore the last digits-1 limbs, but this introduce // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same // noise is kept with respect to the ideal functionality. - //res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize); + res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize); (0..cols).for_each(|col_i| { module.vec_znx_dft(digits, digits - 1 - di, &mut a_dft, col_i, &lhs.data, col_i); diff --git a/core/src/glwe_ciphertext_fourier.rs b/core/src/glwe_ciphertext_fourier.rs index c3cdc40..cfe2381 100644 --- a/core/src/glwe_ciphertext_fourier.rs +++ b/core/src/glwe_ciphertext_fourier.rs @@ -199,17 +199,23 @@ impl + AsRef<[u8]>> GLWECiphertextFourier let cols: usize = rhs.rank() + 1; let digits = rhs.digits(); - // Space for VMP result in DFT domain and high precision let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, (lhs.size() + digits - 1) / digits); { (0..digits).for_each(|di| { - a_dft.set_size((lhs.size() + di) / digits); - res_dft.set_size(rhs.size() - (digits - di - 1)); - + + // Small optimization for digits > 2 + // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then + // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}. + // As such we can ignore the last digits-2 limbs safely of the sum of vmp products. + // It is possible to further ignore the last digits-1 limbs, but this introduce + // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same + // noise is kept with respect to the ideal functionality. + res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize); + (0..cols).for_each(|col_i| { module.vec_znx_dft_copy(digits, digits - 1 - di, &mut a_dft, col_i, &lhs.data, col_i); });