mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
Improve cmux speed
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
use poulpy_hal::{
|
||||
api::ScratchAvailable,
|
||||
api::{ModuleN, ScratchAvailable},
|
||||
layouts::{Backend, DataMut, Module, Scratch, ZnxZero},
|
||||
};
|
||||
|
||||
@@ -13,7 +13,7 @@ use crate::{
|
||||
|
||||
pub trait GGSWExternalProduct<BE: Backend>
|
||||
where
|
||||
Self: GLWEExternalProduct<BE>,
|
||||
Self: GLWEExternalProduct<BE> + ModuleN,
|
||||
{
|
||||
fn ggsw_external_product_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
|
||||
where
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
ModuleN, ScratchTakeBasic, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize,
|
||||
VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
|
||||
ModuleN, ScratchTakeBasic, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftApply, VecZnxDftBytesOf,
|
||||
VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd,
|
||||
VmpApplyDftToDftTmpBytes,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnx, VecZnxBig},
|
||||
layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
@@ -30,7 +31,7 @@ impl<DataSelf: DataMut> GLWE<DataSelf> {
|
||||
pub fn external_product<A, B, M, BE: Backend>(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch<BE>)
|
||||
where
|
||||
A: GLWEToRef,
|
||||
B: GGSWPreparedToRef<BE>,
|
||||
B: GGSWPreparedToRef<BE> + GGSWInfos,
|
||||
M: GLWEExternalProduct<BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
@@ -39,7 +40,7 @@ impl<DataSelf: DataMut> GLWE<DataSelf> {
|
||||
|
||||
pub fn external_product_inplace<A, M, BE: Backend>(&mut self, module: &M, a: &A, scratch: &mut Scratch<BE>)
|
||||
where
|
||||
A: GGSWPreparedToRef<BE>,
|
||||
A: GGSWPreparedToRef<BE> + GGSWInfos,
|
||||
M: GLWEExternalProduct<BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
@@ -47,19 +48,30 @@ impl<DataSelf: DataMut> GLWE<DataSelf> {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait GLWEExternalProduct<BE: Backend>
|
||||
pub trait GLWEExternalProduct<BE: Backend> {
|
||||
fn glwe_external_product_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
|
||||
where
|
||||
Self: Sized
|
||||
+ ModuleN
|
||||
+ VecZnxDftBytesOf
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxDftApply<BE>
|
||||
+ VmpApplyDftToDft<BE>
|
||||
+ VmpApplyDftToDftAdd<BE>
|
||||
+ VecZnxIdftApplyConsume<BE>
|
||||
+ VecZnxBigNormalize<BE>
|
||||
+ VecZnxNormalize<BE>,
|
||||
R: GLWEInfos,
|
||||
A: GLWEInfos,
|
||||
B: GGSWInfos;
|
||||
|
||||
fn glwe_external_product_inplace<R, D>(&self, res: &mut R, a: &D, scratch: &mut Scratch<BE>)
|
||||
where
|
||||
R: GLWEToMut,
|
||||
D: GGSWPreparedToRef<BE> + GGSWInfos,
|
||||
Scratch<BE>: ScratchTakeCore<BE>;
|
||||
|
||||
fn glwe_external_product<R, A, D>(&self, res: &mut R, lhs: &A, rhs: &D, scratch: &mut Scratch<BE>)
|
||||
where
|
||||
R: GLWEToMut,
|
||||
A: GLWEToRef,
|
||||
D: GGSWPreparedToRef<BE> + GGSWInfos,
|
||||
Scratch<BE>: ScratchTakeCore<BE>;
|
||||
}
|
||||
|
||||
impl<BE: Backend> GLWEExternalProduct<BE> for Module<BE>
|
||||
where
|
||||
Self: GLWEExternalProductInternal<BE> + VecZnxDftBytesOf + VecZnxBigNormalize<BE> + VecZnxBigNormalizeTmpBytes,
|
||||
{
|
||||
fn glwe_external_product_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
|
||||
where
|
||||
@@ -67,36 +79,17 @@ where
|
||||
A: GLWEInfos,
|
||||
B: GGSWInfos,
|
||||
{
|
||||
let in_size: usize = a_infos
|
||||
.k()
|
||||
.div_ceil(b_infos.base2k())
|
||||
.div_ceil(b_infos.dsize().into()) as usize;
|
||||
let out_size: usize = res_infos.size();
|
||||
let ggsw_size: usize = b_infos.size();
|
||||
let res_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), ggsw_size);
|
||||
let a_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), in_size);
|
||||
let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes(
|
||||
out_size,
|
||||
in_size,
|
||||
in_size, // rows
|
||||
(b_infos.rank() + 1).into(), // cols in
|
||||
(b_infos.rank() + 1).into(), // cols out
|
||||
ggsw_size,
|
||||
);
|
||||
let normalize_big: usize = self.vec_znx_normalize_tmp_bytes();
|
||||
|
||||
if a_infos.base2k() == b_infos.base2k() {
|
||||
res_dft + a_dft + (vmp | normalize_big)
|
||||
} else {
|
||||
let normalize_conv: usize = VecZnx::bytes_of(self.n(), (b_infos.rank() + 1).into(), in_size);
|
||||
res_dft + ((a_dft + normalize_conv + (self.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big)
|
||||
}
|
||||
let res_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), b_infos.size());
|
||||
res_dft
|
||||
+ self
|
||||
.glwe_external_product_internal_tmp_bytes(res_infos, a_infos, b_infos)
|
||||
.max(self.vec_znx_big_normalize_tmp_bytes())
|
||||
}
|
||||
|
||||
fn glwe_external_product_inplace<R, D>(&self, res: &mut R, a: &D, scratch: &mut Scratch<BE>)
|
||||
where
|
||||
R: GLWEToMut,
|
||||
D: GGSWPreparedToRef<BE>,
|
||||
D: GGSWPreparedToRef<BE> + GGSWInfos,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
|
||||
@@ -114,81 +107,9 @@ where
|
||||
assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, res, rhs));
|
||||
}
|
||||
|
||||
let cols: usize = (rhs.rank() + 1).into();
|
||||
let dsize: usize = rhs.dsize().into();
|
||||
let a_size: usize = (res.size() * basek_in).div_ceil(basek_ggsw);
|
||||
|
||||
let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, rhs.size()); // Todo optimise
|
||||
let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, cols, a_size.div_ceil(dsize));
|
||||
a_dft.data_mut().fill(0);
|
||||
|
||||
if basek_in == basek_ggsw {
|
||||
for di in 0..dsize {
|
||||
// (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize)
|
||||
a_dft.set_size((res.size() + di) / dsize);
|
||||
|
||||
// Small optimization for dsize > 2
|
||||
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
|
||||
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
|
||||
// As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
|
||||
// It is possible to further ignore the last dsize-1 limbs, but this introduce
|
||||
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
|
||||
// noise is kept with respect to the ideal functionality.
|
||||
res_dft.set_size(rhs.size() - ((dsize - di) as isize - 2).max(0) as usize);
|
||||
|
||||
for j in 0..cols {
|
||||
self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &res.data, j);
|
||||
}
|
||||
|
||||
if di == 0 {
|
||||
self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_2);
|
||||
} else {
|
||||
self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_2);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self.n(), cols, a_size);
|
||||
|
||||
for j in 0..cols {
|
||||
self.vec_znx_normalize(
|
||||
basek_ggsw,
|
||||
&mut a_conv,
|
||||
j,
|
||||
basek_in,
|
||||
&res.data,
|
||||
j,
|
||||
scratch_3,
|
||||
);
|
||||
}
|
||||
|
||||
for di in 0..dsize {
|
||||
// (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize)
|
||||
a_dft.set_size((res.size() + di) / dsize);
|
||||
|
||||
// Small optimization for dsize > 2
|
||||
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
|
||||
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
|
||||
// As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
|
||||
// It is possible to further ignore the last dsize-1 limbs, but this introduce
|
||||
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
|
||||
// noise is kept with respect to the ideal functionality.
|
||||
res_dft.set_size(rhs.size() - ((dsize - di) as isize - 2).max(0) as usize);
|
||||
|
||||
for j in 0..cols {
|
||||
self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &res.data, j);
|
||||
}
|
||||
|
||||
if di == 0 {
|
||||
self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_2);
|
||||
} else {
|
||||
self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft);
|
||||
|
||||
for j in 0..cols {
|
||||
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), a.size()); // Todo optimise
|
||||
let res_big = self.glwe_external_product_internal(res_dft, res, a, scratch_1);
|
||||
for j in 0..(res.rank() + 1).into() {
|
||||
self.vec_znx_big_normalize(
|
||||
basek_in,
|
||||
&mut res.data,
|
||||
@@ -213,7 +134,6 @@ where
|
||||
|
||||
let rhs: &GGSWPrepared<&[u8], BE> = &rhs.to_ref();
|
||||
|
||||
let basek_in: usize = lhs.base2k().into();
|
||||
let basek_ggsw: usize = rhs.base2k().into();
|
||||
let basek_out: usize = res.base2k().into();
|
||||
|
||||
@@ -228,96 +148,45 @@ where
|
||||
assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, lhs, rhs));
|
||||
}
|
||||
|
||||
let cols: usize = (rhs.rank() + 1).into();
|
||||
let dsize: usize = rhs.dsize().into();
|
||||
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), rhs.size()); // Todo optimise
|
||||
let res_big = self.glwe_external_product_internal(res_dft, lhs, rhs, scratch_1);
|
||||
|
||||
let a_size: usize = (lhs.size() * basek_in).div_ceil(basek_ggsw);
|
||||
|
||||
let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, rhs.size()); // Todo optimise
|
||||
let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, cols, a_size.div_ceil(dsize));
|
||||
a_dft.data_mut().fill(0);
|
||||
|
||||
if basek_in == basek_ggsw {
|
||||
for di in 0..dsize {
|
||||
// (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize)
|
||||
a_dft.set_size((lhs.size() + di) / dsize);
|
||||
|
||||
// Small optimization for dsize > 2
|
||||
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
|
||||
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
|
||||
// As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
|
||||
// It is possible to further ignore the last dsize-1 limbs, but this introduce
|
||||
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
|
||||
// noise is kept with respect to the ideal functionality.
|
||||
res_dft.set_size(rhs.size() - ((dsize - di) as isize - 2).max(0) as usize);
|
||||
|
||||
for j in 0..cols {
|
||||
self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &lhs.data, j);
|
||||
}
|
||||
|
||||
if di == 0 {
|
||||
self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_2);
|
||||
} else {
|
||||
self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_2);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self.n(), cols, a_size);
|
||||
|
||||
for j in 0..cols {
|
||||
self.vec_znx_normalize(
|
||||
basek_ggsw,
|
||||
&mut a_conv,
|
||||
j,
|
||||
basek_in,
|
||||
&lhs.data,
|
||||
j,
|
||||
scratch_3,
|
||||
);
|
||||
}
|
||||
|
||||
for di in 0..dsize {
|
||||
// (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize)
|
||||
a_dft.set_size((a_size + di) / dsize);
|
||||
|
||||
// Small optimization for dsize > 2
|
||||
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
|
||||
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
|
||||
// As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
|
||||
// It is possible to further ignore the last dsize-1 limbs, but this introduce
|
||||
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
|
||||
// noise is kept with respect to the ideal functionality.
|
||||
res_dft.set_size(rhs.size() - ((dsize - di) as isize - 2).max(0) as usize);
|
||||
|
||||
for j in 0..cols {
|
||||
self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &a_conv, j);
|
||||
}
|
||||
|
||||
if di == 0 {
|
||||
self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_3);
|
||||
} else {
|
||||
self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_3);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft);
|
||||
|
||||
(0..cols).for_each(|i| {
|
||||
for j in 0..(res.rank() + 1).into() {
|
||||
self.vec_znx_big_normalize(
|
||||
basek_out,
|
||||
res.data_mut(),
|
||||
i,
|
||||
&mut res.data,
|
||||
j,
|
||||
basek_ggsw,
|
||||
&res_big,
|
||||
i,
|
||||
j,
|
||||
scratch_1,
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<BE: Backend> GLWEExternalProduct<BE> for Module<BE> where
|
||||
pub trait GLWEExternalProductInternal<BE: Backend> {
|
||||
fn glwe_external_product_internal_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
|
||||
where
|
||||
R: GLWEInfos,
|
||||
A: GLWEInfos,
|
||||
B: GGSWInfos;
|
||||
fn glwe_external_product_internal<DR, A, G>(
|
||||
&self,
|
||||
res_dft: VecZnxDft<DR, BE>,
|
||||
a: &A,
|
||||
ggsw: &G,
|
||||
scratch: &mut Scratch<BE>,
|
||||
) -> VecZnxBig<DR, BE>
|
||||
where
|
||||
DR: DataMut,
|
||||
A: GLWEToRef,
|
||||
G: GGSWPreparedToRef<BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>;
|
||||
}
|
||||
|
||||
impl<BE: Backend> GLWEExternalProductInternal<BE> for Module<BE>
|
||||
where
|
||||
Self: ModuleN
|
||||
+ VecZnxDftBytesOf
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
@@ -330,6 +199,121 @@ impl<BE: Backend> GLWEExternalProduct<BE> for Module<BE> where
|
||||
+ VecZnxNormalize<BE>
|
||||
+ VecZnxDftBytesOf
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxNormalizeTmpBytes,
|
||||
{
|
||||
fn glwe_external_product_internal_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
|
||||
where
|
||||
R: GLWEInfos,
|
||||
A: GLWEInfos,
|
||||
B: GGSWInfos,
|
||||
{
|
||||
let in_size: usize = a_infos
|
||||
.k()
|
||||
.div_ceil(b_infos.base2k())
|
||||
.div_ceil(b_infos.dsize().into()) as usize;
|
||||
let out_size: usize = res_infos.size();
|
||||
let ggsw_size: usize = b_infos.size();
|
||||
let a_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), in_size);
|
||||
let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes(
|
||||
out_size,
|
||||
in_size,
|
||||
in_size, // rows
|
||||
(b_infos.rank() + 1).into(), // cols in
|
||||
(b_infos.rank() + 1).into(), // cols out
|
||||
ggsw_size,
|
||||
);
|
||||
let normalize_big: usize = self.vec_znx_normalize_tmp_bytes();
|
||||
|
||||
if a_infos.base2k() == b_infos.base2k() {
|
||||
a_dft + (vmp | normalize_big)
|
||||
} else {
|
||||
let normalize_conv: usize = VecZnx::bytes_of(self.n(), (b_infos.rank() + 1).into(), in_size);
|
||||
(a_dft + normalize_conv + (self.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big
|
||||
}
|
||||
}
|
||||
|
||||
fn glwe_external_product_internal<DR, A, G>(
|
||||
&self,
|
||||
mut res_dft: VecZnxDft<DR, BE>,
|
||||
a: &A,
|
||||
ggsw: &G,
|
||||
scratch: &mut Scratch<BE>,
|
||||
) -> VecZnxBig<DR, BE>
|
||||
where
|
||||
DR: DataMut,
|
||||
A: GLWEToRef,
|
||||
G: GGSWPreparedToRef<BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
let a: &GLWE<&[u8]> = &a.to_ref();
|
||||
let ggsw: &GGSWPrepared<&[u8], BE> = &ggsw.to_ref();
|
||||
|
||||
let basek_in: usize = a.base2k().into();
|
||||
let basek_ggsw: usize = ggsw.base2k().into();
|
||||
|
||||
let cols: usize = (ggsw.rank() + 1).into();
|
||||
let dsize: usize = ggsw.dsize().into();
|
||||
let a_size: usize = (a.size() * basek_in).div_ceil(basek_ggsw);
|
||||
|
||||
let (mut a_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, a_size.div_ceil(dsize));
|
||||
a_dft.data_mut().fill(0);
|
||||
|
||||
if basek_in == basek_ggsw {
|
||||
for di in 0..dsize {
|
||||
// (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize)
|
||||
a_dft.set_size((a.size() + di) / dsize);
|
||||
|
||||
// Small optimization for dsize > 2
|
||||
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
|
||||
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
|
||||
// As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
|
||||
// It is possible to further ignore the last dsize-1 limbs, but this introduce
|
||||
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
|
||||
// noise is kept with respect to the ideal functionality.
|
||||
res_dft.set_size(ggsw.size() - ((dsize - di) as isize - 2).max(0) as usize);
|
||||
|
||||
for j in 0..cols {
|
||||
self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &a.data, j);
|
||||
}
|
||||
|
||||
if di == 0 {
|
||||
self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &ggsw.data, scratch_1);
|
||||
} else {
|
||||
self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &ggsw.data, di, scratch_1);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let (mut a_conv, scratch_3) = scratch_1.take_vec_znx(self.n(), cols, a_size);
|
||||
|
||||
for j in 0..cols {
|
||||
self.vec_znx_normalize(basek_ggsw, &mut a_conv, j, basek_in, &a.data, j, scratch_3);
|
||||
}
|
||||
|
||||
for di in 0..dsize {
|
||||
// (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize)
|
||||
a_dft.set_size((a.size() + di) / dsize);
|
||||
|
||||
// Small optimization for dsize > 2
|
||||
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
|
||||
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
|
||||
// As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
|
||||
// It is possible to further ignore the last dsize-1 limbs, but this introduce
|
||||
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
|
||||
// noise is kept with respect to the ideal functionality.
|
||||
res_dft.set_size(ggsw.size() - ((dsize - di) as isize - 2).max(0) as usize);
|
||||
|
||||
for j in 0..cols {
|
||||
self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &a.data, j);
|
||||
}
|
||||
|
||||
if di == 0 {
|
||||
self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &ggsw.data, scratch_1);
|
||||
} else {
|
||||
self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &ggsw.data, di, scratch_1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.vec_znx_idft_apply_consume(res_dft)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -382,8 +382,8 @@ impl<D: DataMut, T: UnsignedInteger> FheUint<D, T> {
|
||||
|
||||
let (mut out_bits, scratch_1) = scratch.take_glwe_slice(T::BITS as usize, self);
|
||||
|
||||
for i in 0..T::BITS as usize {
|
||||
module.cmux(&mut out_bits[i], &one, &zero, &other.get_bit(i), scratch_1);
|
||||
for (i, bits) in out_bits.iter_mut().enumerate().take(T::BITS as usize) {
|
||||
module.cmux(bits, &one, &zero, &other.get_bit(i), scratch_1);
|
||||
}
|
||||
|
||||
self.pack(module, out_bits, keys, scratch_1);
|
||||
|
||||
@@ -2,9 +2,13 @@ use core::panic;
|
||||
|
||||
use itertools::Itertools;
|
||||
use poulpy_core::{
|
||||
GLWEAdd, GLWECopy, GLWEExternalProduct, GLWENormalize, GLWESub, ScratchTakeCore, layouts::{GGSWInfos, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, prepared::GGSWPreparedToRef}
|
||||
GLWECopy, GLWEExternalProductInternal, GLWENormalize, GLWESub, ScratchTakeCore,
|
||||
layouts::{GGSWInfos, GGSWPrepared, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, prepared::GGSWPreparedToRef},
|
||||
};
|
||||
use poulpy_hal::{
|
||||
api::{ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftBytesOf},
|
||||
layouts::{Backend, DataMut, Module, Scratch, VecZnxBig, ZnxZero},
|
||||
};
|
||||
use poulpy_hal::layouts::{Backend, DataMut, Module, Scratch, ZnxZero};
|
||||
|
||||
use crate::tfhe::bdd_arithmetic::{GetGGSWBit, UnsignedInteger};
|
||||
|
||||
@@ -70,8 +74,18 @@ where
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(inputs.bit_size() >= circuit.input_size(), "inputs.bit_size(): {} < circuit.input_size():{}", inputs.bit_size(), circuit.input_size());
|
||||
assert!(out.len() >= circuit.output_size(), "out.len(): {} < circuit.output_size(): {}", out.len(), circuit.output_size());
|
||||
assert!(
|
||||
inputs.bit_size() >= circuit.input_size(),
|
||||
"inputs.bit_size(): {} < circuit.input_size():{}",
|
||||
inputs.bit_size(),
|
||||
circuit.input_size()
|
||||
);
|
||||
assert!(
|
||||
out.len() >= circuit.output_size(),
|
||||
"out.len(): {} < circuit.output_size(): {}",
|
||||
out.len(),
|
||||
circuit.output_size()
|
||||
);
|
||||
}
|
||||
|
||||
for (i, out_i) in out.iter_mut().enumerate().take(circuit.output_size()) {
|
||||
@@ -164,7 +178,14 @@ pub enum Node {
|
||||
|
||||
pub trait Cmux<BE: Backend>
|
||||
where
|
||||
Self: GLWEExternalProduct<BE> + GLWESub + GLWEAdd + GLWENormalize<BE>,
|
||||
Self: Sized
|
||||
+ GLWEExternalProductInternal<BE>
|
||||
+ GLWESub
|
||||
+ VecZnxBigAddSmallInplace<BE>
|
||||
+ GLWENormalize<BE>
|
||||
+ VecZnxDftBytesOf
|
||||
+ VecZnxBigNormalize<BE>
|
||||
+ VecZnxBigNormalizeTmpBytes,
|
||||
{
|
||||
fn cmux_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
|
||||
where
|
||||
@@ -172,7 +193,11 @@ where
|
||||
A: GLWEInfos,
|
||||
B: GGSWInfos,
|
||||
{
|
||||
self.glwe_external_product_tmp_bytes(res_infos, a_infos, b_infos)
|
||||
let res_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), b_infos.size());
|
||||
res_dft
|
||||
+ self
|
||||
.glwe_external_product_internal_tmp_bytes(res_infos, a_infos, b_infos)
|
||||
.max(self.vec_znx_big_normalize_tmp_bytes())
|
||||
}
|
||||
|
||||
fn cmux<R, T, F, S>(&self, res: &mut R, t: &T, f: &F, s: &S, scratch: &mut Scratch<BE>)
|
||||
@@ -180,34 +205,73 @@ where
|
||||
R: GLWEToMut,
|
||||
T: GLWEToRef,
|
||||
F: GLWEToRef,
|
||||
S: GGSWPreparedToRef<BE>,
|
||||
S: GGSWPreparedToRef<BE> + GGSWInfos,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
self.glwe_sub(res, t, f);
|
||||
self.glwe_normalize_inplace(res, scratch);
|
||||
self.glwe_external_product_inplace(res, s, scratch);
|
||||
self.glwe_add_inplace(res, f);
|
||||
self.glwe_normalize_inplace(res, scratch);
|
||||
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
|
||||
let s: &GGSWPrepared<&[u8], BE> = &s.to_ref();
|
||||
let f: GLWE<&[u8]> = f.to_ref();
|
||||
|
||||
let res_base2k: usize = res.base2k().into();
|
||||
let ggsw_base2k: usize = s.base2k().into();
|
||||
|
||||
self.glwe_sub(res, t, &f);
|
||||
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), s.size()); // Todo optimise
|
||||
let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_external_product_internal(res_dft, res, s, scratch_1);
|
||||
for j in 0..(res.rank() + 1).into() {
|
||||
self.vec_znx_big_add_small_inplace(&mut res_big, j, f.data(), j);
|
||||
self.vec_znx_big_normalize(
|
||||
res_base2k,
|
||||
res.data_mut(),
|
||||
j,
|
||||
ggsw_base2k,
|
||||
&res_big,
|
||||
j,
|
||||
scratch_1,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn cmux_inplace<R, A, S>(&self, res: &mut R, a: &A, s: &S, scratch: &mut Scratch<BE>)
|
||||
where
|
||||
R: GLWEToMut,
|
||||
A: GLWEToRef,
|
||||
S: GGSWPreparedToRef<BE>,
|
||||
S: GGSWPreparedToRef<BE> + GGSWInfos,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
self.glwe_sub_inplace(res, a);
|
||||
self.glwe_normalize_inplace(res, scratch);
|
||||
self.glwe_external_product_inplace(res, s, scratch);
|
||||
self.glwe_add_inplace(res, a);
|
||||
self.glwe_normalize_inplace(res, scratch);
|
||||
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
|
||||
let s: &GGSWPrepared<&[u8], BE> = &s.to_ref();
|
||||
let a: GLWE<&[u8]> = a.to_ref();
|
||||
let res_base2k: usize = res.base2k().into();
|
||||
let ggsw_base2k: usize = s.base2k().into();
|
||||
self.glwe_sub_inplace(res, &a);
|
||||
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), s.size()); // Todo optimise
|
||||
let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_external_product_internal(res_dft, res, s, scratch_1);
|
||||
for j in 0..(res.rank() + 1).into() {
|
||||
self.vec_znx_big_add_small_inplace(&mut res_big, j, a.data(), j);
|
||||
self.vec_znx_big_normalize(
|
||||
res_base2k,
|
||||
res.data_mut(),
|
||||
j,
|
||||
ggsw_base2k,
|
||||
&res_big,
|
||||
j,
|
||||
scratch_1,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<BE: Backend> Cmux<BE> for Module<BE>
|
||||
where
|
||||
Self: GLWEExternalProduct<BE> + GLWESub + GLWEAdd + GLWENormalize<BE>,
|
||||
Self: Sized
|
||||
+ GLWEExternalProductInternal<BE>
|
||||
+ GLWESub
|
||||
+ VecZnxBigAddSmallInplace<BE>
|
||||
+ GLWENormalize<BE>
|
||||
+ VecZnxDftBytesOf
|
||||
+ VecZnxBigNormalize<BE>
|
||||
+ VecZnxBigNormalizeTmpBytes,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user