Small optimization + more fixes

This commit is contained in:
Jean-Philippe Bossuat
2025-06-11 14:31:32 +02:00
parent a673b84047
commit 655b22ef21
7 changed files with 33 additions and 27 deletions

View File

@@ -150,7 +150,7 @@ impl Scratch {
unsafe { &mut *(data as *mut [u8] as *mut Self) }
}
pub fn zero(&mut self){
pub fn zero(&mut self) {
self.data.fill(0);
}

View File

@@ -660,7 +660,7 @@ mod tests {
(0..*digits).for_each(|di| {
(0..a_cols).for_each(|col_i| {
module.vec_znx_dft(digits - 1 - di, *digits, &mut a_dft, col_i, &a, col_i);
module.vec_znx_dft(*digits, digits - 1 - di, &mut a_dft, col_i, &a, col_i);
});
if di == 0 {

View File

@@ -2,9 +2,7 @@ use std::marker::PhantomData;
use crate::ffi::vec_znx_dft;
use crate::znx_base::ZnxInfos;
use crate::{
Backend, DataView, DataViewMut, FFT64, Module, VecZnxBig, ZnxSliceSize, ZnxView, alloc_aligned,
};
use crate::{Backend, DataView, DataViewMut, FFT64, Module, VecZnxBig, ZnxSliceSize, ZnxView, alloc_aligned};
use std::fmt;
pub struct VecZnxDft<D, B: Backend> {
@@ -62,11 +60,15 @@ impl<D: AsRef<[u8]>> ZnxView for VecZnxDft<D, FFT64> {
type Scalar = f64;
}
impl<D: AsMut<[u8]> + AsRef<[u8]>> VecZnxDft<D, FFT64>{
pub fn set_size(&mut self, size: usize){
impl<D: AsMut<[u8]> + AsRef<[u8]>> VecZnxDft<D, FFT64> {
pub fn set_size(&mut self, size: usize) {
assert!(size <= self.data.as_ref().len() / (self.n * self.cols()));
self.size = size
}
pub fn max_size(&mut self) -> usize {
self.data.as_ref().len() / (self.n * self.cols)
}
}
pub(crate) fn bytes_of_vec_znx_dft<B: Backend>(module: &Module<B>, cols: usize, size: usize) -> usize {

View File

@@ -163,7 +163,7 @@ impl VecZnxDftOps<FFT64> for Module<FFT64> {
(0..min_steps).for_each(|j| {
let limb: usize = offset + j * step;
if limb < a_ref.size(){
if limb < a_ref.size() {
res_mut
.at_mut(res_col, j)
.copy_from_slice(a_ref.at(a_col, limb));

View File

@@ -348,7 +348,6 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
let (mut tmp_dft_i, scratch1) = scratch.tmp_vec_znx_dft(module, cols, tsk.size());
let (mut tmp_a, scratch2) = scratch1.tmp_vec_znx_dft(module, 1, (ci_dft.size() + digits - 1) / digits);
let res_size: usize = res.to_mut().size();
{
// Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2
@@ -363,12 +362,10 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
// =
// (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2)
(1..cols).for_each(|col_i| {
let pmat: &MatZnxDft<DataTsk, FFT64> = &tsk.at(col_i - 1, col_j - 1).0.data; // Selects Enc(s[i]s[j])
// Extracts a[i] and multipies with Enc(s[i]s[j])
(0..digits).for_each(|di| {
tmp_a.set_size((ci_dft.size() + di) / digits);
// Small optimization for digits > 2
@@ -378,7 +375,7 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
// It is possible to further ignore the last digits-1 limbs, but this introduce
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
// noise is kept with respect to the ideal functionality.
//tmp_dft_i.set_size(res_size - ((digits - di) as isize - 2).max(0) as usize);
tmp_dft_i.set_size(tsk.size() - ((digits - di) as isize - 2).max(0) as usize);
module.vec_znx_dft_copy(digits, digits - 1 - di, &mut tmp_a, 0, ci_dft, col_i);
if di == 0 && col_i == 1 {

View File

@@ -1,5 +1,7 @@
use backend::{
AddNormal, Backend, FillUniform, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxAlloc, ScalarZnxDftAlloc, ScalarZnxDftOps, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxZero, FFT64
AddNormal, Backend, FFT64, FillUniform, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxAlloc, ScalarZnxDftAlloc,
ScalarZnxDftOps, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc,
VecZnxDftOps, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxZero,
};
use sampling::source::Source;
@@ -500,7 +502,6 @@ impl<DataSelf: AsRef<[u8]> + AsMut<[u8]>> GLWECiphertext<DataSelf> {
ai_dft.zero();
{
(0..digits).for_each(|di| {
ai_dft.set_size((lhs.size() + di) / digits);
// Small optimization for digits > 2
@@ -510,7 +511,7 @@ impl<DataSelf: AsRef<[u8]> + AsMut<[u8]>> GLWECiphertext<DataSelf> {
// It is possible to further ignore the last digits-1 limbs, but this introduce
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
// noise is kept with respect to the ideal functionality.
//res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize);
res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize);
(0..cols_in).for_each(|col_i| {
module.vec_znx_dft(
@@ -598,7 +599,7 @@ impl<DataSelf: AsRef<[u8]> + AsMut<[u8]>> GLWECiphertext<DataSelf> {
let digits: usize = rhs.digits();
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); // Todo optimise
let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, (lhs.size() + digits-1) / digits);
let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, (lhs.size() + digits - 1) / digits);
{
(0..digits).for_each(|di| {
@@ -612,7 +613,7 @@ impl<DataSelf: AsRef<[u8]> + AsMut<[u8]>> GLWECiphertext<DataSelf> {
// It is possible to further ignore the last digits-1 limbs, but this introduce
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
// noise is kept with respect to the ideal functionality.
//res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize);
res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize);
(0..cols).for_each(|col_i| {
module.vec_znx_dft(digits, digits - 1 - di, &mut a_dft, col_i, &lhs.data, col_i);

View File

@@ -199,16 +199,22 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GLWECiphertextFourier<DataSelf, FFT64>
let cols: usize = rhs.rank() + 1;
let digits = rhs.digits();
// Space for VMP result in DFT domain and high precision
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size());
let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, (lhs.size() + digits - 1) / digits);
{
(0..digits).for_each(|di| {
a_dft.set_size((lhs.size() + di) / digits);
res_dft.set_size(rhs.size() - (digits - di - 1));
// Small optimization for digits > 2
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}.
// As such we can ignore the last digits-2 limbs safely of the sum of vmp products.
// It is possible to further ignore the last digits-1 limbs, but this introduce
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
// noise is kept with respect to the ideal functionality.
res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize);
(0..cols).for_each(|col_i| {
module.vec_znx_dft_copy(digits, digits - 1 - di, &mut a_dft, col_i, &lhs.data, col_i);