diff --git a/poulpy-core/src/external_product/ggsw.rs b/poulpy-core/src/external_product/ggsw.rs index f9659fe..d33055d 100644 --- a/poulpy-core/src/external_product/ggsw.rs +++ b/poulpy-core/src/external_product/ggsw.rs @@ -1,5 +1,5 @@ use poulpy_hal::{ - api::ScratchAvailable, + api::{ModuleN, ScratchAvailable}, layouts::{Backend, DataMut, Module, Scratch, ZnxZero}, }; @@ -13,7 +13,7 @@ use crate::{ pub trait GGSWExternalProduct where - Self: GLWEExternalProduct, + Self: GLWEExternalProduct + ModuleN, { fn ggsw_external_product_tmp_bytes(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize where diff --git a/poulpy-core/src/external_product/glwe.rs b/poulpy-core/src/external_product/glwe.rs index a1cd5ee..303499d 100644 --- a/poulpy-core/src/external_product/glwe.rs +++ b/poulpy-core/src/external_product/glwe.rs @@ -1,9 +1,10 @@ use poulpy_hal::{ api::{ - ModuleN, ScratchTakeBasic, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, - VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + ModuleN, ScratchTakeBasic, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftApply, VecZnxDftBytesOf, + VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, + VmpApplyDftToDftTmpBytes, }, - layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnx, VecZnxBig}, + layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft}, }; use crate::{ @@ -30,7 +31,7 @@ impl GLWE { pub fn external_product(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch) where A: GLWEToRef, - B: GGSWPreparedToRef, + B: GGSWPreparedToRef + GGSWInfos, M: GLWEExternalProduct, Scratch: ScratchTakeCore, { @@ -39,7 +40,7 @@ impl GLWE { pub fn external_product_inplace(&mut self, module: &M, a: &A, scratch: &mut Scratch) where - A: GGSWPreparedToRef, + A: GGSWPreparedToRef + GGSWInfos, M: GLWEExternalProduct, Scratch: ScratchTakeCore, { @@ -47,19 +48,30 @@ impl GLWE { } } -pub trait GLWEExternalProduct +pub trait GLWEExternalProduct { + fn glwe_external_product_tmp_bytes(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize + where + R: GLWEInfos, + A: GLWEInfos, + B: GGSWInfos; + + fn glwe_external_product_inplace(&self, res: &mut R, a: &D, scratch: &mut Scratch) + where + R: GLWEToMut, + D: GGSWPreparedToRef + GGSWInfos, + Scratch: ScratchTakeCore; + + fn glwe_external_product(&self, res: &mut R, lhs: &A, rhs: &D, scratch: &mut Scratch) + where + R: GLWEToMut, + A: GLWEToRef, + D: GGSWPreparedToRef + GGSWInfos, + Scratch: ScratchTakeCore; +} + +impl GLWEExternalProduct for Module where - Self: Sized - + ModuleN - + VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxNormalizeTmpBytes - + VecZnxDftApply - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxIdftApplyConsume - + VecZnxBigNormalize - + VecZnxNormalize, + Self: GLWEExternalProductInternal + VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxBigNormalizeTmpBytes, { fn glwe_external_product_tmp_bytes(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize where @@ -67,36 +79,17 @@ where A: GLWEInfos, B: GGSWInfos, { - let in_size: usize = a_infos - .k() - .div_ceil(b_infos.base2k()) - .div_ceil(b_infos.dsize().into()) as usize; - let out_size: usize = res_infos.size(); - let ggsw_size: usize = b_infos.size(); - let res_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), ggsw_size); - let a_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), in_size); - let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes( - out_size, - in_size, - in_size, // rows - (b_infos.rank() + 1).into(), // cols in - (b_infos.rank() + 1).into(), // cols out - ggsw_size, - ); - let normalize_big: usize = self.vec_znx_normalize_tmp_bytes(); - - if a_infos.base2k() == b_infos.base2k() { - res_dft + a_dft + (vmp | normalize_big) - } else { - let normalize_conv: usize = VecZnx::bytes_of(self.n(), (b_infos.rank() + 1).into(), in_size); - res_dft + ((a_dft + normalize_conv + (self.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big) - } + let res_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), b_infos.size()); + res_dft + + self + .glwe_external_product_internal_tmp_bytes(res_infos, a_infos, b_infos) + .max(self.vec_znx_big_normalize_tmp_bytes()) } fn glwe_external_product_inplace(&self, res: &mut R, a: &D, scratch: &mut Scratch) where R: GLWEToMut, - D: GGSWPreparedToRef, + D: GGSWPreparedToRef + GGSWInfos, Scratch: ScratchTakeCore, { let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); @@ -114,81 +107,9 @@ where assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, res, rhs)); } - let cols: usize = (rhs.rank() + 1).into(); - let dsize: usize = rhs.dsize().into(); - let a_size: usize = (res.size() * basek_in).div_ceil(basek_ggsw); - - let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, rhs.size()); // Todo optimise - let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, cols, a_size.div_ceil(dsize)); - a_dft.data_mut().fill(0); - - if basek_in == basek_ggsw { - for di in 0..dsize { - // (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize) - a_dft.set_size((res.size() + di) / dsize); - - // Small optimization for dsize > 2 - // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then - // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}. - // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. - // It is possible to further ignore the last dsize-1 limbs, but this introduce - // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same - // noise is kept with respect to the ideal functionality. - res_dft.set_size(rhs.size() - ((dsize - di) as isize - 2).max(0) as usize); - - for j in 0..cols { - self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &res.data, j); - } - - if di == 0 { - self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_2); - } else { - self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_2); - } - } - } else { - let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self.n(), cols, a_size); - - for j in 0..cols { - self.vec_znx_normalize( - basek_ggsw, - &mut a_conv, - j, - basek_in, - &res.data, - j, - scratch_3, - ); - } - - for di in 0..dsize { - // (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize) - a_dft.set_size((res.size() + di) / dsize); - - // Small optimization for dsize > 2 - // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then - // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}. - // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. - // It is possible to further ignore the last dsize-1 limbs, but this introduce - // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same - // noise is kept with respect to the ideal functionality. - res_dft.set_size(rhs.size() - ((dsize - di) as isize - 2).max(0) as usize); - - for j in 0..cols { - self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &res.data, j); - } - - if di == 0 { - self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_2); - } else { - self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_2); - } - } - } - - let res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft); - - for j in 0..cols { + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), a.size()); // Todo optimise + let res_big = self.glwe_external_product_internal(res_dft, res, a, scratch_1); + for j in 0..(res.rank() + 1).into() { self.vec_znx_big_normalize( basek_in, &mut res.data, @@ -213,7 +134,6 @@ where let rhs: &GGSWPrepared<&[u8], BE> = &rhs.to_ref(); - let basek_in: usize = lhs.base2k().into(); let basek_ggsw: usize = rhs.base2k().into(); let basek_out: usize = res.base2k().into(); @@ -228,96 +148,45 @@ where assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, lhs, rhs)); } - let cols: usize = (rhs.rank() + 1).into(); - let dsize: usize = rhs.dsize().into(); + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), rhs.size()); // Todo optimise + let res_big = self.glwe_external_product_internal(res_dft, lhs, rhs, scratch_1); - let a_size: usize = (lhs.size() * basek_in).div_ceil(basek_ggsw); - - let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, rhs.size()); // Todo optimise - let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, cols, a_size.div_ceil(dsize)); - a_dft.data_mut().fill(0); - - if basek_in == basek_ggsw { - for di in 0..dsize { - // (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize) - a_dft.set_size((lhs.size() + di) / dsize); - - // Small optimization for dsize > 2 - // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then - // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}. - // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. - // It is possible to further ignore the last dsize-1 limbs, but this introduce - // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same - // noise is kept with respect to the ideal functionality. - res_dft.set_size(rhs.size() - ((dsize - di) as isize - 2).max(0) as usize); - - for j in 0..cols { - self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &lhs.data, j); - } - - if di == 0 { - self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_2); - } else { - self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_2); - } - } - } else { - let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self.n(), cols, a_size); - - for j in 0..cols { - self.vec_znx_normalize( - basek_ggsw, - &mut a_conv, - j, - basek_in, - &lhs.data, - j, - scratch_3, - ); - } - - for di in 0..dsize { - // (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize) - a_dft.set_size((a_size + di) / dsize); - - // Small optimization for dsize > 2 - // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then - // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}. - // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. - // It is possible to further ignore the last dsize-1 limbs, but this introduce - // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same - // noise is kept with respect to the ideal functionality. - res_dft.set_size(rhs.size() - ((dsize - di) as isize - 2).max(0) as usize); - - for j in 0..cols { - self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &a_conv, j); - } - - if di == 0 { - self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_3); - } else { - self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_3); - } - } - } - - let res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft); - - (0..cols).for_each(|i| { + for j in 0..(res.rank() + 1).into() { self.vec_znx_big_normalize( basek_out, - res.data_mut(), - i, + &mut res.data, + j, basek_ggsw, &res_big, - i, + j, scratch_1, ); - }); + } } } -impl GLWEExternalProduct for Module where +pub trait GLWEExternalProductInternal { + fn glwe_external_product_internal_tmp_bytes(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize + where + R: GLWEInfos, + A: GLWEInfos, + B: GGSWInfos; + fn glwe_external_product_internal( + &self, + res_dft: VecZnxDft, + a: &A, + ggsw: &G, + scratch: &mut Scratch, + ) -> VecZnxBig + where + DR: DataMut, + A: GLWEToRef, + G: GGSWPreparedToRef, + Scratch: ScratchTakeCore; +} + +impl GLWEExternalProductInternal for Module +where Self: ModuleN + VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes @@ -330,6 +199,121 @@ impl GLWEExternalProduct for Module where + VecZnxNormalize + VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes - + VecZnxNormalizeTmpBytes + + VecZnxNormalizeTmpBytes, { + fn glwe_external_product_internal_tmp_bytes(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize + where + R: GLWEInfos, + A: GLWEInfos, + B: GGSWInfos, + { + let in_size: usize = a_infos + .k() + .div_ceil(b_infos.base2k()) + .div_ceil(b_infos.dsize().into()) as usize; + let out_size: usize = res_infos.size(); + let ggsw_size: usize = b_infos.size(); + let a_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), in_size); + let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes( + out_size, + in_size, + in_size, // rows + (b_infos.rank() + 1).into(), // cols in + (b_infos.rank() + 1).into(), // cols out + ggsw_size, + ); + let normalize_big: usize = self.vec_znx_normalize_tmp_bytes(); + + if a_infos.base2k() == b_infos.base2k() { + a_dft + (vmp | normalize_big) + } else { + let normalize_conv: usize = VecZnx::bytes_of(self.n(), (b_infos.rank() + 1).into(), in_size); + (a_dft + normalize_conv + (self.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big + } + } + + fn glwe_external_product_internal( + &self, + mut res_dft: VecZnxDft, + a: &A, + ggsw: &G, + scratch: &mut Scratch, + ) -> VecZnxBig + where + DR: DataMut, + A: GLWEToRef, + G: GGSWPreparedToRef, + Scratch: ScratchTakeCore, + { + let a: &GLWE<&[u8]> = &a.to_ref(); + let ggsw: &GGSWPrepared<&[u8], BE> = &ggsw.to_ref(); + + let basek_in: usize = a.base2k().into(); + let basek_ggsw: usize = ggsw.base2k().into(); + + let cols: usize = (ggsw.rank() + 1).into(); + let dsize: usize = ggsw.dsize().into(); + let a_size: usize = (a.size() * basek_in).div_ceil(basek_ggsw); + + let (mut a_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, a_size.div_ceil(dsize)); + a_dft.data_mut().fill(0); + + if basek_in == basek_ggsw { + for di in 0..dsize { + // (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize) + a_dft.set_size((a.size() + di) / dsize); + + // Small optimization for dsize > 2 + // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then + // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}. + // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. + // It is possible to further ignore the last dsize-1 limbs, but this introduce + // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same + // noise is kept with respect to the ideal functionality. + res_dft.set_size(ggsw.size() - ((dsize - di) as isize - 2).max(0) as usize); + + for j in 0..cols { + self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &a.data, j); + } + + if di == 0 { + self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &ggsw.data, scratch_1); + } else { + self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &ggsw.data, di, scratch_1); + } + } + } else { + let (mut a_conv, scratch_3) = scratch_1.take_vec_znx(self.n(), cols, a_size); + + for j in 0..cols { + self.vec_znx_normalize(basek_ggsw, &mut a_conv, j, basek_in, &a.data, j, scratch_3); + } + + for di in 0..dsize { + // (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize) + a_dft.set_size((a.size() + di) / dsize); + + // Small optimization for dsize > 2 + // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then + // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}. + // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. + // It is possible to further ignore the last dsize-1 limbs, but this introduce + // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same + // noise is kept with respect to the ideal functionality. + res_dft.set_size(ggsw.size() - ((dsize - di) as isize - 2).max(0) as usize); + + for j in 0..cols { + self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &a.data, j); + } + + if di == 0 { + self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &ggsw.data, scratch_1); + } else { + self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &ggsw.data, di, scratch_1); + } + } + } + + self.vec_znx_idft_apply_consume(res_dft) + } } diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint.rs index 6bd8be1..819af68 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint.rs @@ -382,8 +382,8 @@ impl FheUint { let (mut out_bits, scratch_1) = scratch.take_glwe_slice(T::BITS as usize, self); - for i in 0..T::BITS as usize { - module.cmux(&mut out_bits[i], &one, &zero, &other.get_bit(i), scratch_1); + for (i, bits) in out_bits.iter_mut().enumerate().take(T::BITS as usize) { + module.cmux(bits, &one, &zero, &other.get_bit(i), scratch_1); } self.pack(module, out_bits, keys, scratch_1); diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs index dbd1862..618304f 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs @@ -2,9 +2,13 @@ use core::panic; use itertools::Itertools; use poulpy_core::{ - GLWEAdd, GLWECopy, GLWEExternalProduct, GLWENormalize, GLWESub, ScratchTakeCore, layouts::{GGSWInfos, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, prepared::GGSWPreparedToRef} + GLWECopy, GLWEExternalProductInternal, GLWENormalize, GLWESub, ScratchTakeCore, + layouts::{GGSWInfos, GGSWPrepared, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, prepared::GGSWPreparedToRef}, +}; +use poulpy_hal::{ + api::{ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftBytesOf}, + layouts::{Backend, DataMut, Module, Scratch, VecZnxBig, ZnxZero}, }; -use poulpy_hal::layouts::{Backend, DataMut, Module, Scratch, ZnxZero}; use crate::tfhe::bdd_arithmetic::{GetGGSWBit, UnsignedInteger}; @@ -70,8 +74,18 @@ where { #[cfg(debug_assertions)] { - assert!(inputs.bit_size() >= circuit.input_size(), "inputs.bit_size(): {} < circuit.input_size():{}", inputs.bit_size(), circuit.input_size()); - assert!(out.len() >= circuit.output_size(), "out.len(): {} < circuit.output_size(): {}", out.len(), circuit.output_size()); + assert!( + inputs.bit_size() >= circuit.input_size(), + "inputs.bit_size(): {} < circuit.input_size():{}", + inputs.bit_size(), + circuit.input_size() + ); + assert!( + out.len() >= circuit.output_size(), + "out.len(): {} < circuit.output_size(): {}", + out.len(), + circuit.output_size() + ); } for (i, out_i) in out.iter_mut().enumerate().take(circuit.output_size()) { @@ -164,7 +178,14 @@ pub enum Node { pub trait Cmux where - Self: GLWEExternalProduct + GLWESub + GLWEAdd + GLWENormalize, + Self: Sized + + GLWEExternalProductInternal + + GLWESub + + VecZnxBigAddSmallInplace + + GLWENormalize + + VecZnxDftBytesOf + + VecZnxBigNormalize + + VecZnxBigNormalizeTmpBytes, { fn cmux_tmp_bytes(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize where @@ -172,7 +193,11 @@ where A: GLWEInfos, B: GGSWInfos, { - self.glwe_external_product_tmp_bytes(res_infos, a_infos, b_infos) + let res_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), b_infos.size()); + res_dft + + self + .glwe_external_product_internal_tmp_bytes(res_infos, a_infos, b_infos) + .max(self.vec_znx_big_normalize_tmp_bytes()) } fn cmux(&self, res: &mut R, t: &T, f: &F, s: &S, scratch: &mut Scratch) @@ -180,34 +205,73 @@ where R: GLWEToMut, T: GLWEToRef, F: GLWEToRef, - S: GGSWPreparedToRef, + S: GGSWPreparedToRef + GGSWInfos, Scratch: ScratchTakeCore, { - self.glwe_sub(res, t, f); - self.glwe_normalize_inplace(res, scratch); - self.glwe_external_product_inplace(res, s, scratch); - self.glwe_add_inplace(res, f); - self.glwe_normalize_inplace(res, scratch); + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let s: &GGSWPrepared<&[u8], BE> = &s.to_ref(); + let f: GLWE<&[u8]> = f.to_ref(); + + let res_base2k: usize = res.base2k().into(); + let ggsw_base2k: usize = s.base2k().into(); + + self.glwe_sub(res, t, &f); + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), s.size()); // Todo optimise + let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_external_product_internal(res_dft, res, s, scratch_1); + for j in 0..(res.rank() + 1).into() { + self.vec_znx_big_add_small_inplace(&mut res_big, j, f.data(), j); + self.vec_znx_big_normalize( + res_base2k, + res.data_mut(), + j, + ggsw_base2k, + &res_big, + j, + scratch_1, + ); + } } fn cmux_inplace(&self, res: &mut R, a: &A, s: &S, scratch: &mut Scratch) where R: GLWEToMut, A: GLWEToRef, - S: GGSWPreparedToRef, + S: GGSWPreparedToRef + GGSWInfos, Scratch: ScratchTakeCore, { - self.glwe_sub_inplace(res, a); - self.glwe_normalize_inplace(res, scratch); - self.glwe_external_product_inplace(res, s, scratch); - self.glwe_add_inplace(res, a); - self.glwe_normalize_inplace(res, scratch); + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let s: &GGSWPrepared<&[u8], BE> = &s.to_ref(); + let a: GLWE<&[u8]> = a.to_ref(); + let res_base2k: usize = res.base2k().into(); + let ggsw_base2k: usize = s.base2k().into(); + self.glwe_sub_inplace(res, &a); + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), s.size()); // Todo optimise + let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_external_product_internal(res_dft, res, s, scratch_1); + for j in 0..(res.rank() + 1).into() { + self.vec_znx_big_add_small_inplace(&mut res_big, j, a.data(), j); + self.vec_znx_big_normalize( + res_base2k, + res.data_mut(), + j, + ggsw_base2k, + &res_big, + j, + scratch_1, + ); + } } } impl Cmux for Module where - Self: GLWEExternalProduct + GLWESub + GLWEAdd + GLWENormalize, + Self: Sized + + GLWEExternalProductInternal + + GLWESub + + VecZnxBigAddSmallInplace + + GLWENormalize + + VecZnxDftBytesOf + + VecZnxBigNormalize + + VecZnxBigNormalizeTmpBytes, Scratch: ScratchTakeCore, { }