mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
Distinguish between gglwe_to_ggsw key and tensor_key + update key repreentation
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
ModuleN, ScratchAvailable, ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
|
||||
VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft,
|
||||
VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
|
||||
VecZnxDftApply, VecZnxDftBytesOf, VecZnxDftCopy, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes,
|
||||
VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, ZnxInfos},
|
||||
layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VecZnxDftToRef, VmpPMat, ZnxInfos},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
@@ -45,46 +45,10 @@ impl<D: DataMut> GLWE<D> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<BE: Backend> GLWEKeyswitch<BE> for Module<BE> where
|
||||
Self: Sized
|
||||
+ ModuleN
|
||||
+ VecZnxDftBytesOf
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxDftBytesOf
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDft<BE>
|
||||
+ VmpApplyDftToDftAdd<BE>
|
||||
+ VecZnxDftApply<BE>
|
||||
+ VecZnxIdftApplyConsume<BE>
|
||||
+ VecZnxBigAddSmallInplace<BE>
|
||||
+ VecZnxBigNormalize<BE>
|
||||
+ VecZnxNormalize<BE>
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
{
|
||||
}
|
||||
|
||||
pub trait GLWEKeyswitch<BE: Backend>
|
||||
impl<BE: Backend> GLWEKeyswitch<BE> for Module<BE>
|
||||
where
|
||||
Self: Sized
|
||||
+ ModuleN
|
||||
+ VecZnxDftBytesOf
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxDftBytesOf
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDft<BE>
|
||||
+ VmpApplyDftToDftAdd<BE>
|
||||
+ VecZnxDftApply<BE>
|
||||
+ VecZnxIdftApplyConsume<BE>
|
||||
+ VecZnxBigAddSmallInplace<BE>
|
||||
+ VecZnxBigNormalize<BE>
|
||||
+ VecZnxNormalize<BE>
|
||||
+ VecZnxNormalizeTmpBytes,
|
||||
Self: Sized + GLWEKeySwitchInternal<BE> + VecZnxBigNormalizeTmpBytes + VecZnxBigNormalize<BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
fn glwe_keyswitch_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, key_infos: &B) -> usize
|
||||
where
|
||||
@@ -92,34 +56,10 @@ where
|
||||
A: GLWEInfos,
|
||||
B: GGLWEInfos,
|
||||
{
|
||||
let in_size: usize = a_infos
|
||||
.k()
|
||||
.div_ceil(key_infos.base2k())
|
||||
.div_ceil(key_infos.dsize().into()) as usize;
|
||||
let out_size: usize = res_infos.size();
|
||||
let ksk_size: usize = key_infos.size();
|
||||
let res_dft: usize = self.bytes_of_vec_znx_dft((key_infos.rank_out() + 1).into(), ksk_size); // TODO OPTIMIZE
|
||||
let ai_dft: usize = self.bytes_of_vec_znx_dft((key_infos.rank_in()).into(), in_size);
|
||||
let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes(
|
||||
out_size,
|
||||
in_size,
|
||||
in_size,
|
||||
(key_infos.rank_in()).into(),
|
||||
(key_infos.rank_out() + 1).into(),
|
||||
ksk_size,
|
||||
) + self.bytes_of_vec_znx_dft((key_infos.rank_in()).into(), in_size);
|
||||
let normalize_big: usize = self.vec_znx_big_normalize_tmp_bytes();
|
||||
if a_infos.base2k() == key_infos.base2k() {
|
||||
res_dft + ((ai_dft + vmp) | normalize_big)
|
||||
} else if key_infos.dsize() == 1 {
|
||||
// In this case, we only need one column, temporary, that we can drop once a_dft is computed.
|
||||
let normalize_conv: usize = VecZnx::bytes_of(self.n(), 1, in_size) + self.vec_znx_normalize_tmp_bytes();
|
||||
res_dft + (((ai_dft + normalize_conv) | vmp) | normalize_big)
|
||||
} else {
|
||||
// Since we stride over a to get a_dft when dsize > 1, we need to store the full columns of a with in the base conversion.
|
||||
let normalize_conv: usize = VecZnx::bytes_of(self.n(), (key_infos.rank_in()).into(), in_size);
|
||||
res_dft + ((ai_dft + normalize_conv + (self.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big)
|
||||
}
|
||||
let cols: usize = res_infos.rank().as_usize() + 1;
|
||||
self.glwe_keyswitch_internal_tmp_bytes(res_infos, a_infos, key_infos)
|
||||
.max(self.vec_znx_big_normalize_tmp_bytes())
|
||||
+ self.bytes_of_vec_znx_dft(cols, key_infos.size())
|
||||
}
|
||||
|
||||
fn glwe_keyswitch<R, A, K>(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch<BE>)
|
||||
@@ -127,7 +67,6 @@ where
|
||||
R: GLWEToMut,
|
||||
A: GLWEToRef,
|
||||
K: GGLWEPreparedToRef<BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
|
||||
let a: &GLWE<&[u8]> = &a.to_ref();
|
||||
@@ -164,8 +103,8 @@ where
|
||||
let base2k_out: usize = b.base2k().into();
|
||||
|
||||
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), b.size()); // Todo optimise
|
||||
let res_big: VecZnxBig<&mut [u8], BE> = keyswitch_internal(self, res_dft, a, b, scratch_1);
|
||||
(0..(res.rank() + 1).into()).for_each(|i| {
|
||||
let res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, a, b, scratch_1);
|
||||
for i in 0..(res.rank() + 1).into() {
|
||||
self.vec_znx_big_normalize(
|
||||
basek_out,
|
||||
&mut res.data,
|
||||
@@ -175,37 +114,36 @@ where
|
||||
i,
|
||||
scratch_1,
|
||||
);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn glwe_keyswitch_inplace<R, K>(&self, res: &mut R, key: &K, scratch: &mut Scratch<BE>)
|
||||
where
|
||||
R: GLWEToMut,
|
||||
K: GGLWEPreparedToRef<BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
|
||||
let a: &GGLWEPrepared<&[u8], BE> = &key.to_ref();
|
||||
let key: &GGLWEPrepared<&[u8], BE> = &key.to_ref();
|
||||
|
||||
assert_eq!(
|
||||
res.rank(),
|
||||
a.rank_in(),
|
||||
key.rank_in(),
|
||||
"res.rank(): {} != a.rank_in(): {}",
|
||||
res.rank(),
|
||||
a.rank_in()
|
||||
key.rank_in()
|
||||
);
|
||||
assert_eq!(
|
||||
res.rank(),
|
||||
a.rank_out(),
|
||||
key.rank_out(),
|
||||
"res.rank(): {} != b.rank_out(): {}",
|
||||
res.rank(),
|
||||
a.rank_out()
|
||||
key.rank_out()
|
||||
);
|
||||
|
||||
assert_eq!(res.n(), self.n() as u32);
|
||||
assert_eq!(a.n(), self.n() as u32);
|
||||
assert_eq!(key.n(), self.n() as u32);
|
||||
|
||||
let scrach_needed: usize = self.glwe_keyswitch_tmp_bytes(res, res, a);
|
||||
let scrach_needed: usize = self.glwe_keyswitch_tmp_bytes(res, res, key);
|
||||
|
||||
assert!(
|
||||
scratch.available() >= scrach_needed,
|
||||
@@ -214,11 +152,11 @@ where
|
||||
);
|
||||
|
||||
let base2k_in: usize = res.base2k().into();
|
||||
let base2k_out: usize = a.base2k().into();
|
||||
let base2k_out: usize = key.base2k().into();
|
||||
|
||||
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), a.size()); // Todo optimise
|
||||
let res_big: VecZnxBig<&mut [u8], BE> = keyswitch_internal(self, res_dft, res, a, scratch_1);
|
||||
(0..(res.rank() + 1).into()).for_each(|i| {
|
||||
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // Todo optimise
|
||||
let res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1);
|
||||
for i in 0..(res.rank() + 1).into() {
|
||||
self.vec_znx_big_normalize(
|
||||
base2k_in,
|
||||
&mut res.data,
|
||||
@@ -228,143 +166,235 @@ where
|
||||
i,
|
||||
scratch_1,
|
||||
);
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl GLWE<Vec<u8>> {}
|
||||
pub trait GLWEKeyswitch<BE: Backend> {
|
||||
fn glwe_keyswitch_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, key_infos: &B) -> usize
|
||||
where
|
||||
R: GLWEInfos,
|
||||
A: GLWEInfos,
|
||||
B: GGLWEInfos;
|
||||
|
||||
impl<DataSelf: DataMut> GLWE<DataSelf> {}
|
||||
fn glwe_keyswitch<R, A, K>(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch<BE>)
|
||||
where
|
||||
R: GLWEToMut,
|
||||
A: GLWEToRef,
|
||||
K: GGLWEPreparedToRef<BE>;
|
||||
|
||||
pub(crate) fn keyswitch_internal<BE: Backend, M, DR, A, K>(
|
||||
module: &M,
|
||||
mut res: VecZnxDft<DR, BE>,
|
||||
a: &A,
|
||||
key: &K,
|
||||
scratch: &mut Scratch<BE>,
|
||||
) -> VecZnxBig<DR, BE>
|
||||
where
|
||||
DR: DataMut,
|
||||
A: GLWEToRef,
|
||||
K: GGLWEPreparedToRef<BE>,
|
||||
M: ModuleN
|
||||
+ VecZnxDftBytesOf
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VmpApplyDftToDft<BE>
|
||||
+ VmpApplyDftToDftAdd<BE>
|
||||
fn glwe_keyswitch_inplace<R, K>(&self, res: &mut R, key: &K, scratch: &mut Scratch<BE>)
|
||||
where
|
||||
R: GLWEToMut,
|
||||
K: GGLWEPreparedToRef<BE>;
|
||||
}
|
||||
|
||||
impl<BE: Backend> GLWEKeySwitchInternal<BE> for Module<BE> where
|
||||
Self: GGLWEProduct<BE>
|
||||
+ VecZnxDftApply<BE>
|
||||
+ VecZnxNormalize<BE>
|
||||
+ VecZnxIdftApplyConsume<BE>
|
||||
+ VecZnxBigAddSmallInplace<BE>
|
||||
+ VecZnxBigNormalize<BE>
|
||||
+ VecZnxNormalize<BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
{
|
||||
let a: &GLWE<&[u8]> = &a.to_ref();
|
||||
let key: &GGLWEPrepared<&[u8], BE> = &key.to_ref();
|
||||
}
|
||||
|
||||
let base2k_in: usize = a.base2k().into();
|
||||
let base2k_out: usize = key.base2k().into();
|
||||
let cols: usize = (a.rank() + 1).into();
|
||||
let a_size: usize = (a.size() * base2k_in).div_ceil(base2k_out);
|
||||
let pmat: &VmpPMat<&[u8], BE> = &key.data;
|
||||
pub(crate) trait GLWEKeySwitchInternal<BE: Backend>
|
||||
where
|
||||
Self: GGLWEProduct<BE>
|
||||
+ VecZnxDftApply<BE>
|
||||
+ VecZnxNormalize<BE>
|
||||
+ VecZnxIdftApplyConsume<BE>
|
||||
+ VecZnxBigAddSmallInplace<BE>
|
||||
+ VecZnxNormalizeTmpBytes,
|
||||
{
|
||||
fn glwe_keyswitch_internal_tmp_bytes<R, A, K>(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize
|
||||
where
|
||||
R: GLWEInfos,
|
||||
A: GLWEInfos,
|
||||
K: GGLWEInfos,
|
||||
{
|
||||
let cols: usize = (a_infos.rank() + 1).into();
|
||||
let a_size: usize = a_infos.size();
|
||||
|
||||
if key.dsize() == 1 {
|
||||
let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols - 1, a.size());
|
||||
let a_conv = if a_infos.base2k() == key_infos.base2k() {
|
||||
0
|
||||
} else {
|
||||
VecZnx::bytes_of(self.n(), 1, a_size) + self.vec_znx_normalize_tmp_bytes()
|
||||
};
|
||||
|
||||
self.gglwe_product_dft_tmp_bytes(res_infos.size(), a_size, key_infos) + self.bytes_of_vec_znx_dft(cols, a_size) + a_conv
|
||||
}
|
||||
|
||||
fn glwe_keyswitch_internal<DR, A, K>(
|
||||
&self,
|
||||
mut res: VecZnxDft<DR, BE>,
|
||||
a: &A,
|
||||
key: &K,
|
||||
scratch: &mut Scratch<BE>,
|
||||
) -> VecZnxBig<DR, BE>
|
||||
where
|
||||
DR: DataMut,
|
||||
A: GLWEToRef,
|
||||
K: GGLWEPreparedToRef<BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
let a: &GLWE<&[u8]> = &a.to_ref();
|
||||
let key: &GGLWEPrepared<&[u8], BE> = &key.to_ref();
|
||||
|
||||
let base2k_in: usize = a.base2k().into();
|
||||
let base2k_out: usize = key.base2k().into();
|
||||
let cols: usize = (a.rank() + 1).into();
|
||||
let a_size: usize = (a.size() * base2k_in).div_ceil(base2k_out);
|
||||
|
||||
let (mut a_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols - 1, a_size);
|
||||
|
||||
if base2k_in == base2k_out {
|
||||
(0..cols - 1).for_each(|col_i| {
|
||||
module.vec_znx_dft_apply(1, 0, &mut ai_dft, col_i, a.data(), col_i + 1);
|
||||
});
|
||||
for col_i in 0..cols - 1 {
|
||||
self.vec_znx_dft_apply(1, 0, &mut a_dft, col_i, a.data(), col_i + 1);
|
||||
}
|
||||
} else {
|
||||
let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(module.n(), 1, a_size);
|
||||
(0..cols - 1).for_each(|col_i| {
|
||||
module.vec_znx_normalize(
|
||||
let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(self.n(), 1, a_size);
|
||||
for i in 0..cols - 1 {
|
||||
self.vec_znx_normalize(
|
||||
base2k_out,
|
||||
&mut a_conv,
|
||||
0,
|
||||
base2k_in,
|
||||
a.data(),
|
||||
col_i + 1,
|
||||
i + 1,
|
||||
scratch_2,
|
||||
);
|
||||
module.vec_znx_dft_apply(1, 0, &mut ai_dft, col_i, &a_conv, 0);
|
||||
});
|
||||
self.vec_znx_dft_apply(1, 0, &mut a_dft, i, &a_conv, 0);
|
||||
}
|
||||
}
|
||||
|
||||
module.vmp_apply_dft_to_dft(&mut res, &ai_dft, pmat, scratch_1);
|
||||
} else {
|
||||
let dsize: usize = key.dsize().into();
|
||||
self.gglwe_product_dft(&mut res, &a_dft, key, scratch_1);
|
||||
|
||||
let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols - 1, a_size.div_ceil(dsize));
|
||||
ai_dft.data_mut().fill(0);
|
||||
let mut res_big: VecZnxBig<DR, BE> = self.vec_znx_idft_apply_consume(res);
|
||||
self.vec_znx_big_add_small_inplace(&mut res_big, 0, a.data(), 0);
|
||||
res_big
|
||||
}
|
||||
}
|
||||
|
||||
if base2k_in == base2k_out {
|
||||
for di in 0..dsize {
|
||||
ai_dft.set_size((a_size + di) / dsize);
|
||||
impl<BE: Backend> GGLWEProduct<BE> for Module<BE> where
|
||||
Self: Sized
|
||||
+ ModuleN
|
||||
+ VecZnxDftBytesOf
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VmpApplyDftToDft<BE>
|
||||
+ VmpApplyDftToDftAdd<BE>
|
||||
+ VecZnxDftCopy<BE>
|
||||
{
|
||||
}
|
||||
|
||||
// Small optimization for dsize > 2
|
||||
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
|
||||
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
|
||||
// As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
|
||||
// It is possible to further ignore the last dsize-1 limbs, but this introduce
|
||||
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
|
||||
// noise is kept with respect to the ideal functionality.
|
||||
res.set_size(pmat.size() - ((dsize - di) as isize - 2).max(0) as usize);
|
||||
pub(crate) trait GGLWEProduct<BE: Backend>
|
||||
where
|
||||
Self: Sized
|
||||
+ ModuleN
|
||||
+ VecZnxDftBytesOf
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VmpApplyDftToDft<BE>
|
||||
+ VmpApplyDftToDftAdd<BE>
|
||||
+ VecZnxDftCopy<BE>,
|
||||
{
|
||||
fn gglwe_product_dft_tmp_bytes<K>(&self, res_size: usize, a_size: usize, key_infos: &K) -> usize
|
||||
where
|
||||
K: GGLWEInfos,
|
||||
{
|
||||
let dsize: usize = key_infos.dsize().as_usize();
|
||||
|
||||
for j in 0..cols - 1 {
|
||||
module.vec_znx_dft_apply(dsize, dsize - di - 1, &mut ai_dft, j, a.data(), j + 1);
|
||||
}
|
||||
|
||||
if di == 0 {
|
||||
module.vmp_apply_dft_to_dft(&mut res, &ai_dft, pmat, scratch_1);
|
||||
} else {
|
||||
module.vmp_apply_dft_to_dft_add(&mut res, &ai_dft, pmat, di, scratch_1);
|
||||
}
|
||||
}
|
||||
if dsize == 1 {
|
||||
self.vmp_apply_dft_to_dft_tmp_bytes(
|
||||
res_size,
|
||||
a_size,
|
||||
key_infos.dnum().into(),
|
||||
(key_infos.rank_in()).into(),
|
||||
(key_infos.rank_out() + 1).into(),
|
||||
key_infos.size(),
|
||||
)
|
||||
} else {
|
||||
let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(module.n(), cols - 1, a_size);
|
||||
for j in 0..cols - 1 {
|
||||
module.vec_znx_normalize(
|
||||
base2k_out,
|
||||
&mut a_conv,
|
||||
j,
|
||||
base2k_in,
|
||||
a.data(),
|
||||
j + 1,
|
||||
scratch_2,
|
||||
);
|
||||
}
|
||||
let dnum: usize = key_infos.dnum().into();
|
||||
let a_size: usize = a_size.div_ceil(dsize).min(dnum);
|
||||
let ai_dft: usize = self.bytes_of_vec_znx_dft(key_infos.rank_in().into(), a_size);
|
||||
|
||||
for di in 0..dsize {
|
||||
ai_dft.set_size((a_size + di) / dsize);
|
||||
let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes(
|
||||
res_size,
|
||||
a_size,
|
||||
dnum,
|
||||
(key_infos.rank_in()).into(),
|
||||
(key_infos.rank_out() + 1).into(),
|
||||
key_infos.size(),
|
||||
);
|
||||
|
||||
// Small optimization for dsize > 2
|
||||
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
|
||||
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
|
||||
// As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
|
||||
// It is possible to further ignore the last dsize-1 limbs, but this introduce
|
||||
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
|
||||
// noise is kept with respect to the ideal functionality.
|
||||
res.set_size(pmat.size() - ((dsize - di) as isize - 2).max(0) as usize);
|
||||
|
||||
for j in 0..cols - 1 {
|
||||
module.vec_znx_dft_apply(dsize, dsize - di - 1, &mut ai_dft, j, &a_conv, j);
|
||||
}
|
||||
|
||||
if di == 0 {
|
||||
module.vmp_apply_dft_to_dft(&mut res, &ai_dft, pmat, scratch_2);
|
||||
} else {
|
||||
module.vmp_apply_dft_to_dft_add(&mut res, &ai_dft, pmat, di, scratch_2);
|
||||
}
|
||||
}
|
||||
ai_dft + vmp
|
||||
}
|
||||
|
||||
res.set_size(res.max_size());
|
||||
}
|
||||
|
||||
let mut res_big: VecZnxBig<DR, BE> = module.vec_znx_idft_apply_consume(res);
|
||||
module.vec_znx_big_add_small_inplace(&mut res_big, 0, a.data(), 0);
|
||||
res_big
|
||||
fn gglwe_product_dft<DR, A, K>(&self, res: &mut VecZnxDft<DR, BE>, a: &A, key: &K, scratch: &mut Scratch<BE>)
|
||||
where
|
||||
DR: DataMut,
|
||||
A: VecZnxDftToRef<BE>,
|
||||
K: GGLWEPreparedToRef<BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
let a: &VecZnxDft<&[u8], BE> = &a.to_ref();
|
||||
let key: &GGLWEPrepared<&[u8], BE> = &key.to_ref();
|
||||
|
||||
let cols: usize = a.cols();
|
||||
let a_size: usize = a.size();
|
||||
let pmat: &VmpPMat<&[u8], BE> = &key.data;
|
||||
|
||||
// If dsize == 1, then the digit decomposition is equal to Base2K and we can simply
|
||||
// can the vmp API.
|
||||
if key.dsize() == 1 {
|
||||
self.vmp_apply_dft_to_dft(res, a, pmat, scratch);
|
||||
// If dsize != 1, then the digit decomposition is k * Base2K with k > 1.
|
||||
// As such we need to perform a bivariate polynomial convolution in (X, Y) / (X^{N}+1) with Y = 2^-K
|
||||
// (instead of yn univariate one in X).
|
||||
//
|
||||
// Since the basis in Y is small (in practice degree 6-7 max), we perform it naiveley.
|
||||
// To do so, we group the different limbs of ai_dft by their respective degree in Y
|
||||
// which are multiples of the current digit.
|
||||
// For example if dsize = 3, with ai_dft = [a0, a1, a2, a3, a4, a5, a6],
|
||||
// we group them as [[a0, a3, a5], [a1, a4, a6], [a2, a5, 0]]
|
||||
// and evaluate sum(a_di * pmat * 2^{di*Base2k})
|
||||
} else {
|
||||
let dsize: usize = key.dsize().into();
|
||||
let dnum: usize = key.dnum().into();
|
||||
|
||||
// We bound ai_dft size by the number of rows of the matrix
|
||||
let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, a_size.div_ceil(dsize).min(dnum));
|
||||
ai_dft.data_mut().fill(0);
|
||||
|
||||
for di in 0..dsize {
|
||||
// Sets ai_dft size according to the current digit (if dsize does not divides a_size),
|
||||
// bounded by the number of rows (digits) in the prepared matrix.
|
||||
ai_dft.set_size(((a_size + di) / dsize).min(dnum));
|
||||
|
||||
// Small optimization for dsize > 2
|
||||
// VMP produce some error e, and since we aggregate vmp * 2^{di * Base2k}, then
|
||||
// we also aggregate ei * 2^{di * Base2k}, with the largest error being ei * 2^{(dsize-1) * Base2k}.
|
||||
// As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
|
||||
// It is possible to further ignore the last dsize-1 limbs, but this introduce
|
||||
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
|
||||
// noise is kept with respect to the ideal functionality.
|
||||
res.set_size(pmat.size() - ((dsize - di) as isize - 2).max(0) as usize);
|
||||
|
||||
for j in 0..cols {
|
||||
self.vec_znx_dft_copy(dsize, dsize - di - 1, &mut ai_dft, j, a, j);
|
||||
}
|
||||
|
||||
if di == 0 {
|
||||
// res = pmat * ai_dft
|
||||
self.vmp_apply_dft_to_dft(res, &ai_dft, pmat, scratch_1);
|
||||
} else {
|
||||
// res = (pmat * ai_dft) * 2^{di * Base2k}
|
||||
self.vmp_apply_dft_to_dft_add(res, &ai_dft, pmat, di, scratch_1);
|
||||
}
|
||||
}
|
||||
|
||||
res.set_size(res.max_size());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user