Improve cmux speed

This commit is contained in:
Pro7ech
2025-11-07 17:56:33 +01:00
parent 836df871fe
commit 75842cd80a
4 changed files with 270 additions and 222 deletions

View File

@@ -1,5 +1,5 @@
use poulpy_hal::{
api::ScratchAvailable,
api::{ModuleN, ScratchAvailable},
layouts::{Backend, DataMut, Module, Scratch, ZnxZero},
};
@@ -13,7 +13,7 @@ use crate::{
pub trait GGSWExternalProduct<BE: Backend>
where
Self: GLWEExternalProduct<BE>,
Self: GLWEExternalProduct<BE> + ModuleN,
{
fn ggsw_external_product_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
where

View File

@@ -1,9 +1,10 @@
use poulpy_hal::{
api::{
ModuleN, ScratchTakeBasic, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize,
VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
ModuleN, ScratchTakeBasic, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftApply, VecZnxDftBytesOf,
VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd,
VmpApplyDftToDftTmpBytes,
},
layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnx, VecZnxBig},
layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft},
};
use crate::{
@@ -30,7 +31,7 @@ impl<DataSelf: DataMut> GLWE<DataSelf> {
pub fn external_product<A, B, M, BE: Backend>(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch<BE>)
where
A: GLWEToRef,
B: GGSWPreparedToRef<BE>,
B: GGSWPreparedToRef<BE> + GGSWInfos,
M: GLWEExternalProduct<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
@@ -39,7 +40,7 @@ impl<DataSelf: DataMut> GLWE<DataSelf> {
pub fn external_product_inplace<A, M, BE: Backend>(&mut self, module: &M, a: &A, scratch: &mut Scratch<BE>)
where
A: GGSWPreparedToRef<BE>,
A: GGSWPreparedToRef<BE> + GGSWInfos,
M: GLWEExternalProduct<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
@@ -47,19 +48,30 @@ impl<DataSelf: DataMut> GLWE<DataSelf> {
}
}
pub trait GLWEExternalProduct<BE: Backend>
pub trait GLWEExternalProduct<BE: Backend> {
fn glwe_external_product_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
where
R: GLWEInfos,
A: GLWEInfos,
B: GGSWInfos;
fn glwe_external_product_inplace<R, D>(&self, res: &mut R, a: &D, scratch: &mut Scratch<BE>)
where
R: GLWEToMut,
D: GGSWPreparedToRef<BE> + GGSWInfos,
Scratch<BE>: ScratchTakeCore<BE>;
fn glwe_external_product<R, A, D>(&self, res: &mut R, lhs: &A, rhs: &D, scratch: &mut Scratch<BE>)
where
R: GLWEToMut,
A: GLWEToRef,
D: GGSWPreparedToRef<BE> + GGSWInfos,
Scratch<BE>: ScratchTakeCore<BE>;
}
impl<BE: Backend> GLWEExternalProduct<BE> for Module<BE>
where
Self: Sized
+ ModuleN
+ VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxNormalizeTmpBytes
+ VecZnxDftApply<BE>
+ VmpApplyDftToDft<BE>
+ VmpApplyDftToDftAdd<BE>
+ VecZnxIdftApplyConsume<BE>
+ VecZnxBigNormalize<BE>
+ VecZnxNormalize<BE>,
Self: GLWEExternalProductInternal<BE> + VecZnxDftBytesOf + VecZnxBigNormalize<BE> + VecZnxBigNormalizeTmpBytes,
{
fn glwe_external_product_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
where
@@ -67,36 +79,17 @@ where
A: GLWEInfos,
B: GGSWInfos,
{
let in_size: usize = a_infos
.k()
.div_ceil(b_infos.base2k())
.div_ceil(b_infos.dsize().into()) as usize;
let out_size: usize = res_infos.size();
let ggsw_size: usize = b_infos.size();
let res_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), ggsw_size);
let a_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), in_size);
let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes(
out_size,
in_size,
in_size, // rows
(b_infos.rank() + 1).into(), // cols in
(b_infos.rank() + 1).into(), // cols out
ggsw_size,
);
let normalize_big: usize = self.vec_znx_normalize_tmp_bytes();
if a_infos.base2k() == b_infos.base2k() {
res_dft + a_dft + (vmp | normalize_big)
} else {
let normalize_conv: usize = VecZnx::bytes_of(self.n(), (b_infos.rank() + 1).into(), in_size);
res_dft + ((a_dft + normalize_conv + (self.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big)
}
let res_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), b_infos.size());
res_dft
+ self
.glwe_external_product_internal_tmp_bytes(res_infos, a_infos, b_infos)
.max(self.vec_znx_big_normalize_tmp_bytes())
}
fn glwe_external_product_inplace<R, D>(&self, res: &mut R, a: &D, scratch: &mut Scratch<BE>)
where
R: GLWEToMut,
D: GGSWPreparedToRef<BE>,
D: GGSWPreparedToRef<BE> + GGSWInfos,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
@@ -114,81 +107,9 @@ where
assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, res, rhs));
}
let cols: usize = (rhs.rank() + 1).into();
let dsize: usize = rhs.dsize().into();
let a_size: usize = (res.size() * basek_in).div_ceil(basek_ggsw);
let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, rhs.size()); // Todo optimise
let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, cols, a_size.div_ceil(dsize));
a_dft.data_mut().fill(0);
if basek_in == basek_ggsw {
for di in 0..dsize {
// (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize)
a_dft.set_size((res.size() + di) / dsize);
// Small optimization for dsize > 2
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
// As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
// It is possible to further ignore the last dsize-1 limbs, but this introduce
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
// noise is kept with respect to the ideal functionality.
res_dft.set_size(rhs.size() - ((dsize - di) as isize - 2).max(0) as usize);
for j in 0..cols {
self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &res.data, j);
}
if di == 0 {
self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_2);
} else {
self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_2);
}
}
} else {
let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self.n(), cols, a_size);
for j in 0..cols {
self.vec_znx_normalize(
basek_ggsw,
&mut a_conv,
j,
basek_in,
&res.data,
j,
scratch_3,
);
}
for di in 0..dsize {
// (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize)
a_dft.set_size((res.size() + di) / dsize);
// Small optimization for dsize > 2
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
// As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
// It is possible to further ignore the last dsize-1 limbs, but this introduce
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
// noise is kept with respect to the ideal functionality.
res_dft.set_size(rhs.size() - ((dsize - di) as isize - 2).max(0) as usize);
for j in 0..cols {
self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &res.data, j);
}
if di == 0 {
self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_2);
} else {
self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_2);
}
}
}
let res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft);
for j in 0..cols {
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), a.size()); // Todo optimise
let res_big = self.glwe_external_product_internal(res_dft, res, a, scratch_1);
for j in 0..(res.rank() + 1).into() {
self.vec_znx_big_normalize(
basek_in,
&mut res.data,
@@ -213,7 +134,6 @@ where
let rhs: &GGSWPrepared<&[u8], BE> = &rhs.to_ref();
let basek_in: usize = lhs.base2k().into();
let basek_ggsw: usize = rhs.base2k().into();
let basek_out: usize = res.base2k().into();
@@ -228,96 +148,45 @@ where
assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, lhs, rhs));
}
let cols: usize = (rhs.rank() + 1).into();
let dsize: usize = rhs.dsize().into();
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), rhs.size()); // Todo optimise
let res_big = self.glwe_external_product_internal(res_dft, lhs, rhs, scratch_1);
let a_size: usize = (lhs.size() * basek_in).div_ceil(basek_ggsw);
let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, rhs.size()); // Todo optimise
let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, cols, a_size.div_ceil(dsize));
a_dft.data_mut().fill(0);
if basek_in == basek_ggsw {
for di in 0..dsize {
// (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize)
a_dft.set_size((lhs.size() + di) / dsize);
// Small optimization for dsize > 2
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
// As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
// It is possible to further ignore the last dsize-1 limbs, but this introduce
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
// noise is kept with respect to the ideal functionality.
res_dft.set_size(rhs.size() - ((dsize - di) as isize - 2).max(0) as usize);
for j in 0..cols {
self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &lhs.data, j);
}
if di == 0 {
self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_2);
} else {
self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_2);
}
}
} else {
let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self.n(), cols, a_size);
for j in 0..cols {
self.vec_znx_normalize(
basek_ggsw,
&mut a_conv,
j,
basek_in,
&lhs.data,
j,
scratch_3,
);
}
for di in 0..dsize {
// (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize)
a_dft.set_size((a_size + di) / dsize);
// Small optimization for dsize > 2
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
// As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
// It is possible to further ignore the last dsize-1 limbs, but this introduce
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
// noise is kept with respect to the ideal functionality.
res_dft.set_size(rhs.size() - ((dsize - di) as isize - 2).max(0) as usize);
for j in 0..cols {
self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &a_conv, j);
}
if di == 0 {
self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_3);
} else {
self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_3);
}
}
}
let res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft);
(0..cols).for_each(|i| {
for j in 0..(res.rank() + 1).into() {
self.vec_znx_big_normalize(
basek_out,
res.data_mut(),
i,
&mut res.data,
j,
basek_ggsw,
&res_big,
i,
j,
scratch_1,
);
});
}
}
}
impl<BE: Backend> GLWEExternalProduct<BE> for Module<BE> where
pub trait GLWEExternalProductInternal<BE: Backend> {
fn glwe_external_product_internal_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
where
R: GLWEInfos,
A: GLWEInfos,
B: GGSWInfos;
fn glwe_external_product_internal<DR, A, G>(
&self,
res_dft: VecZnxDft<DR, BE>,
a: &A,
ggsw: &G,
scratch: &mut Scratch<BE>,
) -> VecZnxBig<DR, BE>
where
DR: DataMut,
A: GLWEToRef,
G: GGSWPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>;
}
impl<BE: Backend> GLWEExternalProductInternal<BE> for Module<BE>
where
Self: ModuleN
+ VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
@@ -330,6 +199,121 @@ impl<BE: Backend> GLWEExternalProduct<BE> for Module<BE> where
+ VecZnxNormalize<BE>
+ VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxNormalizeTmpBytes
+ VecZnxNormalizeTmpBytes,
{
fn glwe_external_product_internal_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
where
R: GLWEInfos,
A: GLWEInfos,
B: GGSWInfos,
{
let in_size: usize = a_infos
.k()
.div_ceil(b_infos.base2k())
.div_ceil(b_infos.dsize().into()) as usize;
let out_size: usize = res_infos.size();
let ggsw_size: usize = b_infos.size();
let a_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), in_size);
let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes(
out_size,
in_size,
in_size, // rows
(b_infos.rank() + 1).into(), // cols in
(b_infos.rank() + 1).into(), // cols out
ggsw_size,
);
let normalize_big: usize = self.vec_znx_normalize_tmp_bytes();
if a_infos.base2k() == b_infos.base2k() {
a_dft + (vmp | normalize_big)
} else {
let normalize_conv: usize = VecZnx::bytes_of(self.n(), (b_infos.rank() + 1).into(), in_size);
(a_dft + normalize_conv + (self.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big
}
}
fn glwe_external_product_internal<DR, A, G>(
&self,
mut res_dft: VecZnxDft<DR, BE>,
a: &A,
ggsw: &G,
scratch: &mut Scratch<BE>,
) -> VecZnxBig<DR, BE>
where
DR: DataMut,
A: GLWEToRef,
G: GGSWPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
let a: &GLWE<&[u8]> = &a.to_ref();
let ggsw: &GGSWPrepared<&[u8], BE> = &ggsw.to_ref();
let basek_in: usize = a.base2k().into();
let basek_ggsw: usize = ggsw.base2k().into();
let cols: usize = (ggsw.rank() + 1).into();
let dsize: usize = ggsw.dsize().into();
let a_size: usize = (a.size() * basek_in).div_ceil(basek_ggsw);
let (mut a_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, a_size.div_ceil(dsize));
a_dft.data_mut().fill(0);
if basek_in == basek_ggsw {
for di in 0..dsize {
// (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize)
a_dft.set_size((a.size() + di) / dsize);
// Small optimization for dsize > 2
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
// As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
// It is possible to further ignore the last dsize-1 limbs, but this introduce
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
// noise is kept with respect to the ideal functionality.
res_dft.set_size(ggsw.size() - ((dsize - di) as isize - 2).max(0) as usize);
for j in 0..cols {
self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &a.data, j);
}
if di == 0 {
self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &ggsw.data, scratch_1);
} else {
self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &ggsw.data, di, scratch_1);
}
}
} else {
let (mut a_conv, scratch_3) = scratch_1.take_vec_znx(self.n(), cols, a_size);
for j in 0..cols {
self.vec_znx_normalize(basek_ggsw, &mut a_conv, j, basek_in, &a.data, j, scratch_3);
}
for di in 0..dsize {
// (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize)
a_dft.set_size((a.size() + di) / dsize);
// Small optimization for dsize > 2
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
// As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
// It is possible to further ignore the last dsize-1 limbs, but this introduce
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
// noise is kept with respect to the ideal functionality.
res_dft.set_size(ggsw.size() - ((dsize - di) as isize - 2).max(0) as usize);
for j in 0..cols {
self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &a.data, j);
}
if di == 0 {
self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &ggsw.data, scratch_1);
} else {
self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &ggsw.data, di, scratch_1);
}
}
}
self.vec_znx_idft_apply_consume(res_dft)
}
}

