mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
Small optimization + more fixes
This commit is contained in:
@@ -150,7 +150,7 @@ impl Scratch {
|
|||||||
unsafe { &mut *(data as *mut [u8] as *mut Self) }
|
unsafe { &mut *(data as *mut [u8] as *mut Self) }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn zero(&mut self){
|
pub fn zero(&mut self) {
|
||||||
self.data.fill(0);
|
self.data.fill(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -660,7 +660,7 @@ mod tests {
|
|||||||
|
|
||||||
(0..*digits).for_each(|di| {
|
(0..*digits).for_each(|di| {
|
||||||
(0..a_cols).for_each(|col_i| {
|
(0..a_cols).for_each(|col_i| {
|
||||||
module.vec_znx_dft(digits - 1 - di, *digits, &mut a_dft, col_i, &a, col_i);
|
module.vec_znx_dft(*digits, digits - 1 - di, &mut a_dft, col_i, &a, col_i);
|
||||||
});
|
});
|
||||||
|
|
||||||
if di == 0 {
|
if di == 0 {
|
||||||
|
|||||||
@@ -2,9 +2,7 @@ use std::marker::PhantomData;
|
|||||||
|
|
||||||
use crate::ffi::vec_znx_dft;
|
use crate::ffi::vec_znx_dft;
|
||||||
use crate::znx_base::ZnxInfos;
|
use crate::znx_base::ZnxInfos;
|
||||||
use crate::{
|
use crate::{Backend, DataView, DataViewMut, FFT64, Module, VecZnxBig, ZnxSliceSize, ZnxView, alloc_aligned};
|
||||||
Backend, DataView, DataViewMut, FFT64, Module, VecZnxBig, ZnxSliceSize, ZnxView, alloc_aligned,
|
|
||||||
};
|
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
|
||||||
pub struct VecZnxDft<D, B: Backend> {
|
pub struct VecZnxDft<D, B: Backend> {
|
||||||
@@ -62,11 +60,15 @@ impl<D: AsRef<[u8]>> ZnxView for VecZnxDft<D, FFT64> {
|
|||||||
type Scalar = f64;
|
type Scalar = f64;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<D: AsMut<[u8]> + AsRef<[u8]>> VecZnxDft<D, FFT64>{
|
impl<D: AsMut<[u8]> + AsRef<[u8]>> VecZnxDft<D, FFT64> {
|
||||||
pub fn set_size(&mut self, size: usize){
|
pub fn set_size(&mut self, size: usize) {
|
||||||
assert!(size <= self.data.as_ref().len() / (self.n * self.cols()));
|
assert!(size <= self.data.as_ref().len() / (self.n * self.cols()));
|
||||||
self.size = size
|
self.size = size
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn max_size(&mut self) -> usize {
|
||||||
|
self.data.as_ref().len() / (self.n * self.cols)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn bytes_of_vec_znx_dft<B: Backend>(module: &Module<B>, cols: usize, size: usize) -> usize {
|
pub(crate) fn bytes_of_vec_znx_dft<B: Backend>(module: &Module<B>, cols: usize, size: usize) -> usize {
|
||||||
|
|||||||
@@ -163,10 +163,10 @@ impl VecZnxDftOps<FFT64> for Module<FFT64> {
|
|||||||
|
|
||||||
(0..min_steps).for_each(|j| {
|
(0..min_steps).for_each(|j| {
|
||||||
let limb: usize = offset + j * step;
|
let limb: usize = offset + j * step;
|
||||||
if limb < a_ref.size(){
|
if limb < a_ref.size() {
|
||||||
res_mut
|
res_mut
|
||||||
.at_mut(res_col, j)
|
.at_mut(res_col, j)
|
||||||
.copy_from_slice(a_ref.at(a_col, limb));
|
.copy_from_slice(a_ref.at(a_col, limb));
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
(min_steps..res_mut.size()).for_each(|j| {
|
(min_steps..res_mut.size()).for_each(|j| {
|
||||||
|
|||||||
@@ -348,7 +348,6 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
|
|||||||
|
|
||||||
let (mut tmp_dft_i, scratch1) = scratch.tmp_vec_znx_dft(module, cols, tsk.size());
|
let (mut tmp_dft_i, scratch1) = scratch.tmp_vec_znx_dft(module, cols, tsk.size());
|
||||||
let (mut tmp_a, scratch2) = scratch1.tmp_vec_znx_dft(module, 1, (ci_dft.size() + digits - 1) / digits);
|
let (mut tmp_a, scratch2) = scratch1.tmp_vec_znx_dft(module, 1, (ci_dft.size() + digits - 1) / digits);
|
||||||
let res_size: usize = res.to_mut().size();
|
|
||||||
|
|
||||||
{
|
{
|
||||||
// Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2
|
// Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2
|
||||||
@@ -363,12 +362,10 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
|
|||||||
// =
|
// =
|
||||||
// (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2)
|
// (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2)
|
||||||
(1..cols).for_each(|col_i| {
|
(1..cols).for_each(|col_i| {
|
||||||
|
|
||||||
let pmat: &MatZnxDft<DataTsk, FFT64> = &tsk.at(col_i - 1, col_j - 1).0.data; // Selects Enc(s[i]s[j])
|
let pmat: &MatZnxDft<DataTsk, FFT64> = &tsk.at(col_i - 1, col_j - 1).0.data; // Selects Enc(s[i]s[j])
|
||||||
|
|
||||||
// Extracts a[i] and multipies with Enc(s[i]s[j])
|
// Extracts a[i] and multipies with Enc(s[i]s[j])
|
||||||
(0..digits).for_each(|di| {
|
(0..digits).for_each(|di| {
|
||||||
|
|
||||||
tmp_a.set_size((ci_dft.size() + di) / digits);
|
tmp_a.set_size((ci_dft.size() + di) / digits);
|
||||||
|
|
||||||
// Small optimization for digits > 2
|
// Small optimization for digits > 2
|
||||||
@@ -378,7 +375,7 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
|
|||||||
// It is possible to further ignore the last digits-1 limbs, but this introduce
|
// It is possible to further ignore the last digits-1 limbs, but this introduce
|
||||||
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
|
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
|
||||||
// noise is kept with respect to the ideal functionality.
|
// noise is kept with respect to the ideal functionality.
|
||||||
//tmp_dft_i.set_size(res_size - ((digits - di) as isize - 2).max(0) as usize);
|
tmp_dft_i.set_size(tsk.size() - ((digits - di) as isize - 2).max(0) as usize);
|
||||||
|
|
||||||
module.vec_znx_dft_copy(digits, digits - 1 - di, &mut tmp_a, 0, ci_dft, col_i);
|
module.vec_znx_dft_copy(digits, digits - 1 - di, &mut tmp_a, 0, ci_dft, col_i);
|
||||||
if di == 0 && col_i == 1 {
|
if di == 0 && col_i == 1 {
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
use backend::{
|
use backend::{
|
||||||
AddNormal, Backend, FillUniform, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxAlloc, ScalarZnxDftAlloc, ScalarZnxDftOps, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxZero, FFT64
|
AddNormal, Backend, FFT64, FillUniform, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxAlloc, ScalarZnxDftAlloc,
|
||||||
|
ScalarZnxDftOps, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc,
|
||||||
|
VecZnxDftOps, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxZero,
|
||||||
};
|
};
|
||||||
use sampling::source::Source;
|
use sampling::source::Source;
|
||||||
|
|
||||||
@@ -500,7 +502,6 @@ impl<DataSelf: AsRef<[u8]> + AsMut<[u8]>> GLWECiphertext<DataSelf> {
|
|||||||
ai_dft.zero();
|
ai_dft.zero();
|
||||||
{
|
{
|
||||||
(0..digits).for_each(|di| {
|
(0..digits).for_each(|di| {
|
||||||
|
|
||||||
ai_dft.set_size((lhs.size() + di) / digits);
|
ai_dft.set_size((lhs.size() + di) / digits);
|
||||||
|
|
||||||
// Small optimization for digits > 2
|
// Small optimization for digits > 2
|
||||||
@@ -510,7 +511,7 @@ impl<DataSelf: AsRef<[u8]> + AsMut<[u8]>> GLWECiphertext<DataSelf> {
|
|||||||
// It is possible to further ignore the last digits-1 limbs, but this introduce
|
// It is possible to further ignore the last digits-1 limbs, but this introduce
|
||||||
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
|
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
|
||||||
// noise is kept with respect to the ideal functionality.
|
// noise is kept with respect to the ideal functionality.
|
||||||
//res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize);
|
res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize);
|
||||||
|
|
||||||
(0..cols_in).for_each(|col_i| {
|
(0..cols_in).for_each(|col_i| {
|
||||||
module.vec_znx_dft(
|
module.vec_znx_dft(
|
||||||
@@ -598,7 +599,7 @@ impl<DataSelf: AsRef<[u8]> + AsMut<[u8]>> GLWECiphertext<DataSelf> {
|
|||||||
let digits: usize = rhs.digits();
|
let digits: usize = rhs.digits();
|
||||||
|
|
||||||
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); // Todo optimise
|
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); // Todo optimise
|
||||||
let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, (lhs.size() + digits-1) / digits);
|
let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, (lhs.size() + digits - 1) / digits);
|
||||||
|
|
||||||
{
|
{
|
||||||
(0..digits).for_each(|di| {
|
(0..digits).for_each(|di| {
|
||||||
@@ -612,7 +613,7 @@ impl<DataSelf: AsRef<[u8]> + AsMut<[u8]>> GLWECiphertext<DataSelf> {
|
|||||||
// It is possible to further ignore the last digits-1 limbs, but this introduce
|
// It is possible to further ignore the last digits-1 limbs, but this introduce
|
||||||
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
|
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
|
||||||
// noise is kept with respect to the ideal functionality.
|
// noise is kept with respect to the ideal functionality.
|
||||||
//res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize);
|
res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize);
|
||||||
|
|
||||||
(0..cols).for_each(|col_i| {
|
(0..cols).for_each(|col_i| {
|
||||||
module.vec_znx_dft(digits, digits - 1 - di, &mut a_dft, col_i, &lhs.data, col_i);
|
module.vec_znx_dft(digits, digits - 1 - di, &mut a_dft, col_i, &lhs.data, col_i);
|
||||||
|
|||||||
@@ -199,16 +199,22 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GLWECiphertextFourier<DataSelf, FFT64>
|
|||||||
let cols: usize = rhs.rank() + 1;
|
let cols: usize = rhs.rank() + 1;
|
||||||
let digits = rhs.digits();
|
let digits = rhs.digits();
|
||||||
|
|
||||||
|
|
||||||
// Space for VMP result in DFT domain and high precision
|
// Space for VMP result in DFT domain and high precision
|
||||||
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size());
|
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size());
|
||||||
let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, (lhs.size() + digits - 1) / digits);
|
let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, (lhs.size() + digits - 1) / digits);
|
||||||
|
|
||||||
{
|
{
|
||||||
(0..digits).for_each(|di| {
|
(0..digits).for_each(|di| {
|
||||||
|
|
||||||
a_dft.set_size((lhs.size() + di) / digits);
|
a_dft.set_size((lhs.size() + di) / digits);
|
||||||
res_dft.set_size(rhs.size() - (digits - di - 1));
|
|
||||||
|
// Small optimization for digits > 2
|
||||||
|
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
|
||||||
|
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}.
|
||||||
|
// As such we can ignore the last digits-2 limbs safely of the sum of vmp products.
|
||||||
|
// It is possible to further ignore the last digits-1 limbs, but this introduce
|
||||||
|
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
|
||||||
|
// noise is kept with respect to the ideal functionality.
|
||||||
|
res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize);
|
||||||
|
|
||||||
(0..cols).for_each(|col_i| {
|
(0..cols).for_each(|col_i| {
|
||||||
module.vec_znx_dft_copy(digits, digits - 1 - di, &mut a_dft, col_i, &lhs.data, col_i);
|
module.vec_znx_dft_copy(digits, digits - 1 - di, &mut a_dft, col_i, &lhs.data, col_i);
|
||||||
|
|||||||
Reference in New Issue
Block a user