Support for bivariate convolution & normalization with offset (#126)

* Add bivariate-convolution
* Add pair-wise convolution + tests + benches
* Add take_cnv_pvec_[left/right] to Scratch & updated CHANGELOG.md
* cross-base2k normalization with positive offset
* clippy & fix CI doctest avx compile error
* more streamlined bounds derivation for normalization
* Working cross-base2k normalization with pos/neg offset
* Update normalization API & tests
* Add glwe tensoring test
* Add relinearization + preliminary test
* Fix GGLWEToGGSW key infos
* Add (X,Y) convolution by const (1, Y) poly
* Faster normalization test + add bench for cnv_by_const
* Update changelog
This commit is contained in:
Jean-Philippe Bossuat
2025-12-21 16:56:42 +01:00
committed by GitHub
parent 76424d0ab5
commit 4e90e08a71
219 changed files with 6571 additions and 5041 deletions

View File

@@ -1,19 +1,80 @@
use rand::RngCore;
use crate::{
api::{
BivariateTensoring, ModuleN, ScratchOwnedAlloc, ScratchOwnedBorrow, ScratchTakeBasic, TakeSlice, VecZnxBigAlloc,
VecZnxBigNormalize, VecZnxDftAlloc, VecZnxDftApply, VecZnxIdftApplyTmpA, VecZnxNormalizeInplace,
CnvPVecAlloc, Convolution, ModuleN, ScratchOwnedAlloc, ScratchOwnedBorrow, ScratchTakeBasic, TakeSlice, VecZnxAdd,
VecZnxBigAlloc, VecZnxBigNormalize, VecZnxCopy, VecZnxDftAlloc, VecZnxDftApply, VecZnxIdftApplyTmpA,
VecZnxNormalizeInplace,
},
layouts::{
Backend, FillUniform, Scratch, ScratchOwned, VecZnx, VecZnxBig, VecZnxDft, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView,
ZnxViewMut, ZnxZero,
Backend, CnvPVecL, CnvPVecR, FillUniform, Scratch, ScratchOwned, VecZnx, VecZnxBig, VecZnxDft, VecZnxToMut, VecZnxToRef,
ZnxInfos, ZnxView, ZnxViewMut, ZnxZero,
},
source::Source,
};
pub fn test_bivariate_tensoring<M, BE: Backend>(module: &M)
pub fn test_convolution_by_const<M, BE: Backend>(module: &M)
where
M: ModuleN + Convolution<BE> + VecZnxBigNormalize<BE> + VecZnxNormalizeInplace<BE> + VecZnxBigAlloc<BE>,
Scratch<BE>: ScratchTakeBasic,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
{
let mut source: Source = Source::new([0u8; 32]);
let base2k: usize = 12;
let a_cols: usize = 2;
let a_size: usize = 15;
let b_size: usize = 15;
let res_size: usize = a_size + b_size;
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), a_cols, a_size);
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), 1, b_size);
let mut res_want: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), 1, res_size);
let mut res_have: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), 1, res_size);
let mut res_big: VecZnxBig<Vec<u8>, BE> = module.vec_znx_big_alloc(1, res_size);
a.fill_uniform(base2k, &mut source);
let mut b_const = vec![0i64; b_size];
let mask = (1 << base2k) - 1;
for (j, x) in b_const[..1].iter_mut().enumerate() {
let r = source.next_u64() & mask;
*x = ((r << (64 - base2k)) as i64) >> (64 - base2k);
b.at_mut(0, j)[0] = *x
}
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(module.cnv_by_const_apply_tmp_bytes(res_size, 0, a_size, b_size));
for a_col in 0..a.cols() {
for offset in 0..res_size {
module.cnv_by_const_apply(&mut res_big, offset, 0, &a, a_col, &b_const, scratch.borrow());
module.vec_znx_big_normalize(&mut res_have, base2k, 0, 0, &res_big, base2k, 0, scratch.borrow());
bivariate_convolution_naive(
module,
base2k,
(offset + 1) as i64,
&mut res_want,
0,
&a,
a_col,
&b,
0,
scratch.borrow(),
);
assert_eq!(res_want, res_have);
}
}
}
pub fn test_convolution<M, BE: Backend>(module: &M)
where
M: ModuleN
+ BivariateTensoring<BE>
+ Convolution<BE>
+ CnvPVecAlloc<BE>
+ VecZnxDftAlloc<BE>
+ VecZnxDftApply<BE>
+ VecZnxIdftApplyTmpA<BE>
@@ -27,56 +88,199 @@ where
let base2k: usize = 12;
let a_cols: usize = 3;
let b_cols: usize = 3;
let a_size: usize = 3;
let b_size: usize = 3;
let c_cols: usize = a_cols + b_cols - 1;
let c_size: usize = a_size + b_size;
let a_cols: usize = 2;
let b_cols: usize = 2;
let a_size: usize = 15;
let b_size: usize = 15;
let res_size: usize = a_size + b_size;
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), a_cols, a_size);
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), b_cols, b_size);
let mut c_want: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), c_cols, c_size);
let mut c_have: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), c_cols, c_size);
let mut c_have_dft: VecZnxDft<Vec<u8>, BE> = module.vec_znx_dft_alloc(c_cols, c_size);
let mut c_have_big: VecZnxBig<Vec<u8>, BE> = module.vec_znx_big_alloc(c_cols, c_size);
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(module.convolution_tmp_bytes(b_size));
let mut res_want: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), 1, res_size);
let mut res_have: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), 1, res_size);
let mut res_dft: VecZnxDft<Vec<u8>, BE> = module.vec_znx_dft_alloc(1, res_size);
let mut res_big: VecZnxBig<Vec<u8>, BE> = module.vec_znx_big_alloc(1, res_size);
a.fill_uniform(base2k, &mut source);
b.fill_uniform(base2k, &mut source);
let mut b_dft: VecZnxDft<Vec<u8>, BE> = module.vec_znx_dft_alloc(b_cols, b_size);
for i in 0..b.cols() {
module.vec_znx_dft_apply(1, 0, &mut b_dft, i, &b, i);
let mut a_prep: CnvPVecL<Vec<u8>, BE> = module.cnv_pvec_left_alloc(a_cols, a_size);
let mut b_prep: CnvPVecR<Vec<u8>, BE> = module.cnv_pvec_right_alloc(b_cols, b_size);
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(
module
.cnv_apply_dft_tmp_bytes(res_size, 0, a_size, b_size)
.max(module.cnv_prepare_left_tmp_bytes(res_size, a_size))
.max(module.cnv_prepare_right_tmp_bytes(res_size, b_size)),
);
module.cnv_prepare_left(&mut a_prep, &a, scratch.borrow());
module.cnv_prepare_right(&mut b_prep, &b, scratch.borrow());
for a_col in 0..a.cols() {
for b_col in 0..b.cols() {
for offset in 0..res_size {
module.cnv_apply_dft(&mut res_dft, offset, 0, &a_prep, a_col, &b_prep, b_col, scratch.borrow());
module.vec_znx_idft_apply_tmpa(&mut res_big, 0, &mut res_dft, 0);
module.vec_znx_big_normalize(&mut res_have, base2k, 0, 0, &res_big, base2k, 0, scratch.borrow());
bivariate_convolution_naive(
module,
base2k,
(offset + 1) as i64,
&mut res_want,
0,
&a,
a_col,
&b,
b_col,
scratch.borrow(),
);
assert_eq!(res_want, res_have);
}
}
}
}
pub fn test_convolution_pairwise<M, BE: Backend>(module: &M)
where
M: ModuleN
+ Convolution<BE>
+ CnvPVecAlloc<BE>
+ VecZnxDftAlloc<BE>
+ VecZnxDftApply<BE>
+ VecZnxIdftApplyTmpA<BE>
+ VecZnxBigNormalize<BE>
+ VecZnxNormalizeInplace<BE>
+ VecZnxBigAlloc<BE>
+ VecZnxAdd
+ VecZnxCopy,
Scratch<BE>: ScratchTakeBasic,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
{
let mut source: Source = Source::new([0u8; 32]);
let base2k: usize = 12;
let cols = 2;
let a_size: usize = 15;
let b_size: usize = 15;
let res_size: usize = a_size + b_size;
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), cols, a_size);
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), cols, b_size);
let mut tmp_a: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), 1, a_size);
let mut tmp_b: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), 1, b_size);
let mut res_want: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), 1, res_size);
let mut res_have: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), 1, res_size);
let mut res_dft: VecZnxDft<Vec<u8>, BE> = module.vec_znx_dft_alloc(1, res_size);
let mut res_big: VecZnxBig<Vec<u8>, BE> = module.vec_znx_big_alloc(1, res_size);
a.fill_uniform(base2k, &mut source);
b.fill_uniform(base2k, &mut source);
let mut a_prep: CnvPVecL<Vec<u8>, BE> = module.cnv_pvec_left_alloc(cols, a_size);
let mut b_prep: CnvPVecR<Vec<u8>, BE> = module.cnv_pvec_right_alloc(cols, b_size);
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(
module
.cnv_pairwise_apply_dft_tmp_bytes(res_size, 0, a_size, b_size)
.max(module.cnv_prepare_left_tmp_bytes(res_size, a_size))
.max(module.cnv_prepare_right_tmp_bytes(res_size, b_size)),
);
module.cnv_prepare_left(&mut a_prep, &a, scratch.borrow());
module.cnv_prepare_right(&mut b_prep, &b, scratch.borrow());
for col_i in 0..cols {
for col_j in 0..cols {
for offset in 0..res_size {
module.cnv_pairwise_apply_dft(&mut res_dft, offset, 0, &a_prep, &b_prep, col_i, col_j, scratch.borrow());
module.vec_znx_idft_apply_tmpa(&mut res_big, 0, &mut res_dft, 0);
module.vec_znx_big_normalize(&mut res_have, base2k, 0, 0, &res_big, base2k, 0, scratch.borrow());
if col_i != col_j {
module.vec_znx_add(&mut tmp_a, 0, &a, col_i, &a, col_j);
module.vec_znx_add(&mut tmp_b, 0, &b, col_i, &b, col_j);
} else {
module.vec_znx_copy(&mut tmp_a, 0, &a, col_i);
module.vec_znx_copy(&mut tmp_b, 0, &b, col_j);
}
bivariate_convolution_naive(
module,
base2k,
(offset + 1) as i64,
&mut res_want,
0,
&tmp_a,
0,
&tmp_b,
0,
scratch.borrow(),
);
assert_eq!(res_want, res_have);
}
}
}
}
#[allow(clippy::too_many_arguments)]
pub fn bivariate_convolution_naive<R, A, B, M, BE: Backend>(
module: &M,
base2k: usize,
k: i64,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
b: &B,
b_col: usize,
scratch: &mut Scratch<BE>,
) where
R: VecZnxToMut,
A: VecZnxToRef,
B: VecZnxToRef,
M: VecZnxNormalizeInplace<BE>,
Scratch<BE>: TakeSlice,
{
let res: &mut VecZnx<&mut [u8]> = &mut res.to_mut();
let a: &VecZnx<&[u8]> = &a.to_ref();
let b: &VecZnx<&[u8]> = &b.to_ref();
for j in 0..res.size() {
res.zero_at(res_col, j);
}
for mut k in 0..(2 * c_size + 1) as i64 {
k -= c_size as i64;
for a_limb in 0..a.size() {
for b_limb in 0..b.size() {
let res_scale_abs = k.unsigned_abs() as usize;
module.bivariate_tensoring(k, &mut c_have_dft, &a, &b_dft, scratch.borrow());
let mut res_limb: usize = a_limb + b_limb + 1;
for i in 0..c_cols {
module.vec_znx_idft_apply_tmpa(&mut c_have_big, i, &mut c_have_dft, i);
if k <= 0 {
res_limb += res_scale_abs;
if res_limb < res.size() {
negacyclic_convolution_naive_add(res.at_mut(res_col, res_limb), a.at(a_col, a_limb), b.at(b_col, b_limb));
}
} else if res_limb >= res_scale_abs {
res_limb -= res_scale_abs;
if res_limb < res.size() {
negacyclic_convolution_naive_add(res.at_mut(res_col, res_limb), a.at(a_col, a_limb), b.at(b_col, b_limb));
}
}
}
for i in 0..c_cols {
module.vec_znx_big_normalize(
base2k,
&mut c_have,
i,
base2k,
&c_have_big,
i,
scratch.borrow(),
);
}
bivariate_tensoring_naive(module, base2k, k, &mut c_want, &a, &b, scratch.borrow());
assert_eq!(c_want, c_have);
}
module.vec_znx_normalize_inplace(base2k, res, res_col, scratch);
}
fn bivariate_tensoring_naive<R, A, B, M, BE: Backend>(
@@ -154,3 +358,18 @@ fn negacyclic_convolution_naive_add(res: &mut [i64], a: &[i64], b: &[i64]) {
}
}
}
fn negacyclic_convolution_naive(res: &mut [i64], a: &[i64], b: &[i64]) {
let n: usize = res.len();
res.fill(0);
for i in 0..n {
let ai: i64 = a[i];
let lim: usize = n - i;
for j in 0..lim {
res[i + j] += ai * b[j];
}
for j in lim..n {
res[i + j - n] -= ai * b[j];
}
}
}

View File

@@ -29,10 +29,7 @@ where
receiver.read_from(&mut reader).expect("read_from failed");
// Ensure serialization round-trip correctness
assert_eq!(
&original, &receiver,
"Deserialized object does not match the original"
);
assert_eq!(&original, &receiver, "Deserialized object does not match the original");
}
#[test]

