Improve cmux speed

This commit is contained in:
Pro7ech
2025-11-07 17:56:33 +01:00
parent 836df871fe
commit 75842cd80a
4 changed files with 270 additions and 222 deletions

View File

@@ -1,5 +1,5 @@
use poulpy_hal::{ use poulpy_hal::{
api::ScratchAvailable, api::{ModuleN, ScratchAvailable},
layouts::{Backend, DataMut, Module, Scratch, ZnxZero}, layouts::{Backend, DataMut, Module, Scratch, ZnxZero},
}; };
@@ -13,7 +13,7 @@ use crate::{
pub trait GGSWExternalProduct<BE: Backend> pub trait GGSWExternalProduct<BE: Backend>
where where
Self: GLWEExternalProduct<BE>, Self: GLWEExternalProduct<BE> + ModuleN,
{ {
fn ggsw_external_product_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize fn ggsw_external_product_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
where where

View File

@@ -1,9 +1,10 @@
use poulpy_hal::{ use poulpy_hal::{
api::{ api::{
ModuleN, ScratchTakeBasic, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, ModuleN, ScratchTakeBasic, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftApply, VecZnxDftBytesOf,
VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd,
VmpApplyDftToDftTmpBytes,
}, },
layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnx, VecZnxBig}, layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft},
}; };
use crate::{ use crate::{
@@ -30,7 +31,7 @@ impl<DataSelf: DataMut> GLWE<DataSelf> {
pub fn external_product<A, B, M, BE: Backend>(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch<BE>) pub fn external_product<A, B, M, BE: Backend>(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch<BE>)
where where
A: GLWEToRef, A: GLWEToRef,
B: GGSWPreparedToRef<BE>, B: GGSWPreparedToRef<BE> + GGSWInfos,
M: GLWEExternalProduct<BE>, M: GLWEExternalProduct<BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
@@ -39,7 +40,7 @@ impl<DataSelf: DataMut> GLWE<DataSelf> {
pub fn external_product_inplace<A, M, BE: Backend>(&mut self, module: &M, a: &A, scratch: &mut Scratch<BE>) pub fn external_product_inplace<A, M, BE: Backend>(&mut self, module: &M, a: &A, scratch: &mut Scratch<BE>)
where where
A: GGSWPreparedToRef<BE>, A: GGSWPreparedToRef<BE> + GGSWInfos,
M: GLWEExternalProduct<BE>, M: GLWEExternalProduct<BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
@@ -47,19 +48,30 @@ impl<DataSelf: DataMut> GLWE<DataSelf> {
} }
} }
pub trait GLWEExternalProduct<BE: Backend> pub trait GLWEExternalProduct<BE: Backend> {
fn glwe_external_product_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
where where
Self: Sized R: GLWEInfos,
+ ModuleN A: GLWEInfos,
+ VecZnxDftBytesOf B: GGSWInfos;
+ VmpApplyDftToDftTmpBytes
+ VecZnxNormalizeTmpBytes fn glwe_external_product_inplace<R, D>(&self, res: &mut R, a: &D, scratch: &mut Scratch<BE>)
+ VecZnxDftApply<BE> where
+ VmpApplyDftToDft<BE> R: GLWEToMut,
+ VmpApplyDftToDftAdd<BE> D: GGSWPreparedToRef<BE> + GGSWInfos,
+ VecZnxIdftApplyConsume<BE> Scratch<BE>: ScratchTakeCore<BE>;
+ VecZnxBigNormalize<BE>
+ VecZnxNormalize<BE>, fn glwe_external_product<R, A, D>(&self, res: &mut R, lhs: &A, rhs: &D, scratch: &mut Scratch<BE>)
where
R: GLWEToMut,
A: GLWEToRef,
D: GGSWPreparedToRef<BE> + GGSWInfos,
Scratch<BE>: ScratchTakeCore<BE>;
}
impl<BE: Backend> GLWEExternalProduct<BE> for Module<BE>
where
Self: GLWEExternalProductInternal<BE> + VecZnxDftBytesOf + VecZnxBigNormalize<BE> + VecZnxBigNormalizeTmpBytes,
{ {
fn glwe_external_product_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize fn glwe_external_product_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
where where
@@ -67,36 +79,17 @@ where
A: GLWEInfos, A: GLWEInfos,
B: GGSWInfos, B: GGSWInfos,
{ {
let in_size: usize = a_infos let res_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), b_infos.size());
.k() res_dft
.div_ceil(b_infos.base2k()) + self
.div_ceil(b_infos.dsize().into()) as usize; .glwe_external_product_internal_tmp_bytes(res_infos, a_infos, b_infos)
let out_size: usize = res_infos.size(); .max(self.vec_znx_big_normalize_tmp_bytes())
let ggsw_size: usize = b_infos.size();
let res_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), ggsw_size);
let a_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), in_size);
let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes(
out_size,
in_size,
in_size, // rows
(b_infos.rank() + 1).into(), // cols in
(b_infos.rank() + 1).into(), // cols out
ggsw_size,
);
let normalize_big: usize = self.vec_znx_normalize_tmp_bytes();
if a_infos.base2k() == b_infos.base2k() {
res_dft + a_dft + (vmp | normalize_big)
} else {
let normalize_conv: usize = VecZnx::bytes_of(self.n(), (b_infos.rank() + 1).into(), in_size);
res_dft + ((a_dft + normalize_conv + (self.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big)
}
} }
fn glwe_external_product_inplace<R, D>(&self, res: &mut R, a: &D, scratch: &mut Scratch<BE>) fn glwe_external_product_inplace<R, D>(&self, res: &mut R, a: &D, scratch: &mut Scratch<BE>)
where where
R: GLWEToMut, R: GLWEToMut,
D: GGSWPreparedToRef<BE>, D: GGSWPreparedToRef<BE> + GGSWInfos,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
@@ -114,81 +107,9 @@ where
assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, res, rhs)); assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, res, rhs));
} }
let cols: usize = (rhs.rank() + 1).into(); let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), a.size()); // Todo optimise
let dsize: usize = rhs.dsize().into(); let res_big = self.glwe_external_product_internal(res_dft, res, a, scratch_1);
let a_size: usize = (res.size() * basek_in).div_ceil(basek_ggsw); for j in 0..(res.rank() + 1).into() {
let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, rhs.size()); // Todo optimise
let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, cols, a_size.div_ceil(dsize));
a_dft.data_mut().fill(0);
if basek_in == basek_ggsw {
for di in 0..dsize {
// (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize)
a_dft.set_size((res.size() + di) / dsize);
// Small optimization for dsize > 2
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
// As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
// It is possible to further ignore the last dsize-1 limbs, but this introduce
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
// noise is kept with respect to the ideal functionality.
res_dft.set_size(rhs.size() - ((dsize - di) as isize - 2).max(0) as usize);
for j in 0..cols {
self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &res.data, j);
}
if di == 0 {
self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_2);
} else {
self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_2);
}
}
} else {
let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self.n(), cols, a_size);
for j in 0..cols {
self.vec_znx_normalize(
basek_ggsw,
&mut a_conv,
j,
basek_in,
&res.data,
j,
scratch_3,
);
}
for di in 0..dsize {
// (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize)
a_dft.set_size((res.size() + di) / dsize);
// Small optimization for dsize > 2
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
// As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
// It is possible to further ignore the last dsize-1 limbs, but this introduce
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
// noise is kept with respect to the ideal functionality.
res_dft.set_size(rhs.size() - ((dsize - di) as isize - 2).max(0) as usize);
for j in 0..cols {
self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &res.data, j);
}
if di == 0 {
self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_2);
} else {
self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_2);
}
}
}
let res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft);
for j in 0..cols {
self.vec_znx_big_normalize( self.vec_znx_big_normalize(
basek_in, basek_in,
&mut res.data, &mut res.data,
@@ -213,7 +134,6 @@ where
let rhs: &GGSWPrepared<&[u8], BE> = &rhs.to_ref(); let rhs: &GGSWPrepared<&[u8], BE> = &rhs.to_ref();
let basek_in: usize = lhs.base2k().into();
let basek_ggsw: usize = rhs.base2k().into(); let basek_ggsw: usize = rhs.base2k().into();
let basek_out: usize = res.base2k().into(); let basek_out: usize = res.base2k().into();
@@ -228,96 +148,45 @@ where
assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, lhs, rhs)); assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, lhs, rhs));
} }
let cols: usize = (rhs.rank() + 1).into(); let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), rhs.size()); // Todo optimise
let dsize: usize = rhs.dsize().into(); let res_big = self.glwe_external_product_internal(res_dft, lhs, rhs, scratch_1);
let a_size: usize = (lhs.size() * basek_in).div_ceil(basek_ggsw); for j in 0..(res.rank() + 1).into() {
let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, rhs.size()); // Todo optimise
let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, cols, a_size.div_ceil(dsize));
a_dft.data_mut().fill(0);
if basek_in == basek_ggsw {
for di in 0..dsize {
// (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize)
a_dft.set_size((lhs.size() + di) / dsize);
// Small optimization for dsize > 2
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
// As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
// It is possible to further ignore the last dsize-1 limbs, but this introduce
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
// noise is kept with respect to the ideal functionality.
res_dft.set_size(rhs.size() - ((dsize - di) as isize - 2).max(0) as usize);
for j in 0..cols {
self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &lhs.data, j);
}
if di == 0 {
self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_2);
} else {
self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_2);
}
}
} else {
let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self.n(), cols, a_size);
for j in 0..cols {
self.vec_znx_normalize(
basek_ggsw,
&mut a_conv,
j,
basek_in,
&lhs.data,
j,
scratch_3,
);
}
for di in 0..dsize {
// (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize)
a_dft.set_size((a_size + di) / dsize);
// Small optimization for dsize > 2
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
// As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
// It is possible to further ignore the last dsize-1 limbs, but this introduce
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
// noise is kept with respect to the ideal functionality.
res_dft.set_size(rhs.size() - ((dsize - di) as isize - 2).max(0) as usize);
for j in 0..cols {
self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &a_conv, j);
}
if di == 0 {
self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_3);
} else {
self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_3);
}
}
}
let res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft);
(0..cols).for_each(|i| {
self.vec_znx_big_normalize( self.vec_znx_big_normalize(
basek_out, basek_out,
res.data_mut(), &mut res.data,
i, j,
basek_ggsw, basek_ggsw,
&res_big, &res_big,
i, j,
scratch_1, scratch_1,
); );
}); }
} }
} }
impl<BE: Backend> GLWEExternalProduct<BE> for Module<BE> where pub trait GLWEExternalProductInternal<BE: Backend> {
fn glwe_external_product_internal_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
where
R: GLWEInfos,
A: GLWEInfos,
B: GGSWInfos;
fn glwe_external_product_internal<DR, A, G>(
&self,
res_dft: VecZnxDft<DR, BE>,
a: &A,
ggsw: &G,
scratch: &mut Scratch<BE>,
) -> VecZnxBig<DR, BE>
where
DR: DataMut,
A: GLWEToRef,
G: GGSWPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>;
}
impl<BE: Backend> GLWEExternalProductInternal<BE> for Module<BE>
where
Self: ModuleN Self: ModuleN
+ VecZnxDftBytesOf + VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes + VmpApplyDftToDftTmpBytes
@@ -330,6 +199,121 @@ impl<BE: Backend> GLWEExternalProduct<BE> for Module<BE> where
+ VecZnxNormalize<BE> + VecZnxNormalize<BE>
+ VecZnxDftBytesOf + VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes + VmpApplyDftToDftTmpBytes
+ VecZnxNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
{ {
fn glwe_external_product_internal_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
where
R: GLWEInfos,
A: GLWEInfos,
B: GGSWInfos,
{
let in_size: usize = a_infos
.k()
.div_ceil(b_infos.base2k())
.div_ceil(b_infos.dsize().into()) as usize;
let out_size: usize = res_infos.size();
let ggsw_size: usize = b_infos.size();
let a_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), in_size);
let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes(
out_size,
in_size,
in_size, // rows
(b_infos.rank() + 1).into(), // cols in
(b_infos.rank() + 1).into(), // cols out
ggsw_size,
);
let normalize_big: usize = self.vec_znx_normalize_tmp_bytes();
if a_infos.base2k() == b_infos.base2k() {
a_dft + (vmp | normalize_big)
} else {
let normalize_conv: usize = VecZnx::bytes_of(self.n(), (b_infos.rank() + 1).into(), in_size);
(a_dft + normalize_conv + (self.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big
}
}
fn glwe_external_product_internal<DR, A, G>(
&self,
mut res_dft: VecZnxDft<DR, BE>,
a: &A,
ggsw: &G,
scratch: &mut Scratch<BE>,
) -> VecZnxBig<DR, BE>
where
DR: DataMut,
A: GLWEToRef,
G: GGSWPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
let a: &GLWE<&[u8]> = &a.to_ref();
let ggsw: &GGSWPrepared<&[u8], BE> = &ggsw.to_ref();
let basek_in: usize = a.base2k().into();
let basek_ggsw: usize = ggsw.base2k().into();
let cols: usize = (ggsw.rank() + 1).into();
let dsize: usize = ggsw.dsize().into();
let a_size: usize = (a.size() * basek_in).div_ceil(basek_ggsw);
let (mut a_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, a_size.div_ceil(dsize));
a_dft.data_mut().fill(0);
if basek_in == basek_ggsw {
for di in 0..dsize {
// (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize)
a_dft.set_size((a.size() + di) / dsize);
// Small optimization for dsize > 2
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
// As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
// It is possible to further ignore the last dsize-1 limbs, but this introduce
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
// noise is kept with respect to the ideal functionality.
res_dft.set_size(ggsw.size() - ((dsize - di) as isize - 2).max(0) as usize);
for j in 0..cols {
self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &a.data, j);
}
if di == 0 {
self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &ggsw.data, scratch_1);
} else {
self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &ggsw.data, di, scratch_1);
}
}
} else {
let (mut a_conv, scratch_3) = scratch_1.take_vec_znx(self.n(), cols, a_size);
for j in 0..cols {
self.vec_znx_normalize(basek_ggsw, &mut a_conv, j, basek_in, &a.data, j, scratch_3);
}
for di in 0..dsize {
// (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize)
a_dft.set_size((a.size() + di) / dsize);
// Small optimization for dsize > 2
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
// As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
// It is possible to further ignore the last dsize-1 limbs, but this introduce
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
// noise is kept with respect to the ideal functionality.
res_dft.set_size(ggsw.size() - ((dsize - di) as isize - 2).max(0) as usize);
for j in 0..cols {
self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &a.data, j);
}
if di == 0 {
self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &ggsw.data, scratch_1);
} else {
self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &ggsw.data, di, scratch_1);
}
}
}
self.vec_znx_idft_apply_consume(res_dft)
}
} }

