mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
Improve cmux speed
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
use poulpy_hal::{
|
use poulpy_hal::{
|
||||||
api::ScratchAvailable,
|
api::{ModuleN, ScratchAvailable},
|
||||||
layouts::{Backend, DataMut, Module, Scratch, ZnxZero},
|
layouts::{Backend, DataMut, Module, Scratch, ZnxZero},
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -13,7 +13,7 @@ use crate::{
|
|||||||
|
|
||||||
pub trait GGSWExternalProduct<BE: Backend>
|
pub trait GGSWExternalProduct<BE: Backend>
|
||||||
where
|
where
|
||||||
Self: GLWEExternalProduct<BE>,
|
Self: GLWEExternalProduct<BE> + ModuleN,
|
||||||
{
|
{
|
||||||
fn ggsw_external_product_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
|
fn ggsw_external_product_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
|
||||||
where
|
where
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
use poulpy_hal::{
|
use poulpy_hal::{
|
||||||
api::{
|
api::{
|
||||||
ModuleN, ScratchTakeBasic, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize,
|
ModuleN, ScratchTakeBasic, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftApply, VecZnxDftBytesOf,
|
||||||
VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
|
VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd,
|
||||||
|
VmpApplyDftToDftTmpBytes,
|
||||||
},
|
},
|
||||||
layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnx, VecZnxBig},
|
layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft},
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
@@ -30,7 +31,7 @@ impl<DataSelf: DataMut> GLWE<DataSelf> {
|
|||||||
pub fn external_product<A, B, M, BE: Backend>(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch<BE>)
|
pub fn external_product<A, B, M, BE: Backend>(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch<BE>)
|
||||||
where
|
where
|
||||||
A: GLWEToRef,
|
A: GLWEToRef,
|
||||||
B: GGSWPreparedToRef<BE>,
|
B: GGSWPreparedToRef<BE> + GGSWInfos,
|
||||||
M: GLWEExternalProduct<BE>,
|
M: GLWEExternalProduct<BE>,
|
||||||
Scratch<BE>: ScratchTakeCore<BE>,
|
Scratch<BE>: ScratchTakeCore<BE>,
|
||||||
{
|
{
|
||||||
@@ -39,7 +40,7 @@ impl<DataSelf: DataMut> GLWE<DataSelf> {
|
|||||||
|
|
||||||
pub fn external_product_inplace<A, M, BE: Backend>(&mut self, module: &M, a: &A, scratch: &mut Scratch<BE>)
|
pub fn external_product_inplace<A, M, BE: Backend>(&mut self, module: &M, a: &A, scratch: &mut Scratch<BE>)
|
||||||
where
|
where
|
||||||
A: GGSWPreparedToRef<BE>,
|
A: GGSWPreparedToRef<BE> + GGSWInfos,
|
||||||
M: GLWEExternalProduct<BE>,
|
M: GLWEExternalProduct<BE>,
|
||||||
Scratch<BE>: ScratchTakeCore<BE>,
|
Scratch<BE>: ScratchTakeCore<BE>,
|
||||||
{
|
{
|
||||||
@@ -47,19 +48,30 @@ impl<DataSelf: DataMut> GLWE<DataSelf> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait GLWEExternalProduct<BE: Backend>
|
pub trait GLWEExternalProduct<BE: Backend> {
|
||||||
|
fn glwe_external_product_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
|
||||||
|
where
|
||||||
|
R: GLWEInfos,
|
||||||
|
A: GLWEInfos,
|
||||||
|
B: GGSWInfos;
|
||||||
|
|
||||||
|
fn glwe_external_product_inplace<R, D>(&self, res: &mut R, a: &D, scratch: &mut Scratch<BE>)
|
||||||
|
where
|
||||||
|
R: GLWEToMut,
|
||||||
|
D: GGSWPreparedToRef<BE> + GGSWInfos,
|
||||||
|
Scratch<BE>: ScratchTakeCore<BE>;
|
||||||
|
|
||||||
|
fn glwe_external_product<R, A, D>(&self, res: &mut R, lhs: &A, rhs: &D, scratch: &mut Scratch<BE>)
|
||||||
|
where
|
||||||
|
R: GLWEToMut,
|
||||||
|
A: GLWEToRef,
|
||||||
|
D: GGSWPreparedToRef<BE> + GGSWInfos,
|
||||||
|
Scratch<BE>: ScratchTakeCore<BE>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<BE: Backend> GLWEExternalProduct<BE> for Module<BE>
|
||||||
where
|
where
|
||||||
Self: Sized
|
Self: GLWEExternalProductInternal<BE> + VecZnxDftBytesOf + VecZnxBigNormalize<BE> + VecZnxBigNormalizeTmpBytes,
|
||||||
+ ModuleN
|
|
||||||
+ VecZnxDftBytesOf
|
|
||||||
+ VmpApplyDftToDftTmpBytes
|
|
||||||
+ VecZnxNormalizeTmpBytes
|
|
||||||
+ VecZnxDftApply<BE>
|
|
||||||
+ VmpApplyDftToDft<BE>
|
|
||||||
+ VmpApplyDftToDftAdd<BE>
|
|
||||||
+ VecZnxIdftApplyConsume<BE>
|
|
||||||
+ VecZnxBigNormalize<BE>
|
|
||||||
+ VecZnxNormalize<BE>,
|
|
||||||
{
|
{
|
||||||
fn glwe_external_product_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
|
fn glwe_external_product_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
|
||||||
where
|
where
|
||||||
@@ -67,36 +79,17 @@ where
|
|||||||
A: GLWEInfos,
|
A: GLWEInfos,
|
||||||
B: GGSWInfos,
|
B: GGSWInfos,
|
||||||
{
|
{
|
||||||
let in_size: usize = a_infos
|
let res_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), b_infos.size());
|
||||||
.k()
|
res_dft
|
||||||
.div_ceil(b_infos.base2k())
|
+ self
|
||||||
.div_ceil(b_infos.dsize().into()) as usize;
|
.glwe_external_product_internal_tmp_bytes(res_infos, a_infos, b_infos)
|
||||||
let out_size: usize = res_infos.size();
|
.max(self.vec_znx_big_normalize_tmp_bytes())
|
||||||
let ggsw_size: usize = b_infos.size();
|
|
||||||
let res_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), ggsw_size);
|
|
||||||
let a_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), in_size);
|
|
||||||
let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes(
|
|
||||||
out_size,
|
|
||||||
in_size,
|
|
||||||
in_size, // rows
|
|
||||||
(b_infos.rank() + 1).into(), // cols in
|
|
||||||
(b_infos.rank() + 1).into(), // cols out
|
|
||||||
ggsw_size,
|
|
||||||
);
|
|
||||||
let normalize_big: usize = self.vec_znx_normalize_tmp_bytes();
|
|
||||||
|
|
||||||
if a_infos.base2k() == b_infos.base2k() {
|
|
||||||
res_dft + a_dft + (vmp | normalize_big)
|
|
||||||
} else {
|
|
||||||
let normalize_conv: usize = VecZnx::bytes_of(self.n(), (b_infos.rank() + 1).into(), in_size);
|
|
||||||
res_dft + ((a_dft + normalize_conv + (self.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn glwe_external_product_inplace<R, D>(&self, res: &mut R, a: &D, scratch: &mut Scratch<BE>)
|
fn glwe_external_product_inplace<R, D>(&self, res: &mut R, a: &D, scratch: &mut Scratch<BE>)
|
||||||
where
|
where
|
||||||
R: GLWEToMut,
|
R: GLWEToMut,
|
||||||
D: GGSWPreparedToRef<BE>,
|
D: GGSWPreparedToRef<BE> + GGSWInfos,
|
||||||
Scratch<BE>: ScratchTakeCore<BE>,
|
Scratch<BE>: ScratchTakeCore<BE>,
|
||||||
{
|
{
|
||||||
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
|
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
|
||||||
@@ -114,81 +107,9 @@ where
|
|||||||
assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, res, rhs));
|
assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, res, rhs));
|
||||||
}
|
}
|
||||||
|
|
||||||
let cols: usize = (rhs.rank() + 1).into();
|
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), a.size()); // Todo optimise
|
||||||
let dsize: usize = rhs.dsize().into();
|
let res_big = self.glwe_external_product_internal(res_dft, res, a, scratch_1);
|
||||||
let a_size: usize = (res.size() * basek_in).div_ceil(basek_ggsw);
|
for j in 0..(res.rank() + 1).into() {
|
||||||
|
|
||||||
let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, rhs.size()); // Todo optimise
|
|
||||||
let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, cols, a_size.div_ceil(dsize));
|
|
||||||
a_dft.data_mut().fill(0);
|
|
||||||
|
|
||||||
if basek_in == basek_ggsw {
|
|
||||||
for di in 0..dsize {
|
|
||||||
// (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize)
|
|
||||||
a_dft.set_size((res.size() + di) / dsize);
|
|
||||||
|
|
||||||
// Small optimization for dsize > 2
|
|
||||||
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
|
|
||||||
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
|
|
||||||
// As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
|
|
||||||
// It is possible to further ignore the last dsize-1 limbs, but this introduce
|
|
||||||
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
|
|
||||||
// noise is kept with respect to the ideal functionality.
|
|
||||||
res_dft.set_size(rhs.size() - ((dsize - di) as isize - 2).max(0) as usize);
|
|
||||||
|
|
||||||
for j in 0..cols {
|
|
||||||
self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &res.data, j);
|
|
||||||
}
|
|
||||||
|
|
||||||
if di == 0 {
|
|
||||||
self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_2);
|
|
||||||
} else {
|
|
||||||
self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_2);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self.n(), cols, a_size);
|
|
||||||
|
|
||||||
for j in 0..cols {
|
|
||||||
self.vec_znx_normalize(
|
|
||||||
basek_ggsw,
|
|
||||||
&mut a_conv,
|
|
||||||
j,
|
|
||||||
basek_in,
|
|
||||||
&res.data,
|
|
||||||
j,
|
|
||||||
scratch_3,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
for di in 0..dsize {
|
|
||||||
// (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize)
|
|
||||||
a_dft.set_size((res.size() + di) / dsize);
|
|
||||||
|
|
||||||
// Small optimization for dsize > 2
|
|
||||||
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
|
|
||||||
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
|
|
||||||
// As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
|
|
||||||
// It is possible to further ignore the last dsize-1 limbs, but this introduce
|
|
||||||
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
|
|
||||||
// noise is kept with respect to the ideal functionality.
|
|
||||||
res_dft.set_size(rhs.size() - ((dsize - di) as isize - 2).max(0) as usize);
|
|
||||||
|
|
||||||
for j in 0..cols {
|
|
||||||
self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &res.data, j);
|
|
||||||
}
|
|
||||||
|
|
||||||
if di == 0 {
|
|
||||||
self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_2);
|
|
||||||
} else {
|
|
||||||
self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_2);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft);
|
|
||||||
|
|
||||||
for j in 0..cols {
|
|
||||||
self.vec_znx_big_normalize(
|
self.vec_znx_big_normalize(
|
||||||
basek_in,
|
basek_in,
|
||||||
&mut res.data,
|
&mut res.data,
|
||||||
@@ -213,7 +134,6 @@ where
|
|||||||
|
|
||||||
let rhs: &GGSWPrepared<&[u8], BE> = &rhs.to_ref();
|
let rhs: &GGSWPrepared<&[u8], BE> = &rhs.to_ref();
|
||||||
|
|
||||||
let basek_in: usize = lhs.base2k().into();
|
|
||||||
let basek_ggsw: usize = rhs.base2k().into();
|
let basek_ggsw: usize = rhs.base2k().into();
|
||||||
let basek_out: usize = res.base2k().into();
|
let basek_out: usize = res.base2k().into();
|
||||||
|
|
||||||
@@ -228,96 +148,45 @@ where
|
|||||||
assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, lhs, rhs));
|
assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, lhs, rhs));
|
||||||
}
|
}
|
||||||
|
|
||||||
let cols: usize = (rhs.rank() + 1).into();
|
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), rhs.size()); // Todo optimise
|
||||||
let dsize: usize = rhs.dsize().into();
|
let res_big = self.glwe_external_product_internal(res_dft, lhs, rhs, scratch_1);
|
||||||
|
|
||||||
let a_size: usize = (lhs.size() * basek_in).div_ceil(basek_ggsw);
|
for j in 0..(res.rank() + 1).into() {
|
||||||
|
|
||||||
let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, rhs.size()); // Todo optimise
|
|
||||||
let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, cols, a_size.div_ceil(dsize));
|
|
||||||
a_dft.data_mut().fill(0);
|
|
||||||
|
|
||||||
if basek_in == basek_ggsw {
|
|
||||||
for di in 0..dsize {
|
|
||||||
// (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize)
|
|
||||||
a_dft.set_size((lhs.size() + di) / dsize);
|
|
||||||
|
|
||||||
// Small optimization for dsize > 2
|
|
||||||
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
|
|
||||||
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
|
|
||||||
// As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
|
|
||||||
// It is possible to further ignore the last dsize-1 limbs, but this introduce
|
|
||||||
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
|
|
||||||
// noise is kept with respect to the ideal functionality.
|
|
||||||
res_dft.set_size(rhs.size() - ((dsize - di) as isize - 2).max(0) as usize);
|
|
||||||
|
|
||||||
for j in 0..cols {
|
|
||||||
self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &lhs.data, j);
|
|
||||||
}
|
|
||||||
|
|
||||||
if di == 0 {
|
|
||||||
self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_2);
|
|
||||||
} else {
|
|
||||||
self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_2);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self.n(), cols, a_size);
|
|
||||||
|
|
||||||
for j in 0..cols {
|
|
||||||
self.vec_znx_normalize(
|
|
||||||
basek_ggsw,
|
|
||||||
&mut a_conv,
|
|
||||||
j,
|
|
||||||
basek_in,
|
|
||||||
&lhs.data,
|
|
||||||
j,
|
|
||||||
scratch_3,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
for di in 0..dsize {
|
|
||||||
// (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize)
|
|
||||||
a_dft.set_size((a_size + di) / dsize);
|
|
||||||
|
|
||||||
// Small optimization for dsize > 2
|
|
||||||
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
|
|
||||||
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
|
|
||||||
// As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
|
|
||||||
// It is possible to further ignore the last dsize-1 limbs, but this introduce
|
|
||||||
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
|
|
||||||
// noise is kept with respect to the ideal functionality.
|
|
||||||
res_dft.set_size(rhs.size() - ((dsize - di) as isize - 2).max(0) as usize);
|
|
||||||
|
|
||||||
for j in 0..cols {
|
|
||||||
self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &a_conv, j);
|
|
||||||
}
|
|
||||||
|
|
||||||
if di == 0 {
|
|
||||||
self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_3);
|
|
||||||
} else {
|
|
||||||
self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_3);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft);
|
|
||||||
|
|
||||||
(0..cols).for_each(|i| {
|
|
||||||
self.vec_znx_big_normalize(
|
self.vec_znx_big_normalize(
|
||||||
basek_out,
|
basek_out,
|
||||||
res.data_mut(),
|
&mut res.data,
|
||||||
i,
|
j,
|
||||||
basek_ggsw,
|
basek_ggsw,
|
||||||
&res_big,
|
&res_big,
|
||||||
i,
|
j,
|
||||||
scratch_1,
|
scratch_1,
|
||||||
);
|
);
|
||||||
});
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<BE: Backend> GLWEExternalProduct<BE> for Module<BE> where
|
pub trait GLWEExternalProductInternal<BE: Backend> {
|
||||||
|
fn glwe_external_product_internal_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
|
||||||
|
where
|
||||||
|
R: GLWEInfos,
|
||||||
|
A: GLWEInfos,
|
||||||
|
B: GGSWInfos;
|
||||||
|
fn glwe_external_product_internal<DR, A, G>(
|
||||||
|
&self,
|
||||||
|
res_dft: VecZnxDft<DR, BE>,
|
||||||
|
a: &A,
|
||||||
|
ggsw: &G,
|
||||||
|
scratch: &mut Scratch<BE>,
|
||||||
|
) -> VecZnxBig<DR, BE>
|
||||||
|
where
|
||||||
|
DR: DataMut,
|
||||||
|
A: GLWEToRef,
|
||||||
|
G: GGSWPreparedToRef<BE>,
|
||||||
|
Scratch<BE>: ScratchTakeCore<BE>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<BE: Backend> GLWEExternalProductInternal<BE> for Module<BE>
|
||||||
|
where
|
||||||
Self: ModuleN
|
Self: ModuleN
|
||||||
+ VecZnxDftBytesOf
|
+ VecZnxDftBytesOf
|
||||||
+ VmpApplyDftToDftTmpBytes
|
+ VmpApplyDftToDftTmpBytes
|
||||||
@@ -330,6 +199,121 @@ impl<BE: Backend> GLWEExternalProduct<BE> for Module<BE> where
|
|||||||
+ VecZnxNormalize<BE>
|
+ VecZnxNormalize<BE>
|
||||||
+ VecZnxDftBytesOf
|
+ VecZnxDftBytesOf
|
||||||
+ VmpApplyDftToDftTmpBytes
|
+ VmpApplyDftToDftTmpBytes
|
||||||
+ VecZnxNormalizeTmpBytes
|
+ VecZnxNormalizeTmpBytes,
|
||||||
{
|
{
|
||||||
|
fn glwe_external_product_internal_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
|
||||||
|
where
|
||||||
|
R: GLWEInfos,
|
||||||
|
A: GLWEInfos,
|
||||||
|
B: GGSWInfos,
|
||||||
|
{
|
||||||
|
let in_size: usize = a_infos
|
||||||
|
.k()
|
||||||
|
.div_ceil(b_infos.base2k())
|
||||||
|
.div_ceil(b_infos.dsize().into()) as usize;
|
||||||
|
let out_size: usize = res_infos.size();
|
||||||
|
let ggsw_size: usize = b_infos.size();
|
||||||
|
let a_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), in_size);
|
||||||
|
let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes(
|
||||||
|
out_size,
|
||||||
|
in_size,
|
||||||
|
in_size, // rows
|
||||||
|
(b_infos.rank() + 1).into(), // cols in
|
||||||
|
(b_infos.rank() + 1).into(), // cols out
|
||||||
|
ggsw_size,
|
||||||
|
);
|
||||||
|
let normalize_big: usize = self.vec_znx_normalize_tmp_bytes();
|
||||||
|
|
||||||
|
if a_infos.base2k() == b_infos.base2k() {
|
||||||
|
a_dft + (vmp | normalize_big)
|
||||||
|
} else {
|
||||||
|
let normalize_conv: usize = VecZnx::bytes_of(self.n(), (b_infos.rank() + 1).into(), in_size);
|
||||||
|
(a_dft + normalize_conv + (self.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn glwe_external_product_internal<DR, A, G>(
|
||||||
|
&self,
|
||||||
|
mut res_dft: VecZnxDft<DR, BE>,
|
||||||
|
a: &A,
|
||||||
|
ggsw: &G,
|
||||||
|
scratch: &mut Scratch<BE>,
|
||||||
|
) -> VecZnxBig<DR, BE>
|
||||||
|
where
|
||||||
|
DR: DataMut,
|
||||||
|
A: GLWEToRef,
|
||||||
|
G: GGSWPreparedToRef<BE>,
|
||||||
|
Scratch<BE>: ScratchTakeCore<BE>,
|
||||||
|
{
|
||||||
|
let a: &GLWE<&[u8]> = &a.to_ref();
|
||||||
|
let ggsw: &GGSWPrepared<&[u8], BE> = &ggsw.to_ref();
|
||||||
|
|
||||||
|
let basek_in: usize = a.base2k().into();
|
||||||
|
let basek_ggsw: usize = ggsw.base2k().into();
|
||||||
|
|
||||||
|
let cols: usize = (ggsw.rank() + 1).into();
|
||||||
|
let dsize: usize = ggsw.dsize().into();
|
||||||
|
let a_size: usize = (a.size() * basek_in).div_ceil(basek_ggsw);
|
||||||
|
|
||||||
|
let (mut a_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, a_size.div_ceil(dsize));
|
||||||
|
a_dft.data_mut().fill(0);
|
||||||
|
|
||||||
|
if basek_in == basek_ggsw {
|
||||||
|
for di in 0..dsize {
|
||||||
|
// (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize)
|
||||||
|
a_dft.set_size((a.size() + di) / dsize);
|
||||||
|
|
||||||
|
// Small optimization for dsize > 2
|
||||||
|
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
|
||||||
|
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
|
||||||
|
// As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
|
||||||
|
// It is possible to further ignore the last dsize-1 limbs, but this introduce
|
||||||
|
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
|
||||||
|
// noise is kept with respect to the ideal functionality.
|
||||||
|
res_dft.set_size(ggsw.size() - ((dsize - di) as isize - 2).max(0) as usize);
|
||||||
|
|
||||||
|
for j in 0..cols {
|
||||||
|
self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &a.data, j);
|
||||||
|
}
|
||||||
|
|
||||||
|
if di == 0 {
|
||||||
|
self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &ggsw.data, scratch_1);
|
||||||
|
} else {
|
||||||
|
self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &ggsw.data, di, scratch_1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let (mut a_conv, scratch_3) = scratch_1.take_vec_znx(self.n(), cols, a_size);
|
||||||
|
|
||||||
|
for j in 0..cols {
|
||||||
|
self.vec_znx_normalize(basek_ggsw, &mut a_conv, j, basek_in, &a.data, j, scratch_3);
|
||||||
|
}
|
||||||
|
|
||||||
|
for di in 0..dsize {
|
||||||
|
// (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize)
|
||||||
|
a_dft.set_size((a.size() + di) / dsize);
|
||||||
|
|
||||||
|
// Small optimization for dsize > 2
|
||||||
|
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
|
||||||
|
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
|
||||||
|
// As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
|
||||||
|
// It is possible to further ignore the last dsize-1 limbs, but this introduce
|
||||||
|
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
|
||||||
|
// noise is kept with respect to the ideal functionality.
|
||||||
|
res_dft.set_size(ggsw.size() - ((dsize - di) as isize - 2).max(0) as usize);
|
||||||
|
|
||||||
|
for j in 0..cols {
|
||||||
|
self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &a.data, j);
|
||||||
|
}
|
||||||
|
|
||||||
|
if di == 0 {
|
||||||
|
self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &ggsw.data, scratch_1);
|
||||||
|
} else {
|
||||||
|
self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &ggsw.data, di, scratch_1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.vec_znx_idft_apply_consume(res_dft)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -382,8 +382,8 @@ impl<D: DataMut, T: UnsignedInteger> FheUint<D, T> {
|
|||||||
|
|
||||||
let (mut out_bits, scratch_1) = scratch.take_glwe_slice(T::BITS as usize, self);
|
let (mut out_bits, scratch_1) = scratch.take_glwe_slice(T::BITS as usize, self);
|
||||||
|
|
||||||
for i in 0..T::BITS as usize {
|
for (i, bits) in out_bits.iter_mut().enumerate().take(T::BITS as usize) {
|
||||||
module.cmux(&mut out_bits[i], &one, &zero, &other.get_bit(i), scratch_1);
|
module.cmux(bits, &one, &zero, &other.get_bit(i), scratch_1);
|
||||||
}
|
}
|
||||||
|
|
||||||
self.pack(module, out_bits, keys, scratch_1);
|
self.pack(module, out_bits, keys, scratch_1);
|
||||||
|
|||||||
@@ -2,9 +2,13 @@ use core::panic;
|
|||||||
|
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
use poulpy_core::{
|
use poulpy_core::{
|
||||||
GLWEAdd, GLWECopy, GLWEExternalProduct, GLWENormalize, GLWESub, ScratchTakeCore, layouts::{GGSWInfos, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, prepared::GGSWPreparedToRef}
|
GLWECopy, GLWEExternalProductInternal, GLWENormalize, GLWESub, ScratchTakeCore,
|
||||||
|
layouts::{GGSWInfos, GGSWPrepared, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, prepared::GGSWPreparedToRef},
|
||||||
|
};
|
||||||
|
use poulpy_hal::{
|
||||||
|
api::{ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftBytesOf},
|
||||||
|
layouts::{Backend, DataMut, Module, Scratch, VecZnxBig, ZnxZero},
|
||||||
};
|
};
|
||||||
use poulpy_hal::layouts::{Backend, DataMut, Module, Scratch, ZnxZero};
|
|
||||||
|
|
||||||
use crate::tfhe::bdd_arithmetic::{GetGGSWBit, UnsignedInteger};
|
use crate::tfhe::bdd_arithmetic::{GetGGSWBit, UnsignedInteger};
|
||||||
|
|
||||||
@@ -70,8 +74,18 @@ where
|
|||||||
{
|
{
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert!(inputs.bit_size() >= circuit.input_size(), "inputs.bit_size(): {} < circuit.input_size():{}", inputs.bit_size(), circuit.input_size());
|
assert!(
|
||||||
assert!(out.len() >= circuit.output_size(), "out.len(): {} < circuit.output_size(): {}", out.len(), circuit.output_size());
|
inputs.bit_size() >= circuit.input_size(),
|
||||||
|
"inputs.bit_size(): {} < circuit.input_size():{}",
|
||||||
|
inputs.bit_size(),
|
||||||
|
circuit.input_size()
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
out.len() >= circuit.output_size(),
|
||||||
|
"out.len(): {} < circuit.output_size(): {}",
|
||||||
|
out.len(),
|
||||||
|
circuit.output_size()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (i, out_i) in out.iter_mut().enumerate().take(circuit.output_size()) {
|
for (i, out_i) in out.iter_mut().enumerate().take(circuit.output_size()) {
|
||||||
@@ -164,7 +178,14 @@ pub enum Node {
|
|||||||
|
|
||||||
pub trait Cmux<BE: Backend>
|
pub trait Cmux<BE: Backend>
|
||||||
where
|
where
|
||||||
Self: GLWEExternalProduct<BE> + GLWESub + GLWEAdd + GLWENormalize<BE>,
|
Self: Sized
|
||||||
|
+ GLWEExternalProductInternal<BE>
|
||||||
|
+ GLWESub
|
||||||
|
+ VecZnxBigAddSmallInplace<BE>
|
||||||
|
+ GLWENormalize<BE>
|
||||||
|
+ VecZnxDftBytesOf
|
||||||
|
+ VecZnxBigNormalize<BE>
|
||||||
|
+ VecZnxBigNormalizeTmpBytes,
|
||||||
{
|
{
|
||||||
fn cmux_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
|
fn cmux_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
|
||||||
where
|
where
|
||||||
@@ -172,7 +193,11 @@ where
|
|||||||
A: GLWEInfos,
|
A: GLWEInfos,
|
||||||
B: GGSWInfos,
|
B: GGSWInfos,
|
||||||
{
|
{
|
||||||
self.glwe_external_product_tmp_bytes(res_infos, a_infos, b_infos)
|
let res_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), b_infos.size());
|
||||||
|
res_dft
|
||||||
|
+ self
|
||||||
|
.glwe_external_product_internal_tmp_bytes(res_infos, a_infos, b_infos)
|
||||||
|
.max(self.vec_znx_big_normalize_tmp_bytes())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cmux<R, T, F, S>(&self, res: &mut R, t: &T, f: &F, s: &S, scratch: &mut Scratch<BE>)
|
fn cmux<R, T, F, S>(&self, res: &mut R, t: &T, f: &F, s: &S, scratch: &mut Scratch<BE>)
|
||||||
@@ -180,34 +205,73 @@ where
|
|||||||
R: GLWEToMut,
|
R: GLWEToMut,
|
||||||
T: GLWEToRef,
|
T: GLWEToRef,
|
||||||
F: GLWEToRef,
|
F: GLWEToRef,
|
||||||
S: GGSWPreparedToRef<BE>,
|
S: GGSWPreparedToRef<BE> + GGSWInfos,
|
||||||
Scratch<BE>: ScratchTakeCore<BE>,
|
Scratch<BE>: ScratchTakeCore<BE>,
|
||||||
{
|
{
|
||||||
self.glwe_sub(res, t, f);
|
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
|
||||||
self.glwe_normalize_inplace(res, scratch);
|
let s: &GGSWPrepared<&[u8], BE> = &s.to_ref();
|
||||||
self.glwe_external_product_inplace(res, s, scratch);
|
let f: GLWE<&[u8]> = f.to_ref();
|
||||||
self.glwe_add_inplace(res, f);
|
|
||||||
self.glwe_normalize_inplace(res, scratch);
|
let res_base2k: usize = res.base2k().into();
|
||||||
|
let ggsw_base2k: usize = s.base2k().into();
|
||||||
|
|
||||||
|
self.glwe_sub(res, t, &f);
|
||||||
|
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), s.size()); // Todo optimise
|
||||||
|
let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_external_product_internal(res_dft, res, s, scratch_1);
|
||||||
|
for j in 0..(res.rank() + 1).into() {
|
||||||
|
self.vec_znx_big_add_small_inplace(&mut res_big, j, f.data(), j);
|
||||||
|
self.vec_znx_big_normalize(
|
||||||
|
res_base2k,
|
||||||
|
res.data_mut(),
|
||||||
|
j,
|
||||||
|
ggsw_base2k,
|
||||||
|
&res_big,
|
||||||
|
j,
|
||||||
|
scratch_1,
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cmux_inplace<R, A, S>(&self, res: &mut R, a: &A, s: &S, scratch: &mut Scratch<BE>)
|
fn cmux_inplace<R, A, S>(&self, res: &mut R, a: &A, s: &S, scratch: &mut Scratch<BE>)
|
||||||
where
|
where
|
||||||
R: GLWEToMut,
|
R: GLWEToMut,
|
||||||
A: GLWEToRef,
|
A: GLWEToRef,
|
||||||
S: GGSWPreparedToRef<BE>,
|
S: GGSWPreparedToRef<BE> + GGSWInfos,
|
||||||
Scratch<BE>: ScratchTakeCore<BE>,
|
Scratch<BE>: ScratchTakeCore<BE>,
|
||||||
{
|
{
|
||||||
self.glwe_sub_inplace(res, a);
|
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
|
||||||
self.glwe_normalize_inplace(res, scratch);
|
let s: &GGSWPrepared<&[u8], BE> = &s.to_ref();
|
||||||
self.glwe_external_product_inplace(res, s, scratch);
|
let a: GLWE<&[u8]> = a.to_ref();
|
||||||
self.glwe_add_inplace(res, a);
|
let res_base2k: usize = res.base2k().into();
|
||||||
self.glwe_normalize_inplace(res, scratch);
|
let ggsw_base2k: usize = s.base2k().into();
|
||||||
|
self.glwe_sub_inplace(res, &a);
|
||||||
|
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), s.size()); // Todo optimise
|
||||||
|
let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_external_product_internal(res_dft, res, s, scratch_1);
|
||||||
|
for j in 0..(res.rank() + 1).into() {
|
||||||
|
self.vec_znx_big_add_small_inplace(&mut res_big, j, a.data(), j);
|
||||||
|
self.vec_znx_big_normalize(
|
||||||
|
res_base2k,
|
||||||
|
res.data_mut(),
|
||||||
|
j,
|
||||||
|
ggsw_base2k,
|
||||||
|
&res_big,
|
||||||
|
j,
|
||||||
|
scratch_1,
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<BE: Backend> Cmux<BE> for Module<BE>
|
impl<BE: Backend> Cmux<BE> for Module<BE>
|
||||||
where
|
where
|
||||||
Self: GLWEExternalProduct<BE> + GLWESub + GLWEAdd + GLWENormalize<BE>,
|
Self: Sized
|
||||||
|
+ GLWEExternalProductInternal<BE>
|
||||||
|
+ GLWESub
|
||||||
|
+ VecZnxBigAddSmallInplace<BE>
|
||||||
|
+ GLWENormalize<BE>
|
||||||
|
+ VecZnxDftBytesOf
|
||||||
|
+ VecZnxBigNormalize<BE>
|
||||||
|
+ VecZnxBigNormalizeTmpBytes,
|
||||||
Scratch<BE>: ScratchTakeCore<BE>,
|
Scratch<BE>: ScratchTakeCore<BE>,
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user