Add cross-basek normalization (#90)

* added cross_basek_normalization

* updated method signatures to take layouts

* fixed cross-base normalization

fix #91
fix #93
This commit is contained in:
Jean-Philippe Bossuat
2025-09-30 14:40:10 +02:00
committed by GitHub
parent 4da790ea6a
commit 37e13b965c
216 changed files with 12481 additions and 7745 deletions

View File

@@ -1,42 +1,41 @@
use poulpy_hal::{
api::{
ScratchAvailable, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume,
VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply,
VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd,
VmpApplyDftToDftTmpBytes,
},
layouts::{Backend, DataMut, DataRef, Module, Scratch},
};
use crate::layouts::{GGLWEAutomorphismKey, GGLWESwitchingKey, prepared::GGSWCiphertextPrepared};
use crate::layouts::{GGLWEAutomorphismKey, GGLWELayoutInfos, GGLWESwitchingKey, GGSWInfos, prepared::GGSWCiphertextPrepared};
impl GGLWEAutomorphismKey<Vec<u8>> {
#[allow(clippy::too_many_arguments)]
pub fn external_product_scratch_space<B: Backend>(
pub fn external_product_scratch_space<B: Backend, OUT, IN, GGSW>(
module: &Module<B>,
basek: usize,
k_out: usize,
k_in: usize,
ggsw_k: usize,
digits: usize,
rank: usize,
out_infos: &OUT,
in_infos: &IN,
ggsw_infos: &GGSW,
) -> usize
where
OUT: GGLWELayoutInfos,
IN: GGLWELayoutInfos,
GGSW: GGSWInfos,
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
{
GGLWESwitchingKey::external_product_scratch_space(module, basek, k_out, k_in, ggsw_k, digits, rank)
GGLWESwitchingKey::external_product_scratch_space(module, out_infos, in_infos, ggsw_infos)
}
pub fn external_product_inplace_scratch_space<B: Backend>(
pub fn external_product_inplace_scratch_space<B: Backend, OUT, GGSW>(
module: &Module<B>,
basek: usize,
k_out: usize,
ggsw_k: usize,
digits: usize,
rank: usize,
out_infos: &OUT,
ggsw_infos: &GGSW,
) -> usize
where
OUT: GGLWELayoutInfos,
GGSW: GGSWInfos,
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
{
GGLWESwitchingKey::external_product_inplace_scratch_space(module, basek, k_out, ggsw_k, digits, rank)
GGLWESwitchingKey::external_product_inplace_scratch_space(module, out_infos, ggsw_infos)
}
}
@@ -55,8 +54,9 @@ impl<DataSelf: DataMut> GGLWEAutomorphismKey<DataSelf> {
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigNormalize<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
+ VecZnxBigNormalize<B>
+ VecZnxNormalize<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
{
self.key.external_product(module, &lhs.key, rhs, scratch);
}
@@ -74,8 +74,9 @@ impl<DataSelf: DataMut> GGLWEAutomorphismKey<DataSelf> {
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigNormalize<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
+ VecZnxBigNormalize<B>
+ VecZnxNormalize<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
{
self.key.external_product_inplace(module, rhs, scratch);
}

View File

@@ -1,42 +1,46 @@
use poulpy_hal::{
api::{
ScratchAvailable, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume,
VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply,
VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd,
VmpApplyDftToDftTmpBytes,
},
layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero},
};
use crate::layouts::{GGLWESwitchingKey, GLWECiphertext, Infos, prepared::GGSWCiphertextPrepared};
use crate::layouts::{GGLWELayoutInfos, GGLWESwitchingKey, GGSWInfos, GLWECiphertext, prepared::GGSWCiphertextPrepared};
impl GGLWESwitchingKey<Vec<u8>> {
#[allow(clippy::too_many_arguments)]
pub fn external_product_scratch_space<B: Backend>(
pub fn external_product_scratch_space<B: Backend, OUT, IN, GGSW>(
module: &Module<B>,
basek: usize,
k_out: usize,
k_in: usize,
k_ggsw: usize,
digits: usize,
rank: usize,
out_infos: &OUT,
in_infos: &IN,
ggsw_infos: &GGSW,
) -> usize
where
OUT: GGLWELayoutInfos,
IN: GGLWELayoutInfos,
GGSW: GGSWInfos,
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
{
GLWECiphertext::external_product_scratch_space(module, basek, k_out, k_in, k_ggsw, digits, rank)
GLWECiphertext::external_product_scratch_space(
module,
&out_infos.glwe_layout(),
&in_infos.glwe_layout(),
ggsw_infos,
)
}
pub fn external_product_inplace_scratch_space<B: Backend>(
pub fn external_product_inplace_scratch_space<B: Backend, OUT, GGSW>(
module: &Module<B>,
basek: usize,
k_out: usize,
k_ggsw: usize,
digits: usize,
rank: usize,
out_infos: &OUT,
ggsw_infos: &GGSW,
) -> usize
where
OUT: GGLWELayoutInfos,
GGSW: GGSWInfos,
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
{
GLWECiphertext::external_product_inplace_scratch_space(module, basek, k_out, k_ggsw, digits, rank)
GLWECiphertext::external_product_inplace_scratch_space(module, &out_infos.glwe_layout(), ggsw_infos)
}
}
@@ -55,11 +59,14 @@ impl<DataSelf: DataMut> GGLWESwitchingKey<DataSelf> {
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigNormalize<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
+ VecZnxBigNormalize<B>
+ VecZnxNormalize<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
{
#[cfg(debug_assertions)]
{
use crate::layouts::GLWEInfos;
assert_eq!(
self.rank_in(),
lhs.rank_in(),
@@ -83,15 +90,15 @@ impl<DataSelf: DataMut> GGLWESwitchingKey<DataSelf> {
);
}
(0..self.rank_in()).for_each(|col_i| {
(0..self.rows()).for_each(|row_j| {
(0..self.rank_in().into()).for_each(|col_i| {
(0..self.rows().into()).for_each(|row_j| {
self.at_mut(row_j, col_i)
.external_product(module, &lhs.at(row_j, col_i), rhs, scratch);
});
});
(self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| {
(0..self.rank_in()).for_each(|col_j| {
(self.rows().min(lhs.rows()).into()..self.rows().into()).for_each(|row_i| {
(0..self.rank_in().into()).for_each(|col_j| {
self.at_mut(row_i, col_j).data.zero();
});
});
@@ -110,11 +117,14 @@ impl<DataSelf: DataMut> GGLWESwitchingKey<DataSelf> {
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigNormalize<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
+ VecZnxBigNormalize<B>
+ VecZnxNormalize<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
{
#[cfg(debug_assertions)]
{
use crate::layouts::GLWEInfos;
assert_eq!(
self.rank_out(),
rhs.rank(),
@@ -124,8 +134,8 @@ impl<DataSelf: DataMut> GGLWESwitchingKey<DataSelf> {
);
}
(0..self.rank_in()).for_each(|col_i| {
(0..self.rows()).for_each(|row_j| {
(0..self.rank_in().into()).for_each(|col_i| {
(0..self.rows().into()).for_each(|row_j| {
self.at_mut(row_j, col_i)
.external_product_inplace(module, rhs, scratch);
});

View File

@@ -1,42 +1,47 @@
use poulpy_hal::{
api::{
ScratchAvailable, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume,
VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply,
VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd,
VmpApplyDftToDftTmpBytes,
},
layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero},
};
use crate::layouts::{GGSWCiphertext, GLWECiphertext, Infos, prepared::GGSWCiphertextPrepared};
use crate::layouts::{GGSWCiphertext, GGSWInfos, GLWECiphertext, GLWEInfos, prepared::GGSWCiphertextPrepared};
impl GGSWCiphertext<Vec<u8>> {
#[allow(clippy::too_many_arguments)]
pub fn external_product_scratch_space<B: Backend>(
pub fn external_product_scratch_space<B: Backend, OUT, IN, GGSW>(
module: &Module<B>,
basek: usize,
k_out: usize,
k_in: usize,
k_ggsw: usize,
digits: usize,
rank: usize,
out_infos: &OUT,
in_infos: &IN,
apply_infos: &GGSW,
) -> usize
where
OUT: GGSWInfos,
IN: GGSWInfos,
GGSW: GGSWInfos,
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
{
GLWECiphertext::external_product_scratch_space(module, basek, k_out, k_in, k_ggsw, digits, rank)
GLWECiphertext::external_product_scratch_space(
module,
&out_infos.glwe_layout(),
&in_infos.glwe_layout(),
apply_infos,
)
}
pub fn external_product_inplace_scratch_space<B: Backend>(
pub fn external_product_inplace_scratch_space<B: Backend, OUT, GGSW>(
module: &Module<B>,
basek: usize,
k_out: usize,
k_ggsw: usize,
digits: usize,
rank: usize,
out_infos: &OUT,
apply_infos: &GGSW,
) -> usize
where
OUT: GGSWInfos,
GGSW: GGSWInfos,
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
{
GLWECiphertext::external_product_inplace_scratch_space(module, basek, k_out, k_ggsw, digits, rank)
GLWECiphertext::external_product_inplace_scratch_space(module, &out_infos.glwe_layout(), apply_infos)
}
}
@@ -55,12 +60,13 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigNormalize<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
+ VecZnxBigNormalize<B>
+ VecZnxNormalize<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
{
#[cfg(debug_assertions)]
{
use crate::layouts::Infos;
use crate::layouts::LWEInfos;
assert_eq!(lhs.n(), self.n());
assert_eq!(rhs.n(), self.n());
@@ -80,28 +86,17 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
rhs.rank()
);
assert!(
scratch.available()
>= GGSWCiphertext::external_product_scratch_space(
module,
self.basek(),
self.k(),
lhs.k(),
rhs.k(),
rhs.digits(),
rhs.rank()
)
)
assert!(scratch.available() >= GGSWCiphertext::external_product_scratch_space(module, self, lhs, rhs))
}
let min_rows: usize = self.rows().min(lhs.rows());
let min_rows: usize = self.rows().min(lhs.rows()).into();
(0..self.rank() + 1).for_each(|col_i| {
(0..(self.rank() + 1).into()).for_each(|col_i| {
(0..min_rows).for_each(|row_j| {
self.at_mut(row_j, col_i)
.external_product(module, &lhs.at(row_j, col_i), rhs, scratch);
});
(min_rows..self.rows()).for_each(|row_i| {
(min_rows..self.rows().into()).for_each(|row_i| {
self.at_mut(row_i, col_i).data.zero();
});
});
@@ -120,11 +115,14 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigNormalize<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
+ VecZnxBigNormalize<B>
+ VecZnxNormalize<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
{
#[cfg(debug_assertions)]
{
use crate::layouts::LWEInfos;
assert_eq!(rhs.n(), self.n());
assert_eq!(
self.rank(),
@@ -135,8 +133,8 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
);
}
(0..self.rank() + 1).for_each(|col_i| {
(0..self.rows()).for_each(|row_j| {
(0..(self.rank() + 1).into()).for_each(|col_i| {
(0..self.rows().into()).for_each(|row_j| {
self.at_mut(row_j, col_i)
.external_product_inplace(module, rhs, scratch);
});

View File

@@ -1,56 +1,65 @@
use poulpy_hal::{
api::{
ScratchAvailable, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume,
VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply,
VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd,
VmpApplyDftToDftTmpBytes,
},
layouts::{Backend, DataMut, DataRef, DataViewMut, Module, Scratch, VecZnxBig},
layouts::{Backend, DataMut, DataRef, DataViewMut, Module, Scratch, VecZnx, VecZnxBig},
};
use crate::layouts::{GLWECiphertext, Infos, prepared::GGSWCiphertextPrepared};
use crate::layouts::{GGSWInfos, GLWECiphertext, GLWEInfos, LWEInfos, prepared::GGSWCiphertextPrepared};
impl GLWECiphertext<Vec<u8>> {
#[allow(clippy::too_many_arguments)]
pub fn external_product_scratch_space<B: Backend>(
pub fn external_product_scratch_space<B: Backend, OUT, IN, GGSW>(
module: &Module<B>,
basek: usize,
k_out: usize,
k_in: usize,
k_ggsw: usize,
digits: usize,
rank: usize,
out_infos: &OUT,
in_infos: &IN,
apply_infos: &GGSW,
) -> usize
where
OUT: GLWEInfos,
IN: GLWEInfos,
GGSW: GGSWInfos,
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
{
let in_size: usize = k_in.div_ceil(basek).div_ceil(digits);
let out_size: usize = k_out.div_ceil(basek);
let ggsw_size: usize = k_ggsw.div_ceil(basek);
let res_dft: usize = module.vec_znx_dft_alloc_bytes(rank + 1, ggsw_size);
let a_dft: usize = module.vec_znx_dft_alloc_bytes(rank + 1, in_size);
let in_size: usize = in_infos
.k()
.div_ceil(apply_infos.base2k())
.div_ceil(apply_infos.digits().into()) as usize;
let out_size: usize = out_infos.size();
let ggsw_size: usize = apply_infos.size();
let res_dft: usize = module.vec_znx_dft_alloc_bytes((apply_infos.rank() + 1).into(), ggsw_size);
let a_dft: usize = module.vec_znx_dft_alloc_bytes((apply_infos.rank() + 1).into(), in_size);
let vmp: usize = module.vmp_apply_dft_to_dft_tmp_bytes(
out_size,
in_size,
in_size, // rows
rank + 1, // cols in
rank + 1, // cols out
in_size, // rows
(apply_infos.rank() + 1).into(), // cols in
(apply_infos.rank() + 1).into(), // cols out
ggsw_size,
);
let normalize: usize = module.vec_znx_normalize_tmp_bytes();
res_dft + a_dft + (vmp | normalize)
let normalize_big: usize = module.vec_znx_normalize_tmp_bytes();
if in_infos.base2k() == apply_infos.base2k() {
res_dft + a_dft + (vmp | normalize_big)
} else {
let normalize_conv: usize = VecZnx::alloc_bytes(module.n(), (apply_infos.rank() + 1).into(), in_size);
res_dft + ((a_dft + normalize_conv + (module.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big)
}
}
pub fn external_product_inplace_scratch_space<B: Backend>(
pub fn external_product_inplace_scratch_space<B: Backend, OUT, GGSW>(
module: &Module<B>,
basek: usize,
k_out: usize,
k_ggsw: usize,
digits: usize,
rank: usize,
out_infos: &OUT,
apply_infos: &GGSW,
) -> usize
where
OUT: GLWEInfos,
GGSW: GGSWInfos,
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
{
Self::external_product_scratch_space(module, basek, k_out, k_out, k_ggsw, digits, rank)
Self::external_product_scratch_space(module, out_infos, out_infos, apply_infos)
}
}
@@ -69,10 +78,13 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigNormalize<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
+ VecZnxBigNormalize<B>
+ VecZnxNormalize<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
{
let basek: usize = self.basek();
let basek_in: usize = lhs.base2k().into();
let basek_ggsw: usize = rhs.base2k().into();
let basek_out: usize = self.base2k().into();
#[cfg(debug_assertions)]
{
@@ -80,34 +92,22 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
assert_eq!(rhs.rank(), lhs.rank());
assert_eq!(rhs.rank(), self.rank());
assert_eq!(self.basek(), basek);
assert_eq!(lhs.basek(), basek);
assert_eq!(rhs.n(), self.n());
assert_eq!(lhs.n(), self.n());
assert!(
scratch.available()
>= GLWECiphertext::external_product_scratch_space(
module,
self.basek(),
self.k(),
lhs.k(),
rhs.k(),
rhs.digits(),
rhs.rank(),
)
);
assert!(scratch.available() >= GLWECiphertext::external_product_scratch_space(module, self, lhs, rhs));
}
let cols: usize = rhs.rank() + 1;
let digits: usize = rhs.digits();
let cols: usize = (rhs.rank() + 1).into();
let digits: usize = rhs.digits().into();
let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), cols, rhs.size()); // Todo optimise
let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self.n(), cols, lhs.size().div_ceil(digits));
let a_size: usize = (lhs.size() * basek_in).div_ceil(basek_ggsw);
let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), cols, rhs.size()); // Todo optimise
let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self.n().into(), cols, a_size.div_ceil(digits));
a_dft.data_mut().fill(0);
{
(0..digits).for_each(|di| {
if basek_in == basek_ggsw {
for di in 0..digits {
// (lhs.size() + di) / digits = (a - (digit - di - 1)).div_ceil(digits)
a_dft.set_size((lhs.size() + di) / digits);
@@ -120,22 +120,68 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
// noise is kept with respect to the ideal functionality.
res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize);
(0..cols).for_each(|col_i| {
module.vec_znx_dft_apply(digits, digits - 1 - di, &mut a_dft, col_i, &lhs.data, col_i);
});
for j in 0..cols {
module.vec_znx_dft_apply(digits, digits - 1 - di, &mut a_dft, j, &lhs.data, j);
}
if di == 0 {
module.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_2);
} else {
module.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_2);
}
});
}
} else {
let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(module.n(), cols, a_size);
for j in 0..cols {
module.vec_znx_normalize(
basek_ggsw,
&mut a_conv,
j,
basek_in,
&lhs.data,
j,
scratch_3,
);
}
for di in 0..digits {
// (lhs.size() + di) / digits = (a - (digit - di - 1)).div_ceil(digits)
a_dft.set_size((a_size + di) / digits);
// Small optimization for digits > 2
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}.
// As such we can ignore the last digits-2 limbs safely of the sum of vmp products.
// It is possible to further ignore the last digits-1 limbs, but this introduce
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
// noise is kept with respect to the ideal functionality.
res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize);
for j in 0..cols {
module.vec_znx_dft_apply(digits, digits - 1 - di, &mut a_dft, j, &a_conv, j);
}
if di == 0 {
module.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_3);
} else {
module.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_3);
}
}
}
let res_big: VecZnxBig<&mut [u8], B> = module.vec_znx_idft_apply_consume(res_dft);
(0..cols).for_each(|i| {
module.vec_znx_big_normalize(basek, &mut self.data, i, &res_big, i, scratch_1);
module.vec_znx_big_normalize(
basek_out,
&mut self.data,
i,
basek_ggsw,
&res_big,
i,
scratch_1,
);
});
}
@@ -152,42 +198,32 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigNormalize<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
+ VecZnxBigNormalize<B>
+ VecZnxNormalize<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
{
let basek: usize = self.basek();
let basek_in: usize = self.base2k().into();
let basek_ggsw: usize = rhs.base2k().into();
#[cfg(debug_assertions)]
{
use poulpy_hal::api::ScratchAvailable;
assert_eq!(rhs.rank(), self.rank());
assert_eq!(self.basek(), basek);
assert_eq!(rhs.n(), self.n());
assert!(
scratch.available()
>= GLWECiphertext::external_product_scratch_space(
module,
self.basek(),
self.k(),
self.k(),
rhs.k(),
rhs.digits(),
rhs.rank(),
)
);
assert!(scratch.available() >= GLWECiphertext::external_product_inplace_scratch_space(module, self, rhs,));
}
let cols: usize = rhs.rank() + 1;
let digits: usize = rhs.digits();
let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), cols, rhs.size()); // Todo optimise
let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self.n(), cols, self.size().div_ceil(digits));
let cols: usize = (rhs.rank() + 1).into();
let digits: usize = rhs.digits().into();
let a_size: usize = (self.size() * basek_in).div_ceil(basek_ggsw);
let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), cols, rhs.size()); // Todo optimise
let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self.n().into(), cols, a_size.div_ceil(digits));
a_dft.data_mut().fill(0);
{
(0..digits).for_each(|di| {
if basek_in == basek_ggsw {
for di in 0..digits {
// (lhs.size() + di) / digits = (a - (digit - di - 1)).div_ceil(digits)
a_dft.set_size((self.size() + di) / digits);
@@ -200,29 +236,68 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
// noise is kept with respect to the ideal functionality.
res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize);
(0..cols).for_each(|col_i| {
module.vec_znx_dft_apply(
digits,
digits - 1 - di,
&mut a_dft,
col_i,
&self.data,
col_i,
);
});
for j in 0..cols {
module.vec_znx_dft_apply(digits, digits - 1 - di, &mut a_dft, j, &self.data, j);
}
if di == 0 {
module.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_2);
} else {
module.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_2);
}
});
}
} else {
let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(module.n(), cols, a_size);
for j in 0..cols {
module.vec_znx_normalize(
basek_ggsw,
&mut a_conv,
j,
basek_in,
&self.data,
j,
scratch_3,
);
}
for di in 0..digits {
// (lhs.size() + di) / digits = (a - (digit - di - 1)).div_ceil(digits)
a_dft.set_size((self.size() + di) / digits);
// Small optimization for digits > 2
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}.
// As such we can ignore the last digits-2 limbs safely of the sum of vmp products.
// It is possible to further ignore the last digits-1 limbs, but this introduce
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
// noise is kept with respect to the ideal functionality.
res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize);
for j in 0..cols {
module.vec_znx_dft_apply(digits, digits - 1 - di, &mut a_dft, j, &self.data, j);
}
if di == 0 {
module.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_2);
} else {
module.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_2);
}
}
}
let res_big: VecZnxBig<&mut [u8], B> = module.vec_znx_idft_apply_consume(res_dft);
(0..cols).for_each(|i| {
module.vec_znx_big_normalize(basek, &mut self.data, i, &res_big, i, scratch_1);
});
for j in 0..cols {
module.vec_znx_big_normalize(
basek_in,
&mut self.data,
j,
basek_ggsw,
&res_big,
j,
scratch_1,
);
}
}
}