mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 21:26:41 +01:00
Support for bivariate convolution & normalization with offset (#126)
* Add bivariate-convolution * Add pair-wise convolution + tests + benches * Add take_cnv_pvec_[left/right] to Scratch & updated CHANGELOG.md * cross-base2k normalization with positive offset * clippy & fix CI doctest avx compile error * more streamlined bounds derivation for normalization * Working cross-base2k normalization with pos/neg offset * Update normalization API & tests * Add glwe tensoring test * Add relinearization + preliminary test * Fix GGLWEToGGSW key infos * Add (X,Y) convolution by const (1, Y) poly * Faster normalization test + add bench for cnv_by_const * Update changelog
This commit is contained in:
committed by
GitHub
parent
76424d0ab5
commit
4e90e08a71
@@ -120,13 +120,13 @@ where
|
||||
R: GGSWInfos,
|
||||
A: GGLWEInfos,
|
||||
{
|
||||
let base2k_tsk: usize = tsk_infos.base2k().into();
|
||||
let tsk_base2k: usize = tsk_infos.base2k().into();
|
||||
|
||||
let rank: usize = res_infos.rank().into();
|
||||
let cols: usize = rank + 1;
|
||||
|
||||
let res_size: usize = res_infos.size();
|
||||
let a_size: usize = res_infos.max_k().as_usize().div_ceil(base2k_tsk);
|
||||
let a_size: usize = res_infos.max_k().as_usize().div_ceil(tsk_base2k);
|
||||
|
||||
let a_0: usize = VecZnx::bytes_of(self.n(), 1, a_size);
|
||||
let a_dft: usize = self.bytes_of_vec_znx_dft(cols - 1, a_size);
|
||||
@@ -146,15 +146,15 @@ where
|
||||
let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
|
||||
let tsk: &GGLWEToGGSWKeyPrepared<&[u8], BE> = &tsk.to_ref();
|
||||
|
||||
let base2k_res: usize = res.base2k().into();
|
||||
let base2k_tsk: usize = tsk.base2k().into();
|
||||
let res_base2k: usize = res.base2k().into();
|
||||
let tsk_base2k: usize = tsk.base2k().into();
|
||||
|
||||
assert!(scratch.available() >= self.ggsw_expand_rows_tmp_bytes(res, tsk));
|
||||
|
||||
let rank: usize = res.rank().into();
|
||||
let cols: usize = rank + 1;
|
||||
|
||||
let res_conv_size: usize = res.max_k().as_usize().div_ceil(base2k_tsk);
|
||||
let res_conv_size: usize = res.max_k().as_usize().div_ceil(tsk_base2k);
|
||||
|
||||
let (mut a_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols - 1, res_conv_size);
|
||||
let (mut a_0, scratch_2) = scratch_1.take_vec_znx(self.n(), 1, res_conv_size);
|
||||
@@ -163,33 +163,17 @@ where
|
||||
for row in 0..res.dnum().as_usize() {
|
||||
let glwe_mi_1: &GLWE<&[u8]> = &res.at(row, 0);
|
||||
|
||||
if base2k_res == base2k_tsk {
|
||||
if res_base2k == tsk_base2k {
|
||||
for col_i in 0..cols - 1 {
|
||||
self.vec_znx_dft_apply(1, 0, &mut a_dft, col_i, glwe_mi_1.data(), col_i + 1);
|
||||
}
|
||||
self.vec_znx_copy(&mut a_0, 0, glwe_mi_1.data(), 0);
|
||||
} else {
|
||||
for i in 0..cols - 1 {
|
||||
self.vec_znx_normalize(
|
||||
base2k_tsk,
|
||||
&mut a_0,
|
||||
0,
|
||||
base2k_res,
|
||||
glwe_mi_1.data(),
|
||||
i + 1,
|
||||
scratch_2,
|
||||
);
|
||||
self.vec_znx_normalize(&mut a_0, tsk_base2k, 0, 0, glwe_mi_1.data(), res_base2k, i + 1, scratch_2);
|
||||
self.vec_znx_dft_apply(1, 0, &mut a_dft, i, &a_0, 0);
|
||||
}
|
||||
self.vec_znx_normalize(
|
||||
base2k_tsk,
|
||||
&mut a_0,
|
||||
0,
|
||||
base2k_res,
|
||||
glwe_mi_1.data(),
|
||||
0,
|
||||
scratch_2,
|
||||
);
|
||||
self.vec_znx_normalize(&mut a_0, tsk_base2k, 0, 0, glwe_mi_1.data(), res_base2k, 0, scratch_2);
|
||||
}
|
||||
|
||||
ggsw_expand_rows_internal(self, row, res, &a_0, &a_dft, tsk, scratch_2)
|
||||
@@ -267,13 +251,16 @@ fn ggsw_expand_rows_internal<M, R, C, A, T, BE: Backend>(
|
||||
// (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2)
|
||||
module.vec_znx_big_add_small_inplace(&mut res_big, col, a_0, 0);
|
||||
|
||||
let res_base2k: usize = res.base2k().as_usize();
|
||||
|
||||
for j in 0..cols {
|
||||
module.vec_znx_big_normalize(
|
||||
res.base2k().as_usize(),
|
||||
res.at_mut(row, col).data_mut(),
|
||||
res_base2k,
|
||||
0,
|
||||
j,
|
||||
tsk.base2k().as_usize(),
|
||||
&res_big,
|
||||
tsk.base2k().as_usize(),
|
||||
j,
|
||||
scratch_1,
|
||||
);
|
||||
|
||||
@@ -56,12 +56,8 @@ where
|
||||
rank: Rank(1),
|
||||
};
|
||||
|
||||
GLWE::bytes_of(
|
||||
self.n().into(),
|
||||
lwe_infos.base2k(),
|
||||
lwe_infos.k(),
|
||||
1u32.into(),
|
||||
) + GLWE::bytes_of_from_infos(glwe_infos)
|
||||
GLWE::bytes_of(self.n().into(), lwe_infos.base2k(), lwe_infos.k(), 1u32.into())
|
||||
+ GLWE::bytes_of_from_infos(glwe_infos)
|
||||
+ self.glwe_keyswitch_tmp_bytes(&res_infos, glwe_infos, key_infos)
|
||||
}
|
||||
|
||||
|
||||
@@ -73,11 +73,12 @@ where
|
||||
}
|
||||
|
||||
self.vec_znx_normalize(
|
||||
ksk.base2k().into(),
|
||||
&mut glwe.data,
|
||||
ksk.base2k().into(),
|
||||
0,
|
||||
0,
|
||||
lwe.base2k().into(),
|
||||
&a_conv,
|
||||
lwe.base2k().into(),
|
||||
0,
|
||||
scratch_2,
|
||||
);
|
||||
@@ -89,11 +90,12 @@ where
|
||||
}
|
||||
|
||||
self.vec_znx_normalize(
|
||||
ksk.base2k().into(),
|
||||
&mut glwe.data,
|
||||
ksk.base2k().into(),
|
||||
0,
|
||||
1,
|
||||
lwe.base2k().into(),
|
||||
&a_conv,
|
||||
lwe.base2k().into(),
|
||||
0,
|
||||
scratch_2,
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user