This commit is contained in:
Pro7ech
2025-10-14 23:45:00 +02:00
parent 779e02acc4
commit a5df85170d
2 changed files with 16 additions and 16 deletions

View File

@@ -15,7 +15,7 @@ impl GGSW<Vec<u8>> {
module: &Module<B>, module: &Module<B>,
out_infos: &OUT, out_infos: &OUT,
in_infos: &IN, in_infos: &IN,
apply_infos: &GGSW, ggsw_infos: &GGSW,
) -> usize ) -> usize
where where
OUT: GGSWInfos, OUT: GGSWInfos,
@@ -27,21 +27,21 @@ impl GGSW<Vec<u8>> {
module, module,
&out_infos.glwe_layout(), &out_infos.glwe_layout(),
&in_infos.glwe_layout(), &in_infos.glwe_layout(),
apply_infos, ggsw_infos,
) )
} }
pub fn external_product_inplace_scratch_space<B: Backend, OUT, GGSW>( pub fn external_product_inplace_scratch_space<B: Backend, OUT, GGSW>(
module: &Module<B>, module: &Module<B>,
out_infos: &OUT, out_infos: &OUT,
apply_infos: &GGSW, ggsw_infos: &GGSW,
) -> usize ) -> usize
where where
OUT: GGSWInfos, OUT: GGSWInfos,
GGSW: GGSWInfos, GGSW: GGSWInfos,
Module<B>: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes, Module<B>: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
{ {
GLWE::external_product_inplace_scratch_space(module, &out_infos.glwe_layout(), apply_infos) GLWE::external_product_inplace_scratch_space(module, &out_infos.glwe_layout(), ggsw_infos)
} }
} }

View File

@@ -19,7 +19,7 @@ impl<DataSelf: DataMut> GLWE<DataSelf> {
module: Module<B>, module: Module<B>,
out_infos: &OUT, out_infos: &OUT,
in_infos: &IN, in_infos: &IN,
apply_infos: &GGSW, ggsw_infos: &GGSW,
) -> usize ) -> usize
where where
OUT: GLWEInfos, OUT: GLWEInfos,
@@ -27,7 +27,7 @@ impl<DataSelf: DataMut> GLWE<DataSelf> {
GGSW: GGSWInfos, GGSW: GGSWInfos,
Module<B>: GLWEExternalProduct<B>, Module<B>: GLWEExternalProduct<B>,
{ {
module.glwe_external_product_scratch_space(out_infos, in_infos, apply_infos) module.glwe_external_product_scratch_space(out_infos, in_infos, ggsw_infos)
} }
pub fn external_product<L, R, B: Backend>(&mut self, module: &Module<B>, lhs: &L, rhs: &R, scratch: &mut Scratch<B>) pub fn external_product<L, R, B: Backend>(&mut self, module: &Module<B>, lhs: &L, rhs: &R, scratch: &mut Scratch<B>)
@@ -67,7 +67,7 @@ where
+ VecZnxNormalizeTmpBytes, + VecZnxNormalizeTmpBytes,
{ {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn glwe_external_product_scratch_space<OUT, IN, GGSW>(&self, out_infos: &OUT, in_infos: &IN, apply_infos: &GGSW) -> usize fn glwe_external_product_scratch_space<OUT, IN, GGSW>(&self, out_infos: &OUT, in_infos: &IN, ggsw_infos: &GGSW) -> usize
where where
OUT: GLWEInfos, OUT: GLWEInfos,
IN: GLWEInfos, IN: GLWEInfos,
@@ -75,26 +75,26 @@ where
{ {
let in_size: usize = in_infos let in_size: usize = in_infos
.k() .k()
.div_ceil(apply_infos.base2k()) .div_ceil(ggsw_infos.base2k())
.div_ceil(apply_infos.dsize().into()) as usize; .div_ceil(ggsw_infos.dsize().into()) as usize;
let out_size: usize = out_infos.size(); let out_size: usize = out_infos.size();
let ggsw_size: usize = apply_infos.size(); let ggsw_size: usize = ggsw_infos.size();
let res_dft: usize = self.bytes_of_vec_znx_dft((apply_infos.rank() + 1).into(), ggsw_size); let res_dft: usize = self.bytes_of_vec_znx_dft((ggsw_infos.rank() + 1).into(), ggsw_size);
let a_dft: usize = self.bytes_of_vec_znx_dft((apply_infos.rank() + 1).into(), in_size); let a_dft: usize = self.bytes_of_vec_znx_dft((ggsw_infos.rank() + 1).into(), in_size);
let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes( let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes(
out_size, out_size,
in_size, in_size,
in_size, // rows in_size, // rows
(apply_infos.rank() + 1).into(), // cols in (ggsw_infos.rank() + 1).into(), // cols in
(apply_infos.rank() + 1).into(), // cols out (ggsw_infos.rank() + 1).into(), // cols out
ggsw_size, ggsw_size,
); );
let normalize_big: usize = self.vec_znx_normalize_tmp_bytes(); let normalize_big: usize = self.vec_znx_normalize_tmp_bytes();
if in_infos.base2k() == apply_infos.base2k() { if in_infos.base2k() == ggsw_infos.base2k() {
res_dft + a_dft + (vmp | normalize_big) res_dft + a_dft + (vmp | normalize_big)
} else { } else {
let normalize_conv: usize = VecZnx::bytes_of(self.n().into(), (apply_infos.rank() + 1).into(), in_size); let normalize_conv: usize = VecZnx::bytes_of(self.n().into(), (ggsw_infos.rank() + 1).into(), in_size);
res_dft + ((a_dft + normalize_conv + (self.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big) res_dft + ((a_dft + normalize_conv + (self.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big)
} }
} }