View File

@@ -90,24 +90,8 @@ where
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),
);
module_ref.vec_znx_big_normalize(&mut res_ref, base2k, 0, j, &res_big_ref, base2k, j, scratch_ref.borrow());
module_test.vec_znx_big_normalize(&mut res_test, base2k, 0, j, &res_big_test, base2k, j, scratch_test.borrow());
}
assert_eq!(res_ref, res_test);
@@ -212,24 +196,8 @@ where
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),
);
module_ref.vec_znx_big_normalize(&mut res_ref, base2k, 0, j, &res_big_ref, base2k, j, scratch_ref.borrow());
module_test.vec_znx_big_normalize(&mut res_test, base2k, 0, j, &res_big_test, base2k, j, scratch_test.borrow());
}
assert_eq!(res_ref, res_test);
@@ -339,24 +307,8 @@ where
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),
);
module_ref.vec_znx_big_normalize(&mut res_ref, base2k, 0, j, &res_big_ref, base2k, j, scratch_ref.borrow());
module_test.vec_znx_big_normalize(&mut res_test, base2k, 0, j, &res_big_test, base2k, j, scratch_test.borrow());
}
assert_eq!(res_ref, res_test);
@@ -447,24 +399,8 @@ pub fn test_svp_apply_dft_to_dft_inplace<BR: Backend, BT: Backend>(
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),
);
module_ref.vec_znx_big_normalize(&mut res_ref, base2k, 0, j, &res_big_ref, base2k, j, scratch_ref.borrow());
module_test.vec_znx_big_normalize(&mut res_test, base2k, 0, j, &res_big_test, base2k, j, scratch_test.borrow());
}
assert_eq!(res_ref, res_test);

