Support for bivariate convolution & normalization with offset (#126)

* Add bivariate-convolution
* Add pair-wise convolution + tests + benches
* Add take_cnv_pvec_[left/right] to Scratch & updated CHANGELOG.md
* cross-base2k normalization with positive offset
* clippy & fix CI doctest avx compile error
* more streamlined bounds derivation for normalization
* Working cross-base2k normalization with pos/neg offset
* Update normalization API & tests
* Add glwe tensoring test
* Add relinearization + preliminary test
* Fix GGLWEToGGSW key infos
* Add (X,Y) convolution by const (1, Y) poly
* Faster normalization test + add bench for cnv_by_const
* Update changelog
This commit is contained in:
Jean-Philippe Bossuat
2025-12-21 16:56:42 +01:00
committed by GitHub
parent 76424d0ab5
commit 4e90e08a71
219 changed files with 6571 additions and 5041 deletions

View File

@@ -7,8 +7,9 @@ use crate::{
VecZnxFillNormal, VecZnxFillUniform, VecZnxLsh, VecZnxLshInplace, VecZnxLshTmpBytes, VecZnxMergeRings,
VecZnxMergeRingsTmpBytes, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxMulXpMinusOneInplaceTmpBytes,
VecZnxNegate, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate,
VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxRsh, VecZnxRshInplace, VecZnxSplitRing, VecZnxSplitRingTmpBytes,
VecZnxSub, VecZnxSubInplace, VecZnxSubNegateInplace, VecZnxSubScalar, VecZnxSubScalarInplace, VecZnxSwitchRing,
VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxRsh, VecZnxRshInplace, VecZnxRshTmpBytes, VecZnxSplitRing,
VecZnxSplitRingTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSubNegateInplace, VecZnxSubScalar, VecZnxSubScalarInplace,
VecZnxSwitchRing,
},
layouts::{Backend, DigestU64, FillUniform, Module, ScalarZnx, ScratchOwned, VecZnx, ZnxInfos, ZnxView, ZnxViewMut},
reference::znx::znx_copy_ref,
@@ -341,10 +342,7 @@ where
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_merge_rings_tmp_bytes());
for a_size in [1, 2, 3, 4] {
let mut a: [VecZnx<Vec<u8>>; 2] = [
VecZnx::alloc(n >> 1, cols, a_size),
VecZnx::alloc(n >> 1, cols, a_size),
];
let mut a: [VecZnx<Vec<u8>>; 2] = [VecZnx::alloc(n >> 1, cols, a_size), VecZnx::alloc(n >> 1, cols, a_size)];
a.iter_mut().for_each(|ai| {
ai.fill_uniform(base2k, &mut source);
@@ -549,26 +547,20 @@ where
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
// Set d to garbage
res_ref.fill_uniform(base2k, &mut source);
res_test.fill_uniform(base2k, &mut source);
for res_offset in -(base2k as i64)..=(base2k as i64) {
// Set d to garbage
res_ref.fill_uniform(base2k, &mut source);
res_test.fill_uniform(base2k, &mut source);
// Reference
for i in 0..cols {
module_ref.vec_znx_normalize(base2k, &mut res_ref, i, base2k, &a, i, scratch_ref.borrow());
module_test.vec_znx_normalize(
base2k,
&mut res_test,
i,
base2k,
&a,
i,
scratch_test.borrow(),
);
// Reference
for i in 0..cols {
module_ref.vec_znx_normalize(&mut res_ref, base2k, res_offset, i, &a, base2k, i, scratch_ref.borrow());
module_test.vec_znx_normalize(&mut res_test, base2k, res_offset, i, &a, base2k, i, scratch_test.borrow());
}
assert_eq!(a.digest_u64(), a_digest);
assert_eq!(res_ref, res_test);
}
assert_eq!(a.digest_u64(), a_digest);
assert_eq!(res_ref, res_test);
}
}
}
@@ -718,10 +710,7 @@ where
})
} else {
let std: f64 = a.stats(base2k, col_i).std();
assert!(
(std - one_12_sqrt).abs() < 0.01,
"std={std} ~!= {one_12_sqrt}",
);
assert!((std - one_12_sqrt).abs() < 0.01, "std={std} ~!= {one_12_sqrt}",);
}
})
});
@@ -783,11 +772,7 @@ where
})
} else {
let std: f64 = a.stats(base2k, col_i).std() * k_f64;
assert!(
(std - sigma * sqrt2).abs() < 0.1,
"std={std} ~!= {}",
sigma * sqrt2
);
assert!((std - sigma * sqrt2).abs() < 0.1, "std={std} ~!= {}", sigma * sqrt2);
}
})
});
@@ -872,9 +857,9 @@ where
pub fn test_vec_znx_rsh<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxRsh<BR> + VecZnxLshTmpBytes,
Module<BR>: VecZnxRsh<BR> + VecZnxRshTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
Module<BT>: VecZnxRsh<BT> + VecZnxLshTmpBytes,
Module<BT>: VecZnxRsh<BT> + VecZnxRshTmpBytes,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
@@ -882,8 +867,8 @@ where
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_lsh_tmp_bytes());
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_lsh_tmp_bytes());
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_rsh_tmp_bytes());
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_rsh_tmp_bytes());
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
@@ -914,9 +899,9 @@ where
pub fn test_vec_znx_rsh_inplace<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxRshInplace<BR> + VecZnxLshTmpBytes,
Module<BR>: VecZnxRshInplace<BR> + VecZnxRshTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
Module<BT>: VecZnxRshInplace<BT> + VecZnxLshTmpBytes,
Module<BT>: VecZnxRshInplace<BT> + VecZnxRshTmpBytes,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
@@ -924,8 +909,8 @@ where
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_lsh_tmp_bytes());
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_lsh_tmp_bytes());
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_rsh_tmp_bytes());
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_rsh_tmp_bytes());
for res_size in [1, 2, 3, 4] {
for k in 0..base2k * res_size {
@@ -966,15 +951,11 @@ where
let a_digest = a.digest_u64();
for res_size in [1, 2, 3, 4] {
let mut res_ref: [VecZnx<Vec<u8>>; 2] = [
VecZnx::alloc(n >> 1, cols, res_size),
VecZnx::alloc(n >> 1, cols, res_size),
];
let mut res_ref: [VecZnx<Vec<u8>>; 2] =
[VecZnx::alloc(n >> 1, cols, res_size), VecZnx::alloc(n >> 1, cols, res_size)];
let mut res_test: [VecZnx<Vec<u8>>; 2] = [
VecZnx::alloc(n >> 1, cols, res_size),
VecZnx::alloc(n >> 1, cols, res_size),
];
let mut res_test: [VecZnx<Vec<u8>>; 2] =
[VecZnx::alloc(n >> 1, cols, res_size), VecZnx::alloc(n >> 1, cols, res_size)];
res_ref.iter_mut().for_each(|ri| {
ri.fill_uniform(base2k, &mut source);