This commit is contained in:
Pro7ech
2025-10-14 23:45:00 +02:00
parent 779e02acc4
commit a5df85170d
2 changed files with 16 additions and 16 deletions

View File

@@ -15,7 +15,7 @@ impl GGSW<Vec<u8>> {
module: &Module<B>,
out_infos: &OUT,
in_infos: &IN,
apply_infos: &GGSW,
ggsw_infos: &GGSW,
) -> usize
where
OUT: GGSWInfos,
@@ -27,21 +27,21 @@ impl GGSW<Vec<u8>> {
module,
&out_infos.glwe_layout(),
&in_infos.glwe_layout(),
apply_infos,
ggsw_infos,
)
}
pub fn external_product_inplace_scratch_space<B: Backend, OUT, GGSW>(
module: &Module<B>,
out_infos: &OUT,
apply_infos: &GGSW,
ggsw_infos: &GGSW,
) -> usize
where
OUT: GGSWInfos,
GGSW: GGSWInfos,
Module<B>: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes,
{
GLWE::external_product_inplace_scratch_space(module, &out_infos.glwe_layout(), apply_infos)
GLWE::external_product_inplace_scratch_space(module, &out_infos.glwe_layout(), ggsw_infos)
}
}

View File

@@ -19,7 +19,7 @@ impl<DataSelf: DataMut> GLWE<DataSelf> {
module: Module<B>,
out_infos: &OUT,
in_infos: &IN,
apply_infos: &GGSW,
ggsw_infos: &GGSW,
) -> usize
where
OUT: GLWEInfos,
@@ -27,7 +27,7 @@ impl<DataSelf: DataMut> GLWE<DataSelf> {
GGSW: GGSWInfos,
Module<B>: GLWEExternalProduct<B>,
{
module.glwe_external_product_scratch_space(out_infos, in_infos, apply_infos)
module.glwe_external_product_scratch_space(out_infos, in_infos, ggsw_infos)
}
pub fn external_product<L, R, B: Backend>(&mut self, module: &Module<B>, lhs: &L, rhs: &R, scratch: &mut Scratch<B>)
@@ -67,7 +67,7 @@ where
+ VecZnxNormalizeTmpBytes,
{
#[allow(clippy::too_many_arguments)]
fn glwe_external_product_scratch_space<OUT, IN, GGSW>(&self, out_infos: &OUT, in_infos: &IN, apply_infos: &GGSW) -> usize
fn glwe_external_product_scratch_space<OUT, IN, GGSW>(&self, out_infos: &OUT, in_infos: &IN, ggsw_infos: &GGSW) -> usize
where
OUT: GLWEInfos,
IN: GLWEInfos,
@@ -75,26 +75,26 @@ where
{
let in_size: usize = in_infos
.k()
.div_ceil(apply_infos.base2k())
.div_ceil(apply_infos.dsize().into()) as usize;
.div_ceil(ggsw_infos.base2k())
.div_ceil(ggsw_infos.dsize().into()) as usize;
let out_size: usize = out_infos.size();
let ggsw_size: usize = apply_infos.size();
let res_dft: usize = self.bytes_of_vec_znx_dft((apply_infos.rank() + 1).into(), ggsw_size);
let a_dft: usize = self.bytes_of_vec_znx_dft((apply_infos.rank() + 1).into(), in_size);
let ggsw_size: usize = ggsw_infos.size();
let res_dft: usize = self.bytes_of_vec_znx_dft((ggsw_infos.rank() + 1).into(), ggsw_size);
let a_dft: usize = self.bytes_of_vec_znx_dft((ggsw_infos.rank() + 1).into(), in_size);
let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes(
out_size,
in_size,
in_size, // rows
(apply_infos.rank() + 1).into(), // cols in
(apply_infos.rank() + 1).into(), // cols out
(ggsw_infos.rank() + 1).into(), // cols in
(ggsw_infos.rank() + 1).into(), // cols out
ggsw_size,
);
let normalize_big: usize = self.vec_znx_normalize_tmp_bytes();
if in_infos.base2k() == apply_infos.base2k() {
if in_infos.base2k() == ggsw_infos.base2k() {
res_dft + a_dft + (vmp | normalize_big)
} else {
let normalize_conv: usize = VecZnx::bytes_of(self.n().into(), (apply_infos.rank() + 1).into(), in_size);
let normalize_conv: usize = VecZnx::bytes_of(self.n().into(), (ggsw_infos.rank() + 1).into(), in_size);
res_dft + ((a_dft + normalize_conv + (self.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big)
}
}