View File

@@ -7,8 +7,9 @@ use crate::{
VecZnxFillNormal, VecZnxFillUniform, VecZnxLsh, VecZnxLshInplace, VecZnxLshTmpBytes, VecZnxMergeRings,
VecZnxMergeRingsTmpBytes, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxMulXpMinusOneInplaceTmpBytes,
VecZnxNegate, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate,
VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxRsh, VecZnxRshInplace, VecZnxSplitRing, VecZnxSplitRingTmpBytes,
VecZnxSub, VecZnxSubInplace, VecZnxSubNegateInplace, VecZnxSubScalar, VecZnxSubScalarInplace, VecZnxSwitchRing,
VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxRsh, VecZnxRshInplace, VecZnxRshTmpBytes, VecZnxSplitRing,
VecZnxSplitRingTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSubNegateInplace, VecZnxSubScalar, VecZnxSubScalarInplace,
VecZnxSwitchRing,
},
layouts::{Backend, DigestU64, FillUniform, Module, ScalarZnx, ScratchOwned, VecZnx, ZnxInfos, ZnxView, ZnxViewMut},
reference::znx::znx_copy_ref,
@@ -341,10 +342,7 @@ where
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_merge_rings_tmp_bytes());
for a_size in [1, 2, 3, 4] {
let mut a: [VecZnx<Vec<u8>>; 2] = [
VecZnx::alloc(n >> 1, cols, a_size),
VecZnx::alloc(n >> 1, cols, a_size),
];
let mut a: [VecZnx<Vec<u8>>; 2] = [VecZnx::alloc(n >> 1, cols, a_size), VecZnx::alloc(n >> 1, cols, a_size)];
a.iter_mut().for_each(|ai| {
ai.fill_uniform(base2k, &mut source);
@@ -549,26 +547,20 @@ where
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
// Set d to garbage
res_ref.fill_uniform(base2k, &mut source);
res_test.fill_uniform(base2k, &mut source);
for res_offset in -(base2k as i64)..=(base2k as i64) {
// Set d to garbage
res_ref.fill_uniform(base2k, &mut source);
res_test.fill_uniform(base2k, &mut source);
// Reference
for i in 0..cols {
module_ref.vec_znx_normalize(base2k, &mut res_ref, i, base2k, &a, i, scratch_ref.borrow());
module_test.vec_znx_normalize(
base2k,
&mut res_test,
i,
base2k,
&a,
i,
scratch_test.borrow(),
);
// Reference
for i in 0..cols {
module_ref.vec_znx_normalize(&mut res_ref, base2k, res_offset, i, &a, base2k, i, scratch_ref.borrow());
module_test.vec_znx_normalize(&mut res_test, base2k, res_offset, i, &a, base2k, i, scratch_test.borrow());
}
assert_eq!(a.digest_u64(), a_digest);
assert_eq!(res_ref, res_test);
}
assert_eq!(a.digest_u64(), a_digest);
assert_eq!(res_ref, res_test);
}
}
}
@@ -718,10 +710,7 @@ where
})
} else {
let std: f64 = a.stats(base2k, col_i).std();
assert!(
(std - one_12_sqrt).abs() < 0.01,
"std={std} ~!= {one_12_sqrt}",
);
assert!((std - one_12_sqrt).abs() < 0.01, "std={std} ~!= {one_12_sqrt}",);
}
})
});
@@ -783,11 +772,7 @@ where
})
} else {
let std: f64 = a.stats(base2k, col_i).std() * k_f64;
assert!(
(std - sigma * sqrt2).abs() < 0.1,
"std={std} ~!= {}",
sigma * sqrt2
);
assert!((std - sigma * sqrt2).abs() < 0.1, "std={std} ~!= {}", sigma * sqrt2);
}
})
});
@@ -872,9 +857,9 @@ where
pub fn test_vec_znx_rsh<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxRsh<BR> + VecZnxLshTmpBytes,
Module<BR>: VecZnxRsh<BR> + VecZnxRshTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
Module<BT>: VecZnxRsh<BT> + VecZnxLshTmpBytes,
Module<BT>: VecZnxRsh<BT> + VecZnxRshTmpBytes,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
@@ -882,8 +867,8 @@ where
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_lsh_tmp_bytes());
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_lsh_tmp_bytes());
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_rsh_tmp_bytes());
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_rsh_tmp_bytes());
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
@@ -914,9 +899,9 @@ where
pub fn test_vec_znx_rsh_inplace<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxRshInplace<BR> + VecZnxLshTmpBytes,
Module<BR>: VecZnxRshInplace<BR> + VecZnxRshTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
Module<BT>: VecZnxRshInplace<BT> + VecZnxLshTmpBytes,
Module<BT>: VecZnxRshInplace<BT> + VecZnxRshTmpBytes,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
@@ -924,8 +909,8 @@ where
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_lsh_tmp_bytes());
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_lsh_tmp_bytes());
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_rsh_tmp_bytes());
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_rsh_tmp_bytes());
for res_size in [1, 2, 3, 4] {
for k in 0..base2k * res_size {
@@ -966,15 +951,11 @@ where
let a_digest = a.digest_u64();
for res_size in [1, 2, 3, 4] {
let mut res_ref: [VecZnx<Vec<u8>>; 2] = [
VecZnx::alloc(n >> 1, cols, res_size),
VecZnx::alloc(n >> 1, cols, res_size),
];
let mut res_ref: [VecZnx<Vec<u8>>; 2] =
[VecZnx::alloc(n >> 1, cols, res_size), VecZnx::alloc(n >> 1, cols, res_size)];
let mut res_test: [VecZnx<Vec<u8>>; 2] = [
VecZnx::alloc(n >> 1, cols, res_size),
VecZnx::alloc(n >> 1, cols, res_size),
];
let mut res_test: [VecZnx<Vec<u8>>; 2] =
[VecZnx::alloc(n >> 1, cols, res_size), VecZnx::alloc(n >> 1, cols, res_size)];
res_ref.iter_mut().for_each(|ri| {
ri.fill_uniform(base2k, &mut source);

View File

@@ -93,20 +93,22 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -188,20 +190,22 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -279,20 +283,22 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -367,20 +373,22 @@ pub fn test_vec_znx_big_add_small_inplace<BR: Backend, BT: Backend>(
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -459,20 +467,22 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -546,20 +556,22 @@ pub fn test_vec_znx_big_automorphism_inplace<BR: Backend, BT: Backend>(
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -631,20 +643,22 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -709,20 +723,22 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -782,36 +798,40 @@ where
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
// Set d to garbage
source.fill_bytes(res_ref.data_mut());
source.fill_bytes(res_test.data_mut());
for res_offset in -(base2k as i64)..=(base2k as i64) {
// Set d to garbage
source.fill_bytes(res_ref.data_mut());
source.fill_bytes(res_test.data_mut());
// Reference
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_ref,
j,
base2k,
&a_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_test,
j,
base2k,
&a_test,
j,
scratch_test.borrow(),
);
// Reference
for j in 0..cols {
module_ref.vec_znx_big_normalize(
&mut res_ref,
base2k,
res_offset,
j,
&a_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
&mut res_test,
base2k,
res_offset,
j,
&a_test,
base2k,
j,
scratch_test.borrow(),
);
}
assert_eq!(a_ref.digest_u64(), a_ref_digest);
assert_eq!(a_test.digest_u64(), a_test_digest);
assert_eq!(res_ref, res_test);
}
assert_eq!(a_ref.digest_u64(), a_ref_digest);
assert_eq!(a_test.digest_u64(), a_test_digest);
assert_eq!(res_ref, res_test);
}
}
}
@@ -891,20 +911,22 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -986,20 +1008,22 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -1083,20 +1107,22 @@ pub fn test_vec_znx_big_sub_negate_inplace<BR: Backend, BT: Backend>(
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -1180,20 +1206,22 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -1278,20 +1306,22 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -1366,20 +1396,22 @@ pub fn test_vec_znx_big_sub_small_a_inplace<BR: Backend, BT: Backend>(
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -1427,55 +1459,59 @@ pub fn test_vec_znx_big_sub_small_b_inplace<BR: Backend, BT: Backend>(
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
res.fill_uniform(base2k, &mut source);
for res_offset in -(base2k as i64)..=(base2k as i64) {
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
res.fill_uniform(base2k, &mut source);
let mut res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_big_alloc(cols, res_size);
let mut res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_big_alloc(cols, res_size);
let mut res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_big_alloc(cols, res_size);
let mut res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_big_alloc(cols, res_size);
for j in 0..cols {
module_ref.vec_znx_big_from_small(&mut res_big_ref, j, &res, j);
module_test.vec_znx_big_from_small(&mut res_big_test, j, &res, j);
for j in 0..cols {
module_ref.vec_znx_big_from_small(&mut res_big_ref, j, &res, j);
module_test.vec_znx_big_from_small(&mut res_big_test, j, &res, j);
}
for i in 0..cols {
module_ref.vec_znx_big_sub_small_negate_inplace(&mut res_big_ref, i, &a, i);
module_test.vec_znx_big_sub_small_negate_inplace(&mut res_big_test, i, &a, i);
}
assert_eq!(a.digest_u64(), a_digest);
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let res_ref_digest: u64 = res_big_ref.digest_u64();
let res_test_digest: u64 = res_big_test.digest_u64();
for j in 0..cols {
module_ref.vec_znx_big_normalize(
&mut res_small_ref,
base2k,
res_offset,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
&mut res_small_test,
base2k,
res_offset,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
}
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
assert_eq!(res_big_test.digest_u64(), res_test_digest);
assert_eq!(res_small_ref, res_small_test);
}
for i in 0..cols {
module_ref.vec_znx_big_sub_small_negate_inplace(&mut res_big_ref, i, &a, i);
module_test.vec_znx_big_sub_small_negate_inplace(&mut res_big_test, i, &a, i);
}
assert_eq!(a.digest_u64(), a_digest);
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let res_ref_digest: u64 = res_big_ref.digest_u64();
let res_test_digest: u64 = res_big_test.digest_u64();
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),
);
}
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
assert_eq!(res_big_test.digest_u64(), res_test_digest);
assert_eq!(res_small_ref, res_small_test);
}
}
}

View File

@@ -102,20 +102,22 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -208,20 +210,22 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -311,20 +315,22 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -392,13 +398,7 @@ where
for j in 0..cols {
module_ref.vec_znx_idft_apply(&mut res_big_ref, j, &res_dft_ref, j, scratch_ref.borrow());
module_test.vec_znx_idft_apply(
&mut res_big_test,
j,
&res_dft_test,
j,
scratch_test.borrow(),
);
module_test.vec_znx_idft_apply(&mut res_big_test, j, &res_dft_test, j, scratch_test.borrow());
}
assert_eq!(res_dft_ref.digest_u64(), res_dft_ref_digest);
@@ -412,20 +412,22 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -502,20 +504,22 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -589,20 +593,22 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -709,20 +715,22 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -815,20 +823,22 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -923,20 +933,22 @@ pub fn test_vec_znx_dft_sub_negate_inplace<BR: Backend, BT: Backend>(
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);

View File

@@ -90,20 +90,22 @@ where
for j in 0..cols_out {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -205,18 +207,8 @@ where
source.fill_bytes(res_dft_ref.data_mut());
source.fill_bytes(res_dft_test.data_mut());
module_ref.vmp_apply_dft_to_dft(
&mut res_dft_ref,
&a_dft_ref,
&pmat_ref,
scratch_ref.borrow(),
);
module_test.vmp_apply_dft_to_dft(
&mut res_dft_test,
&a_dft_test,
&pmat_test,
scratch_test.borrow(),
);
module_ref.vmp_apply_dft_to_dft(&mut res_dft_ref, &a_dft_ref, &pmat_ref, scratch_ref.borrow());
module_test.vmp_apply_dft_to_dft(&mut res_dft_test, &a_dft_test, &pmat_test, scratch_test.borrow());
let res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
let res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
@@ -229,20 +221,22 @@ where
for j in 0..cols_out {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -379,20 +373,22 @@ where
for j in 0..cols_out {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);