View File

@@ -382,8 +382,8 @@ impl<D: DataMut, T: UnsignedInteger> FheUint<D, T> {
let (mut out_bits, scratch_1) = scratch.take_glwe_slice(T::BITS as usize, self);
for i in 0..T::BITS as usize {
module.cmux(&mut out_bits[i], &one, &zero, &other.get_bit(i), scratch_1);
for (i, bits) in out_bits.iter_mut().enumerate().take(T::BITS as usize) {
module.cmux(bits, &one, &zero, &other.get_bit(i), scratch_1);
}
self.pack(module, out_bits, keys, scratch_1);

View File

@@ -2,9 +2,13 @@ use core::panic;
use itertools::Itertools;
use poulpy_core::{
GLWEAdd, GLWECopy, GLWEExternalProduct, GLWENormalize, GLWESub, ScratchTakeCore, layouts::{GGSWInfos, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, prepared::GGSWPreparedToRef}
GLWECopy, GLWEExternalProductInternal, GLWENormalize, GLWESub, ScratchTakeCore,
layouts::{GGSWInfos, GGSWPrepared, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, prepared::GGSWPreparedToRef},
};
use poulpy_hal::{
api::{ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftBytesOf},
layouts::{Backend, DataMut, Module, Scratch, VecZnxBig, ZnxZero},
};
use poulpy_hal::layouts::{Backend, DataMut, Module, Scratch, ZnxZero};
use crate::tfhe::bdd_arithmetic::{GetGGSWBit, UnsignedInteger};
@@ -70,8 +74,18 @@ where
{
#[cfg(debug_assertions)]
{
assert!(inputs.bit_size() >= circuit.input_size(), "inputs.bit_size(): {} < circuit.input_size():{}", inputs.bit_size(), circuit.input_size());
assert!(out.len() >= circuit.output_size(), "out.len(): {} < circuit.output_size(): {}", out.len(), circuit.output_size());
assert!(
inputs.bit_size() >= circuit.input_size(),
"inputs.bit_size(): {} < circuit.input_size():{}",
inputs.bit_size(),
circuit.input_size()
);
assert!(
out.len() >= circuit.output_size(),
"out.len(): {} < circuit.output_size(): {}",
out.len(),
circuit.output_size()
);
}
for (i, out_i) in out.iter_mut().enumerate().take(circuit.output_size()) {
@@ -164,7 +178,14 @@ pub enum Node {
pub trait Cmux<BE: Backend>
where
Self: GLWEExternalProduct<BE> + GLWESub + GLWEAdd + GLWENormalize<BE>,
Self: Sized
+ GLWEExternalProductInternal<BE>
+ GLWESub
+ VecZnxBigAddSmallInplace<BE>
+ GLWENormalize<BE>
+ VecZnxDftBytesOf
+ VecZnxBigNormalize<BE>
+ VecZnxBigNormalizeTmpBytes,
{
fn cmux_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
where
@@ -172,7 +193,11 @@ where
A: GLWEInfos,
B: GGSWInfos,
{
self.glwe_external_product_tmp_bytes(res_infos, a_infos, b_infos)
let res_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), b_infos.size());
res_dft
+ self
.glwe_external_product_internal_tmp_bytes(res_infos, a_infos, b_infos)
.max(self.vec_znx_big_normalize_tmp_bytes())
}
fn cmux<R, T, F, S>(&self, res: &mut R, t: &T, f: &F, s: &S, scratch: &mut Scratch<BE>)
@@ -180,34 +205,73 @@ where
R: GLWEToMut,
T: GLWEToRef,
F: GLWEToRef,
S: GGSWPreparedToRef<BE>,
S: GGSWPreparedToRef<BE> + GGSWInfos,
Scratch<BE>: ScratchTakeCore<BE>,
{
self.glwe_sub(res, t, f);
self.glwe_normalize_inplace(res, scratch);
self.glwe_external_product_inplace(res, s, scratch);
self.glwe_add_inplace(res, f);
self.glwe_normalize_inplace(res, scratch);
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let s: &GGSWPrepared<&[u8], BE> = &s.to_ref();
let f: GLWE<&[u8]> = f.to_ref();
let res_base2k: usize = res.base2k().into();
let ggsw_base2k: usize = s.base2k().into();
self.glwe_sub(res, t, &f);
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), s.size()); // Todo optimise
let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_external_product_internal(res_dft, res, s, scratch_1);
for j in 0..(res.rank() + 1).into() {
self.vec_znx_big_add_small_inplace(&mut res_big, j, f.data(), j);
self.vec_znx_big_normalize(
res_base2k,
res.data_mut(),
j,
ggsw_base2k,
&res_big,
j,
scratch_1,
);
}
}
fn cmux_inplace<R, A, S>(&self, res: &mut R, a: &A, s: &S, scratch: &mut Scratch<BE>)
where
R: GLWEToMut,
A: GLWEToRef,
S: GGSWPreparedToRef<BE>,
S: GGSWPreparedToRef<BE> + GGSWInfos,
Scratch<BE>: ScratchTakeCore<BE>,
{
self.glwe_sub_inplace(res, a);
self.glwe_normalize_inplace(res, scratch);
self.glwe_external_product_inplace(res, s, scratch);
self.glwe_add_inplace(res, a);
self.glwe_normalize_inplace(res, scratch);
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let s: &GGSWPrepared<&[u8], BE> = &s.to_ref();
let a: GLWE<&[u8]> = a.to_ref();
let res_base2k: usize = res.base2k().into();
let ggsw_base2k: usize = s.base2k().into();
self.glwe_sub_inplace(res, &a);
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), s.size()); // Todo optimise
let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_external_product_internal(res_dft, res, s, scratch_1);
for j in 0..(res.rank() + 1).into() {
self.vec_znx_big_add_small_inplace(&mut res_big, j, a.data(), j);
self.vec_znx_big_normalize(
res_base2k,
res.data_mut(),
j,
ggsw_base2k,
&res_big,
j,
scratch_1,
);
}
}
}
impl<BE: Backend> Cmux<BE> for Module<BE>
where
Self: GLWEExternalProduct<BE> + GLWESub + GLWEAdd + GLWENormalize<BE>,
Self: Sized
+ GLWEExternalProductInternal<BE>
+ GLWESub
+ VecZnxBigAddSmallInplace<BE>
+ GLWENormalize<BE>
+ VecZnxDftBytesOf
+ VecZnxBigNormalize<BE>
+ VecZnxBigNormalizeTmpBytes,
Scratch<BE>: ScratchTakeCore<BE>,
{
}