Small optimization + more fixes

This commit is contained in:
Jean-Philippe Bossuat
2025-06-11 14:31:32 +02:00
parent a673b84047
commit 655b22ef21
7 changed files with 33 additions and 27 deletions

View File

@@ -150,7 +150,7 @@ impl Scratch {
unsafe { &mut *(data as *mut [u8] as *mut Self) } unsafe { &mut *(data as *mut [u8] as *mut Self) }
} }
pub fn zero(&mut self){ pub fn zero(&mut self) {
self.data.fill(0); self.data.fill(0);
} }

View File

@@ -660,7 +660,7 @@ mod tests {
(0..*digits).for_each(|di| { (0..*digits).for_each(|di| {
(0..a_cols).for_each(|col_i| { (0..a_cols).for_each(|col_i| {
module.vec_znx_dft(digits - 1 - di, *digits, &mut a_dft, col_i, &a, col_i); module.vec_znx_dft(*digits, digits - 1 - di, &mut a_dft, col_i, &a, col_i);
}); });
if di == 0 { if di == 0 {

View File

@@ -2,9 +2,7 @@ use std::marker::PhantomData;
use crate::ffi::vec_znx_dft; use crate::ffi::vec_znx_dft;
use crate::znx_base::ZnxInfos; use crate::znx_base::ZnxInfos;
use crate::{ use crate::{Backend, DataView, DataViewMut, FFT64, Module, VecZnxBig, ZnxSliceSize, ZnxView, alloc_aligned};
Backend, DataView, DataViewMut, FFT64, Module, VecZnxBig, ZnxSliceSize, ZnxView, alloc_aligned,
};
use std::fmt; use std::fmt;
pub struct VecZnxDft<D, B: Backend> { pub struct VecZnxDft<D, B: Backend> {
@@ -62,11 +60,15 @@ impl<D: AsRef<[u8]>> ZnxView for VecZnxDft<D, FFT64> {
type Scalar = f64; type Scalar = f64;
} }
impl<D: AsMut<[u8]> + AsRef<[u8]>> VecZnxDft<D, FFT64>{ impl<D: AsMut<[u8]> + AsRef<[u8]>> VecZnxDft<D, FFT64> {
pub fn set_size(&mut self, size: usize){ pub fn set_size(&mut self, size: usize) {
assert!(size <= self.data.as_ref().len() / (self.n * self.cols())); assert!(size <= self.data.as_ref().len() / (self.n * self.cols()));
self.size = size self.size = size
} }
pub fn max_size(&mut self) -> usize {
self.data.as_ref().len() / (self.n * self.cols)
}
} }
pub(crate) fn bytes_of_vec_znx_dft<B: Backend>(module: &Module<B>, cols: usize, size: usize) -> usize { pub(crate) fn bytes_of_vec_znx_dft<B: Backend>(module: &Module<B>, cols: usize, size: usize) -> usize {

View File

@@ -163,10 +163,10 @@ impl VecZnxDftOps<FFT64> for Module<FFT64> {
(0..min_steps).for_each(|j| { (0..min_steps).for_each(|j| {
let limb: usize = offset + j * step; let limb: usize = offset + j * step;
if limb < a_ref.size(){ if limb < a_ref.size() {
res_mut res_mut
.at_mut(res_col, j) .at_mut(res_col, j)
.copy_from_slice(a_ref.at(a_col, limb)); .copy_from_slice(a_ref.at(a_col, limb));
} }
}); });
(min_steps..res_mut.size()).for_each(|j| { (min_steps..res_mut.size()).for_each(|j| {

View File

@@ -348,7 +348,6 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
let (mut tmp_dft_i, scratch1) = scratch.tmp_vec_znx_dft(module, cols, tsk.size()); let (mut tmp_dft_i, scratch1) = scratch.tmp_vec_znx_dft(module, cols, tsk.size());
let (mut tmp_a, scratch2) = scratch1.tmp_vec_znx_dft(module, 1, (ci_dft.size() + digits - 1) / digits); let (mut tmp_a, scratch2) = scratch1.tmp_vec_znx_dft(module, 1, (ci_dft.size() + digits - 1) / digits);
let res_size: usize = res.to_mut().size();
{ {
// Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2 // Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2
@@ -363,23 +362,21 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
// = // =
// (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2) // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2)
(1..cols).for_each(|col_i| { (1..cols).for_each(|col_i| {
let pmat: &MatZnxDft<DataTsk, FFT64> = &tsk.at(col_i - 1, col_j - 1).0.data; // Selects Enc(s[i]s[j]) let pmat: &MatZnxDft<DataTsk, FFT64> = &tsk.at(col_i - 1, col_j - 1).0.data; // Selects Enc(s[i]s[j])
// Extracts a[i] and multipies with Enc(s[i]s[j]) // Extracts a[i] and multipies with Enc(s[i]s[j])
(0..digits).for_each(|di| { (0..digits).for_each(|di| {
tmp_a.set_size((ci_dft.size() + di) / digits); tmp_a.set_size((ci_dft.size() + di) / digits);
// Small optimization for digits > 2 // Small optimization for digits > 2
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}. // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}.
// As such we can ignore the last digits-2 limbs safely of the sum of vmp products. // As such we can ignore the last digits-2 limbs safely of the sum of vmp products.
// It is possible to further ignore the last digits-1 limbs, but this introduce // It is possible to further ignore the last digits-1 limbs, but this introduce
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
// noise is kept with respect to the ideal functionality. // noise is kept with respect to the ideal functionality.
//tmp_dft_i.set_size(res_size - ((digits - di) as isize - 2).max(0) as usize); tmp_dft_i.set_size(tsk.size() - ((digits - di) as isize - 2).max(0) as usize);
module.vec_znx_dft_copy(digits, digits - 1 - di, &mut tmp_a, 0, ci_dft, col_i); module.vec_znx_dft_copy(digits, digits - 1 - di, &mut tmp_a, 0, ci_dft, col_i);
if di == 0 && col_i == 1 { if di == 0 && col_i == 1 {
module.vmp_apply(&mut tmp_dft_i, &tmp_a, pmat, scratch2); module.vmp_apply(&mut tmp_dft_i, &tmp_a, pmat, scratch2);

View File

@@ -1,5 +1,7 @@
use backend::{ use backend::{
AddNormal, Backend, FillUniform, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxAlloc, ScalarZnxDftAlloc, ScalarZnxDftOps, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxZero, FFT64 AddNormal, Backend, FFT64, FillUniform, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxAlloc, ScalarZnxDftAlloc,
ScalarZnxDftOps, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc,
VecZnxDftOps, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxZero,
}; };
use sampling::source::Source; use sampling::source::Source;
@@ -500,17 +502,16 @@ impl<DataSelf: AsRef<[u8]> + AsMut<[u8]>> GLWECiphertext<DataSelf> {
ai_dft.zero(); ai_dft.zero();
{ {
(0..digits).for_each(|di| { (0..digits).for_each(|di| {
ai_dft.set_size((lhs.size() + di) / digits); ai_dft.set_size((lhs.size() + di) / digits);
// Small optimization for digits > 2 // Small optimization for digits > 2
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}. // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}.
// As such we can ignore the last digits-2 limbs safely of the sum of vmp products. // As such we can ignore the last digits-2 limbs safely of the sum of vmp products.
// It is possible to further ignore the last digits-1 limbs, but this introduce // It is possible to further ignore the last digits-1 limbs, but this introduce
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
// noise is kept with respect to the ideal functionality. // noise is kept with respect to the ideal functionality.
//res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize); res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize);
(0..cols_in).for_each(|col_i| { (0..cols_in).for_each(|col_i| {
module.vec_znx_dft( module.vec_znx_dft(
@@ -598,7 +599,7 @@ impl<DataSelf: AsRef<[u8]> + AsMut<[u8]>> GLWECiphertext<DataSelf> {
let digits: usize = rhs.digits(); let digits: usize = rhs.digits();
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); // Todo optimise let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); // Todo optimise
let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, (lhs.size() + digits-1) / digits); let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, (lhs.size() + digits - 1) / digits);
{ {
(0..digits).for_each(|di| { (0..digits).for_each(|di| {
@@ -609,10 +610,10 @@ impl<DataSelf: AsRef<[u8]> + AsMut<[u8]>> GLWECiphertext<DataSelf> {
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}. // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}.
// As such we can ignore the last digits-2 limbs safely of the sum of vmp products. // As such we can ignore the last digits-2 limbs safely of the sum of vmp products.
// It is possible to further ignore the last digits-1 limbs, but this introduce // It is possible to further ignore the last digits-1 limbs, but this introduce
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
// noise is kept with respect to the ideal functionality. // noise is kept with respect to the ideal functionality.
//res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize); res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize);
(0..cols).for_each(|col_i| { (0..cols).for_each(|col_i| {
module.vec_znx_dft(digits, digits - 1 - di, &mut a_dft, col_i, &lhs.data, col_i); module.vec_znx_dft(digits, digits - 1 - di, &mut a_dft, col_i, &lhs.data, col_i);

View File

@@ -199,17 +199,23 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GLWECiphertextFourier<DataSelf, FFT64>
let cols: usize = rhs.rank() + 1; let cols: usize = rhs.rank() + 1;
let digits = rhs.digits(); let digits = rhs.digits();
// Space for VMP result in DFT domain and high precision // Space for VMP result in DFT domain and high precision
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size());
let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, (lhs.size() + digits - 1) / digits); let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, (lhs.size() + digits - 1) / digits);
{ {
(0..digits).for_each(|di| { (0..digits).for_each(|di| {
a_dft.set_size((lhs.size() + di) / digits); a_dft.set_size((lhs.size() + di) / digits);
res_dft.set_size(rhs.size() - (digits - di - 1));
// Small optimization for digits > 2
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}.
// As such we can ignore the last digits-2 limbs safely of the sum of vmp products.
// It is possible to further ignore the last digits-1 limbs, but this introduce
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
// noise is kept with respect to the ideal functionality.
res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize);
(0..cols).for_each(|col_i| { (0..cols).for_each(|col_i| {
module.vec_znx_dft_copy(digits, digits - 1 - di, &mut a_dft, col_i, &lhs.data, col_i); module.vec_znx_dft_copy(digits, digits - 1 - di, &mut a_dft, col_i, &lhs.data, col_i);
}); });