View File

@@ -382,8 +382,8 @@ impl<D: DataMut, T: UnsignedInteger> FheUint<D, T> {
let (mut out_bits, scratch_1) = scratch.take_glwe_slice(T::BITS as usize, self); let (mut out_bits, scratch_1) = scratch.take_glwe_slice(T::BITS as usize, self);
for i in 0..T::BITS as usize { for (i, bits) in out_bits.iter_mut().enumerate().take(T::BITS as usize) {
module.cmux(&mut out_bits[i], &one, &zero, &other.get_bit(i), scratch_1); module.cmux(bits, &one, &zero, &other.get_bit(i), scratch_1);
} }
self.pack(module, out_bits, keys, scratch_1); self.pack(module, out_bits, keys, scratch_1);

View File

@@ -2,9 +2,13 @@ use core::panic;
use itertools::Itertools; use itertools::Itertools;
use poulpy_core::{ use poulpy_core::{
GLWEAdd, GLWECopy, GLWEExternalProduct, GLWENormalize, GLWESub, ScratchTakeCore, layouts::{GGSWInfos, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, prepared::GGSWPreparedToRef} GLWECopy, GLWEExternalProductInternal, GLWENormalize, GLWESub, ScratchTakeCore,
layouts::{GGSWInfos, GGSWPrepared, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, prepared::GGSWPreparedToRef},
};
use poulpy_hal::{
api::{ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftBytesOf},
layouts::{Backend, DataMut, Module, Scratch, VecZnxBig, ZnxZero},
}; };
use poulpy_hal::layouts::{Backend, DataMut, Module, Scratch, ZnxZero};
use crate::tfhe::bdd_arithmetic::{GetGGSWBit, UnsignedInteger}; use crate::tfhe::bdd_arithmetic::{GetGGSWBit, UnsignedInteger};
@@ -70,8 +74,18 @@ where
{ {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert!(inputs.bit_size() >= circuit.input_size(), "inputs.bit_size(): {} < circuit.input_size():{}", inputs.bit_size(), circuit.input_size()); assert!(
assert!(out.len() >= circuit.output_size(), "out.len(): {} < circuit.output_size(): {}", out.len(), circuit.output_size()); inputs.bit_size() >= circuit.input_size(),
"inputs.bit_size(): {} < circuit.input_size():{}",
inputs.bit_size(),
circuit.input_size()
);
assert!(
out.len() >= circuit.output_size(),
"out.len(): {} < circuit.output_size(): {}",
out.len(),
circuit.output_size()
);
} }
for (i, out_i) in out.iter_mut().enumerate().take(circuit.output_size()) { for (i, out_i) in out.iter_mut().enumerate().take(circuit.output_size()) {
@@ -164,7 +178,14 @@ pub enum Node {
pub trait Cmux<BE: Backend> pub trait Cmux<BE: Backend>
where where
Self: GLWEExternalProduct<BE> + GLWESub + GLWEAdd + GLWENormalize<BE>, Self: Sized
+ GLWEExternalProductInternal<BE>
+ GLWESub
+ VecZnxBigAddSmallInplace<BE>
+ GLWENormalize<BE>
+ VecZnxDftBytesOf
+ VecZnxBigNormalize<BE>
+ VecZnxBigNormalizeTmpBytes,
{ {
fn cmux_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize fn cmux_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
where where
@@ -172,7 +193,11 @@ where
A: GLWEInfos, A: GLWEInfos,
B: GGSWInfos, B: GGSWInfos,
{ {
self.glwe_external_product_tmp_bytes(res_infos, a_infos, b_infos) let res_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), b_infos.size());
res_dft
+ self
.glwe_external_product_internal_tmp_bytes(res_infos, a_infos, b_infos)
.max(self.vec_znx_big_normalize_tmp_bytes())
} }
fn cmux<R, T, F, S>(&self, res: &mut R, t: &T, f: &F, s: &S, scratch: &mut Scratch<BE>) fn cmux<R, T, F, S>(&self, res: &mut R, t: &T, f: &F, s: &S, scratch: &mut Scratch<BE>)
@@ -180,34 +205,73 @@ where
R: GLWEToMut, R: GLWEToMut,
T: GLWEToRef, T: GLWEToRef,
F: GLWEToRef, F: GLWEToRef,
S: GGSWPreparedToRef<BE>, S: GGSWPreparedToRef<BE> + GGSWInfos,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
self.glwe_sub(res, t, f); let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
self.glwe_normalize_inplace(res, scratch); let s: &GGSWPrepared<&[u8], BE> = &s.to_ref();
self.glwe_external_product_inplace(res, s, scratch); let f: GLWE<&[u8]> = f.to_ref();
self.glwe_add_inplace(res, f);
self.glwe_normalize_inplace(res, scratch); let res_base2k: usize = res.base2k().into();
let ggsw_base2k: usize = s.base2k().into();
self.glwe_sub(res, t, &f);
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), s.size()); // Todo optimise
let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_external_product_internal(res_dft, res, s, scratch_1);
for j in 0..(res.rank() + 1).into() {
self.vec_znx_big_add_small_inplace(&mut res_big, j, f.data(), j);
self.vec_znx_big_normalize(
res_base2k,
res.data_mut(),
j,
ggsw_base2k,
&res_big,
j,
scratch_1,
);
}
} }
fn cmux_inplace<R, A, S>(&self, res: &mut R, a: &A, s: &S, scratch: &mut Scratch<BE>) fn cmux_inplace<R, A, S>(&self, res: &mut R, a: &A, s: &S, scratch: &mut Scratch<BE>)
where where
R: GLWEToMut, R: GLWEToMut,
A: GLWEToRef, A: GLWEToRef,
S: GGSWPreparedToRef<BE>, S: GGSWPreparedToRef<BE> + GGSWInfos,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
self.glwe_sub_inplace(res, a); let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
self.glwe_normalize_inplace(res, scratch); let s: &GGSWPrepared<&[u8], BE> = &s.to_ref();
self.glwe_external_product_inplace(res, s, scratch); let a: GLWE<&[u8]> = a.to_ref();
self.glwe_add_inplace(res, a); let res_base2k: usize = res.base2k().into();
self.glwe_normalize_inplace(res, scratch); let ggsw_base2k: usize = s.base2k().into();
self.glwe_sub_inplace(res, &a);
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), s.size()); // Todo optimise
let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_external_product_internal(res_dft, res, s, scratch_1);
for j in 0..(res.rank() + 1).into() {
self.vec_znx_big_add_small_inplace(&mut res_big, j, a.data(), j);
self.vec_znx_big_normalize(
res_base2k,
res.data_mut(),
j,
ggsw_base2k,
&res_big,
j,
scratch_1,
);
}
} }
} }
impl<BE: Backend> Cmux<BE> for Module<BE> impl<BE: Backend> Cmux<BE> for Module<BE>
where where
Self: GLWEExternalProduct<BE> + GLWESub + GLWEAdd + GLWENormalize<BE>, Self: Sized
+ GLWEExternalProductInternal<BE>
+ GLWESub
+ VecZnxBigAddSmallInplace<BE>
+ GLWENormalize<BE>
+ VecZnxDftBytesOf
+ VecZnxBigNormalize<BE>
+ VecZnxBigNormalizeTmpBytes,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
} }