Support for bivariate convolution & normalization with offset (#126)

* Add bivariate-convolution
* Add pair-wise convolution + tests + benches
* Add take_cnv_pvec_[left/right] to Scratch & updated CHANGELOG.md
* cross-base2k normalization with positive offset
* clippy & fix CI doctest avx compile error
* more streamlined bounds derivation for normalization
* Working cross-base2k normalization with pos/neg offset
* Update normalization API & tests
* Add glwe tensoring test
* Add relinearization + preliminary test
* Fix GGLWEToGGSW key infos
* Add (X,Y) convolution by const (1, Y) poly
* Faster normalization test + add bench for cnv_by_const
* Update changelog
This commit is contained in:
Jean-Philippe Bossuat
2025-12-21 16:56:42 +01:00
committed by GitHub
parent 76424d0ab5
commit 4e90e08a71
219 changed files with 6571 additions and 5041 deletions

View File

@@ -120,13 +120,13 @@ where
R: GGSWInfos,
A: GGLWEInfos,
{
let base2k_tsk: usize = tsk_infos.base2k().into();
let tsk_base2k: usize = tsk_infos.base2k().into();
let rank: usize = res_infos.rank().into();
let cols: usize = rank + 1;
let res_size: usize = res_infos.size();
let a_size: usize = res_infos.max_k().as_usize().div_ceil(base2k_tsk);
let a_size: usize = res_infos.max_k().as_usize().div_ceil(tsk_base2k);
let a_0: usize = VecZnx::bytes_of(self.n(), 1, a_size);
let a_dft: usize = self.bytes_of_vec_znx_dft(cols - 1, a_size);
@@ -146,15 +146,15 @@ where
let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
let tsk: &GGLWEToGGSWKeyPrepared<&[u8], BE> = &tsk.to_ref();
let base2k_res: usize = res.base2k().into();
let base2k_tsk: usize = tsk.base2k().into();
let res_base2k: usize = res.base2k().into();
let tsk_base2k: usize = tsk.base2k().into();
assert!(scratch.available() >= self.ggsw_expand_rows_tmp_bytes(res, tsk));
let rank: usize = res.rank().into();
let cols: usize = rank + 1;
let res_conv_size: usize = res.max_k().as_usize().div_ceil(base2k_tsk);
let res_conv_size: usize = res.max_k().as_usize().div_ceil(tsk_base2k);
let (mut a_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols - 1, res_conv_size);
let (mut a_0, scratch_2) = scratch_1.take_vec_znx(self.n(), 1, res_conv_size);
@@ -163,33 +163,17 @@ where
for row in 0..res.dnum().as_usize() {
let glwe_mi_1: &GLWE<&[u8]> = &res.at(row, 0);
if base2k_res == base2k_tsk {
if res_base2k == tsk_base2k {
for col_i in 0..cols - 1 {
self.vec_znx_dft_apply(1, 0, &mut a_dft, col_i, glwe_mi_1.data(), col_i + 1);
}
self.vec_znx_copy(&mut a_0, 0, glwe_mi_1.data(), 0);
} else {
for i in 0..cols - 1 {
self.vec_znx_normalize(
base2k_tsk,
&mut a_0,
0,
base2k_res,
glwe_mi_1.data(),
i + 1,
scratch_2,
);
self.vec_znx_normalize(&mut a_0, tsk_base2k, 0, 0, glwe_mi_1.data(), res_base2k, i + 1, scratch_2);
self.vec_znx_dft_apply(1, 0, &mut a_dft, i, &a_0, 0);
}
self.vec_znx_normalize(
base2k_tsk,
&mut a_0,
0,
base2k_res,
glwe_mi_1.data(),
0,
scratch_2,
);
self.vec_znx_normalize(&mut a_0, tsk_base2k, 0, 0, glwe_mi_1.data(), res_base2k, 0, scratch_2);
}
ggsw_expand_rows_internal(self, row, res, &a_0, &a_dft, tsk, scratch_2)
@@ -267,13 +251,16 @@ fn ggsw_expand_rows_internal<M, R, C, A, T, BE: Backend>(
// (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2)
module.vec_znx_big_add_small_inplace(&mut res_big, col, a_0, 0);
let res_base2k: usize = res.base2k().as_usize();
for j in 0..cols {
module.vec_znx_big_normalize(
res.base2k().as_usize(),
res.at_mut(row, col).data_mut(),
res_base2k,
0,
j,
tsk.base2k().as_usize(),
&res_big,
tsk.base2k().as_usize(),
j,
scratch_1,
);

View File

@@ -56,12 +56,8 @@ where
rank: Rank(1),
};
GLWE::bytes_of(
self.n().into(),
lwe_infos.base2k(),
lwe_infos.k(),
1u32.into(),
) + GLWE::bytes_of_from_infos(glwe_infos)
GLWE::bytes_of(self.n().into(), lwe_infos.base2k(), lwe_infos.k(), 1u32.into())
+ GLWE::bytes_of_from_infos(glwe_infos)
+ self.glwe_keyswitch_tmp_bytes(&res_infos, glwe_infos, key_infos)
}

View File

@@ -73,11 +73,12 @@ where
}
self.vec_znx_normalize(
ksk.base2k().into(),
&mut glwe.data,
ksk.base2k().into(),
0,
0,
lwe.base2k().into(),
&a_conv,
lwe.base2k().into(),
0,
scratch_2,
);
@@ -89,11 +90,12 @@ where
}
self.vec_znx_normalize(
ksk.base2k().into(),
&mut glwe.data,
ksk.base2k().into(),
0,
1,
lwe.base2k().into(),
&a_conv,
lwe.base2k().into(),
0,
scratch_2,
);