mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
Support for bivariate convolution & normalization with offset (#126)
* Add bivariate-convolution * Add pair-wise convolution + tests + benches * Add take_cnv_pvec_[left/right] to Scratch & updated CHANGELOG.md * cross-base2k normalization with positive offset * clippy & fix CI doctest avx compile error * more streamlined bounds derivation for normalization * Working cross-base2k normalization with pos/neg offset * Update normalization API & tests * Add glwe tensoring test * Add relinearization + preliminary test * Fix GGLWEToGGSW key infos * Add (X,Y) convolution by const (1, Y) poly * Faster normalization test + add bench for cnv_by_const * Update changelog
This commit is contained in:
committed by
GitHub
parent
76424d0ab5
commit
4e90e08a71
@@ -7,8 +7,9 @@ use crate::{
|
||||
VecZnxFillNormal, VecZnxFillUniform, VecZnxLsh, VecZnxLshInplace, VecZnxLshTmpBytes, VecZnxMergeRings,
|
||||
VecZnxMergeRingsTmpBytes, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxMulXpMinusOneInplaceTmpBytes,
|
||||
VecZnxNegate, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate,
|
||||
VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxRsh, VecZnxRshInplace, VecZnxSplitRing, VecZnxSplitRingTmpBytes,
|
||||
VecZnxSub, VecZnxSubInplace, VecZnxSubNegateInplace, VecZnxSubScalar, VecZnxSubScalarInplace, VecZnxSwitchRing,
|
||||
VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxRsh, VecZnxRshInplace, VecZnxRshTmpBytes, VecZnxSplitRing,
|
||||
VecZnxSplitRingTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSubNegateInplace, VecZnxSubScalar, VecZnxSubScalarInplace,
|
||||
VecZnxSwitchRing,
|
||||
},
|
||||
layouts::{Backend, DigestU64, FillUniform, Module, ScalarZnx, ScratchOwned, VecZnx, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
reference::znx::znx_copy_ref,
|
||||
@@ -341,10 +342,7 @@ where
|
||||
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_merge_rings_tmp_bytes());
|
||||
|
||||
for a_size in [1, 2, 3, 4] {
|
||||
let mut a: [VecZnx<Vec<u8>>; 2] = [
|
||||
VecZnx::alloc(n >> 1, cols, a_size),
|
||||
VecZnx::alloc(n >> 1, cols, a_size),
|
||||
];
|
||||
let mut a: [VecZnx<Vec<u8>>; 2] = [VecZnx::alloc(n >> 1, cols, a_size), VecZnx::alloc(n >> 1, cols, a_size)];
|
||||
|
||||
a.iter_mut().for_each(|ai| {
|
||||
ai.fill_uniform(base2k, &mut source);
|
||||
@@ -549,26 +547,20 @@ where
|
||||
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
|
||||
// Set d to garbage
|
||||
res_ref.fill_uniform(base2k, &mut source);
|
||||
res_test.fill_uniform(base2k, &mut source);
|
||||
for res_offset in -(base2k as i64)..=(base2k as i64) {
|
||||
// Set d to garbage
|
||||
res_ref.fill_uniform(base2k, &mut source);
|
||||
res_test.fill_uniform(base2k, &mut source);
|
||||
|
||||
// Reference
|
||||
for i in 0..cols {
|
||||
module_ref.vec_znx_normalize(base2k, &mut res_ref, i, base2k, &a, i, scratch_ref.borrow());
|
||||
module_test.vec_znx_normalize(
|
||||
base2k,
|
||||
&mut res_test,
|
||||
i,
|
||||
base2k,
|
||||
&a,
|
||||
i,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
// Reference
|
||||
for i in 0..cols {
|
||||
module_ref.vec_znx_normalize(&mut res_ref, base2k, res_offset, i, &a, base2k, i, scratch_ref.borrow());
|
||||
module_test.vec_znx_normalize(&mut res_test, base2k, res_offset, i, &a, base2k, i, scratch_test.borrow());
|
||||
}
|
||||
|
||||
assert_eq!(a.digest_u64(), a_digest);
|
||||
assert_eq!(res_ref, res_test);
|
||||
}
|
||||
|
||||
assert_eq!(a.digest_u64(), a_digest);
|
||||
assert_eq!(res_ref, res_test);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -718,10 +710,7 @@ where
|
||||
})
|
||||
} else {
|
||||
let std: f64 = a.stats(base2k, col_i).std();
|
||||
assert!(
|
||||
(std - one_12_sqrt).abs() < 0.01,
|
||||
"std={std} ~!= {one_12_sqrt}",
|
||||
);
|
||||
assert!((std - one_12_sqrt).abs() < 0.01, "std={std} ~!= {one_12_sqrt}",);
|
||||
}
|
||||
})
|
||||
});
|
||||
@@ -783,11 +772,7 @@ where
|
||||
})
|
||||
} else {
|
||||
let std: f64 = a.stats(base2k, col_i).std() * k_f64;
|
||||
assert!(
|
||||
(std - sigma * sqrt2).abs() < 0.1,
|
||||
"std={std} ~!= {}",
|
||||
sigma * sqrt2
|
||||
);
|
||||
assert!((std - sigma * sqrt2).abs() < 0.1, "std={std} ~!= {}", sigma * sqrt2);
|
||||
}
|
||||
})
|
||||
});
|
||||
@@ -872,9 +857,9 @@ where
|
||||
|
||||
pub fn test_vec_znx_rsh<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
|
||||
where
|
||||
Module<BR>: VecZnxRsh<BR> + VecZnxLshTmpBytes,
|
||||
Module<BR>: VecZnxRsh<BR> + VecZnxRshTmpBytes,
|
||||
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
|
||||
Module<BT>: VecZnxRsh<BT> + VecZnxLshTmpBytes,
|
||||
Module<BT>: VecZnxRsh<BT> + VecZnxRshTmpBytes,
|
||||
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
|
||||
{
|
||||
assert_eq!(module_ref.n(), module_test.n());
|
||||
@@ -882,8 +867,8 @@ where
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let cols: usize = 2;
|
||||
|
||||
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_lsh_tmp_bytes());
|
||||
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_lsh_tmp_bytes());
|
||||
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_rsh_tmp_bytes());
|
||||
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_rsh_tmp_bytes());
|
||||
|
||||
for a_size in [1, 2, 3, 4] {
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
|
||||
@@ -914,9 +899,9 @@ where
|
||||
|
||||
pub fn test_vec_znx_rsh_inplace<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
|
||||
where
|
||||
Module<BR>: VecZnxRshInplace<BR> + VecZnxLshTmpBytes,
|
||||
Module<BR>: VecZnxRshInplace<BR> + VecZnxRshTmpBytes,
|
||||
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
|
||||
Module<BT>: VecZnxRshInplace<BT> + VecZnxLshTmpBytes,
|
||||
Module<BT>: VecZnxRshInplace<BT> + VecZnxRshTmpBytes,
|
||||
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
|
||||
{
|
||||
assert_eq!(module_ref.n(), module_test.n());
|
||||
@@ -924,8 +909,8 @@ where
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let cols: usize = 2;
|
||||
|
||||
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_lsh_tmp_bytes());
|
||||
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_lsh_tmp_bytes());
|
||||
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_rsh_tmp_bytes());
|
||||
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_rsh_tmp_bytes());
|
||||
|
||||
for res_size in [1, 2, 3, 4] {
|
||||
for k in 0..base2k * res_size {
|
||||
@@ -966,15 +951,11 @@ where
|
||||
let a_digest = a.digest_u64();
|
||||
|
||||
for res_size in [1, 2, 3, 4] {
|
||||
let mut res_ref: [VecZnx<Vec<u8>>; 2] = [
|
||||
VecZnx::alloc(n >> 1, cols, res_size),
|
||||
VecZnx::alloc(n >> 1, cols, res_size),
|
||||
];
|
||||
let mut res_ref: [VecZnx<Vec<u8>>; 2] =
|
||||
[VecZnx::alloc(n >> 1, cols, res_size), VecZnx::alloc(n >> 1, cols, res_size)];
|
||||
|
||||
let mut res_test: [VecZnx<Vec<u8>>; 2] = [
|
||||
VecZnx::alloc(n >> 1, cols, res_size),
|
||||
VecZnx::alloc(n >> 1, cols, res_size),
|
||||
];
|
||||
let mut res_test: [VecZnx<Vec<u8>>; 2] =
|
||||
[VecZnx::alloc(n >> 1, cols, res_size), VecZnx::alloc(n >> 1, cols, res_size)];
|
||||
|
||||
res_ref.iter_mut().for_each(|ri| {
|
||||
ri.fill_uniform(base2k, &mut source);
|
||||
|
||||
Reference in New Issue
Block a user