Add cross-basek normalization (#90)

* added cross_basek_normalization

* updated method signatures to take layouts

* fixed cross-base normalization

fix #91
fix #93
This commit is contained in:
Jean-Philippe Bossuat
2025-09-30 14:40:10 +02:00
committed by GitHub
parent 4da790ea6a
commit 37e13b965c
216 changed files with 12481 additions and 7745 deletions

View File

@@ -8,38 +8,18 @@ use crate::{
VecZnxMergeRingsTmpBytes, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxMulXpMinusOneInplaceTmpBytes,
VecZnxNegate, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate,
VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxRsh, VecZnxRshInplace, VecZnxSplitRing, VecZnxSplitRingTmpBytes,
VecZnxSub, VecZnxSubABInplace, VecZnxSubBAInplace, VecZnxSubScalar, VecZnxSubScalarInplace, VecZnxSwitchRing,
VecZnxSub, VecZnxSubInplace, VecZnxSubNegateInplace, VecZnxSubScalar, VecZnxSubScalarInplace, VecZnxSwitchRing,
},
layouts::{Backend, DigestU64, FillUniform, Module, ScalarZnx, ScratchOwned, VecZnx, ZnxInfos, ZnxView, ZnxViewMut},
reference::znx::znx_copy_ref,
source::Source,
};
pub fn test_vec_znx_encode_vec_i64_lo_norm() {
pub fn test_vec_znx_encode_vec_i64() {
let n: usize = 32;
let basek: usize = 17;
let base2k: usize = 17;
let size: usize = 5;
let k: usize = size * basek - 5;
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, 2, size);
let mut source: Source = Source::new([0u8; 32]);
let raw: &mut [i64] = a.raw_mut();
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
(0..a.cols()).for_each(|col_i| {
let mut have: Vec<i64> = vec![i64::default(); n];
have.iter_mut()
.for_each(|x| *x = (source.next_i64() << 56) >> 56);
a.encode_vec_i64(basek, col_i, k, &have, 10);
let mut want: Vec<i64> = vec![i64::default(); n];
a.decode_vec_i64(basek, col_i, k, &mut want);
assert_eq!(have, want, "{:?} != {:?}", &have, &want);
});
}
pub fn test_vec_znx_encode_vec_i64_hi_norm() {
let n: usize = 32;
let basek: usize = 17;
let size: usize = 5;
for k in [1, basek / 2, size * basek - 5] {
for k in [1, base2k / 2, size * base2k - 5] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, 2, size);
let mut source = Source::new([0u8; 32]);
let raw: &mut [i64] = a.raw_mut();
@@ -53,15 +33,15 @@ pub fn test_vec_znx_encode_vec_i64_hi_norm() {
*x = source.next_i64();
}
});
a.encode_vec_i64(basek, col_i, k, &have, 63);
a.encode_vec_i64(base2k, col_i, k, &have);
let mut want: Vec<i64> = vec![i64::default(); n];
a.decode_vec_i64(basek, col_i, k, &mut want);
a.decode_vec_i64(base2k, col_i, k, &mut want);
assert_eq!(have, want, "{:?} != {:?}", &have, &want);
})
}
}
pub fn test_vec_znx_add_scalar<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_add_scalar<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxAddScalar,
Module<BT>: VecZnxAddScalar,
@@ -74,12 +54,12 @@ where
let cols: usize = 2;
let mut a: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n, cols);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest = a.digest_u64();
for a_size in [1, 2, 3, 4] {
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
b.fill_uniform(basek, &mut source);
b.fill_uniform(base2k, &mut source);
let b_digest: u64 = b.digest_u64();
for res_size in [1, 2, 3, 4] {
@@ -87,8 +67,8 @@ where
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
// Set d to garbage
rest_ref.fill_uniform(basek, &mut source);
res_test.fill_uniform(basek, &mut source);
rest_ref.fill_uniform(base2k, &mut source);
res_test.fill_uniform(base2k, &mut source);
// Reference
for i in 0..cols {
@@ -103,7 +83,7 @@ where
}
}
pub fn test_vec_znx_add_scalar_inplace<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_add_scalar_inplace<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxAddScalarInplace,
Module<BT>: VecZnxAddScalarInplace,
@@ -116,14 +96,14 @@ where
let cols: usize = 2;
let mut b: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n, cols);
b.fill_uniform(basek, &mut source);
b.fill_uniform(base2k, &mut source);
let b_digest: u64 = b.digest_u64();
for res_size in [1, 2, 3, 4] {
let mut rest_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
rest_ref.fill_uniform(basek, &mut source);
rest_ref.fill_uniform(base2k, &mut source);
res_test.raw_mut().copy_from_slice(rest_ref.raw());
for i in 0..cols {
@@ -135,7 +115,7 @@ where
assert_eq!(rest_ref, res_test);
}
}
pub fn test_vec_znx_add<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_add<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxAdd,
Module<BT>: VecZnxAdd,
@@ -148,13 +128,13 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
for b_size in [1, 2, 3, 4] {
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, b_size);
b.fill_uniform(basek, &mut source);
b.fill_uniform(base2k, &mut source);
let b_digest: u64 = b.digest_u64();
@@ -163,8 +143,8 @@ where
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
// Set d to garbage
res_ref.fill_uniform(basek, &mut source);
res_test.fill_uniform(basek, &mut source);
res_ref.fill_uniform(base2k, &mut source);
res_test.fill_uniform(base2k, &mut source);
// Reference
for i in 0..cols {
@@ -181,7 +161,7 @@ where
}
}
pub fn test_vec_znx_add_inplace<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_add_inplace<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxAddInplace,
Module<BT>: VecZnxAddInplace,
@@ -194,14 +174,14 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
res_ref.fill_uniform(basek, &mut source);
res_ref.fill_uniform(base2k, &mut source);
res_test.raw_mut().copy_from_slice(res_ref.raw());
for i in 0..cols {
@@ -215,7 +195,7 @@ where
}
}
pub fn test_vec_znx_automorphism<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_automorphism<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxAutomorphism,
Module<BT>: VecZnxAutomorphism,
@@ -228,7 +208,7 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
@@ -261,7 +241,7 @@ where
}
pub fn test_vec_znx_automorphism_inplace<BR: Backend, BT: Backend>(
basek: usize,
base2k: usize,
module_ref: &Module<BR>,
module_test: &Module<BT>,
) where
@@ -284,7 +264,7 @@ pub fn test_vec_znx_automorphism_inplace<BR: Backend, BT: Backend>(
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
// Fill a with random i64
res_ref.fill_uniform(basek, &mut source);
res_ref.fill_uniform(base2k, &mut source);
znx_copy_ref(res_test.raw_mut(), res_ref.raw());
let p: i64 = -7;
@@ -309,7 +289,7 @@ pub fn test_vec_znx_automorphism_inplace<BR: Backend, BT: Backend>(
}
}
pub fn test_vec_znx_copy<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_copy<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxCopy,
Module<BT>: VecZnxCopy,
@@ -322,7 +302,7 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
@@ -330,8 +310,8 @@ where
let mut res_1: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
// Set d to garbage
res_0.fill_uniform(basek, &mut source);
res_1.fill_uniform(basek, &mut source);
res_0.fill_uniform(base2k, &mut source);
res_1.fill_uniform(base2k, &mut source);
// Reference
for i in 0..cols {
@@ -345,7 +325,7 @@ where
}
}
pub fn test_vec_znx_merge_rings<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_merge_rings<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxMergeRings<BR> + ModuleNew<BR> + VecZnxMergeRingsTmpBytes,
Module<BT>: VecZnxMergeRings<BT> + ModuleNew<BT> + VecZnxMergeRingsTmpBytes,
@@ -367,7 +347,7 @@ where
];
a.iter_mut().for_each(|ai| {
ai.fill_uniform(basek, &mut source);
ai.fill_uniform(base2k, &mut source);
});
let a_digests: [u64; 2] = [a[0].digest_u64(), a[1].digest_u64()];
@@ -376,8 +356,8 @@ where
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
res_ref.fill_uniform(basek, &mut source);
res_test.fill_uniform(basek, &mut source);
res_ref.fill_uniform(base2k, &mut source);
res_test.fill_uniform(base2k, &mut source);
for i in 0..cols {
module_ref.vec_znx_merge_rings(&mut res_test, i, &a, i, scratch_ref.borrow());
@@ -390,7 +370,7 @@ where
}
}
pub fn test_vec_znx_mul_xp_minus_one<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_mul_xp_minus_one<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxMulXpMinusOne,
Module<BT>: VecZnxMulXpMinusOne,
@@ -403,7 +383,7 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
@@ -437,7 +417,7 @@ where
}
pub fn test_vec_znx_mul_xp_minus_one_inplace<BR: Backend, BT: Backend>(
basek: usize,
base2k: usize,
module_ref: &Module<BR>,
module_test: &Module<BT>,
) where
@@ -460,7 +440,7 @@ pub fn test_vec_znx_mul_xp_minus_one_inplace<BR: Backend, BT: Backend>(
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
// Fill a with random i64
res_ref.fill_uniform(basek, &mut source);
res_ref.fill_uniform(base2k, &mut source);
znx_copy_ref(res_test.raw_mut(), res_ref.raw());
let p: i64 = -7;
@@ -483,7 +463,7 @@ pub fn test_vec_znx_mul_xp_minus_one_inplace<BR: Backend, BT: Backend>(
}
}
pub fn test_vec_znx_negate<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_negate<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxNegate,
Module<BT>: VecZnxNegate,
@@ -496,14 +476,14 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
res_ref.fill_uniform(basek, &mut source);
res_ref.fill_uniform(base2k, &mut source);
res_test.raw_mut().copy_from_slice(res_ref.raw());
for i in 0..cols {
@@ -517,7 +497,7 @@ where
}
}
pub fn test_vec_znx_negate_inplace<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_negate_inplace<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxNegateInplace,
Module<BT>: VecZnxNegateInplace,
@@ -532,7 +512,7 @@ where
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
res_ref.fill_uniform(basek, &mut source);
res_ref.fill_uniform(base2k, &mut source);
res_test.raw_mut().copy_from_slice(res_ref.raw());
for i in 0..cols {
@@ -544,7 +524,7 @@ where
}
}
pub fn test_vec_znx_normalize<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_normalize<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxNormalize<BR> + VecZnxNormalizeTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
@@ -562,7 +542,7 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
@@ -570,13 +550,21 @@ where
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
// Set d to garbage
res_ref.fill_uniform(basek, &mut source);
res_test.fill_uniform(basek, &mut source);
res_ref.fill_uniform(base2k, &mut source);
res_test.fill_uniform(base2k, &mut source);
// Reference
for i in 0..cols {
module_ref.vec_znx_normalize(basek, &mut res_ref, i, &a, i, scratch_ref.borrow());
module_test.vec_znx_normalize(basek, &mut res_test, i, &a, i, scratch_test.borrow());
module_ref.vec_znx_normalize(base2k, &mut res_ref, i, base2k, &a, i, scratch_ref.borrow());
module_test.vec_znx_normalize(
base2k,
&mut res_test,
i,
base2k,
&a,
i,
scratch_test.borrow(),
);
}
assert_eq!(a.digest_u64(), a_digest);
@@ -585,7 +573,7 @@ where
}
}
pub fn test_vec_znx_normalize_inplace<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_normalize_inplace<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxNormalizeInplace<BR> + VecZnxNormalizeTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
@@ -605,20 +593,20 @@ where
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
res_ref.fill_uniform(basek, &mut source);
res_ref.fill_uniform(base2k, &mut source);
res_test.raw_mut().copy_from_slice(res_ref.raw());
// Reference
for i in 0..cols {
module_ref.vec_znx_normalize_inplace(basek, &mut res_ref, i, scratch_ref.borrow());
module_test.vec_znx_normalize_inplace(basek, &mut res_test, i, scratch_test.borrow());
module_ref.vec_znx_normalize_inplace(base2k, &mut res_ref, i, scratch_ref.borrow());
module_test.vec_znx_normalize_inplace(base2k, &mut res_test, i, scratch_test.borrow());
}
assert_eq!(res_ref, res_test);
}
}
pub fn test_vec_znx_rotate<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_rotate<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxRotate,
Module<BT>: VecZnxRotate,
@@ -631,7 +619,7 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
@@ -663,7 +651,7 @@ where
}
}
pub fn test_vec_znx_rotate_inplace<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_rotate_inplace<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxRotateInplace<BR> + VecZnxRotateInplaceTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
@@ -684,7 +672,7 @@ where
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
// Fill a with random i64
res_ref.fill_uniform(basek, &mut source);
res_ref.fill_uniform(base2k, &mut source);
znx_copy_ref(res_test.raw_mut(), res_ref.raw());
let p: i64 = -5;
@@ -714,7 +702,7 @@ where
Module<B>: VecZnxFillUniform,
{
let n: usize = module.n();
let basek: usize = 17;
let base2k: usize = 17;
let size: usize = 5;
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
@@ -722,19 +710,17 @@ where
let one_12_sqrt: f64 = 0.28867513459481287;
(0..cols).for_each(|col_i| {
let mut a: VecZnx<_> = VecZnx::alloc(n, cols, size);
module.vec_znx_fill_uniform(basek, &mut a, col_i, &mut source);
module.vec_znx_fill_uniform(base2k, &mut a, col_i, &mut source);
(0..cols).for_each(|col_j| {
if col_j != col_i {
(0..size).for_each(|limb_i| {
assert_eq!(a.at(col_j, limb_i), zero);
})
} else {
let std: f64 = a.std(basek, col_i);
let std: f64 = a.std(base2k, col_i);
assert!(
(std - one_12_sqrt).abs() < 0.01,
"std={} ~!= {}",
std,
one_12_sqrt
"std={std} ~!= {one_12_sqrt}",
);
}
})
@@ -746,7 +732,7 @@ where
Module<B>: VecZnxFillNormal,
{
let n: usize = module.n();
let basek: usize = 17;
let base2k: usize = 17;
let k: usize = 2 * 17;
let size: usize = 5;
let sigma: f64 = 3.2;
@@ -757,15 +743,15 @@ where
let k_f64: f64 = (1u64 << k as u64) as f64;
(0..cols).for_each(|col_i| {
let mut a: VecZnx<_> = VecZnx::alloc(n, cols, size);
module.vec_znx_fill_normal(basek, &mut a, col_i, k, &mut source, sigma, bound);
module.vec_znx_fill_normal(base2k, &mut a, col_i, k, &mut source, sigma, bound);
(0..cols).for_each(|col_j| {
if col_j != col_i {
(0..size).for_each(|limb_i| {
assert_eq!(a.at(col_j, limb_i), zero);
})
} else {
let std: f64 = a.std(basek, col_i) * k_f64;
assert!((std - sigma).abs() < 0.1, "std={} ~!= {}", std, sigma);
let std: f64 = a.std(base2k, col_i) * k_f64;
assert!((std - sigma).abs() < 0.1, "std={std} ~!= {sigma}");
}
})
});
@@ -776,7 +762,7 @@ where
Module<B>: VecZnxFillNormal + VecZnxAddNormal,
{
let n: usize = module.n();
let basek: usize = 17;
let base2k: usize = 17;
let k: usize = 2 * 17;
let size: usize = 5;
let sigma: f64 = 3.2;
@@ -788,19 +774,18 @@ where
let sqrt2: f64 = SQRT_2;
(0..cols).for_each(|col_i| {
let mut a: VecZnx<_> = VecZnx::alloc(n, cols, size);
module.vec_znx_fill_normal(basek, &mut a, col_i, k, &mut source, sigma, bound);
module.vec_znx_add_normal(basek, &mut a, col_i, k, &mut source, sigma, bound);
module.vec_znx_fill_normal(base2k, &mut a, col_i, k, &mut source, sigma, bound);
module.vec_znx_add_normal(base2k, &mut a, col_i, k, &mut source, sigma, bound);
(0..cols).for_each(|col_j| {
if col_j != col_i {
(0..size).for_each(|limb_i| {
assert_eq!(a.at(col_j, limb_i), zero);
})
} else {
let std: f64 = a.std(basek, col_i) * k_f64;
let std: f64 = a.std(base2k, col_i) * k_f64;
assert!(
(std - sigma * sqrt2).abs() < 0.1,
"std={} ~!= {}",
std,
"std={std} ~!= {}",
sigma * sqrt2
);
}
@@ -808,7 +793,7 @@ where
});
}
pub fn test_vec_znx_lsh<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_lsh<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxLsh<BR> + VecZnxLshTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
@@ -826,22 +811,22 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
for k in 0..res_size * basek {
for k in 0..res_size * base2k {
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
// Set d to garbage
res_ref.fill_uniform(basek, &mut source);
res_test.fill_uniform(basek, &mut source);
res_ref.fill_uniform(base2k, &mut source);
res_test.fill_uniform(base2k, &mut source);
// Reference
for i in 0..cols {
module_ref.vec_znx_lsh(basek, k, &mut res_ref, i, &a, i, scratch_ref.borrow());
module_test.vec_znx_lsh(basek, k, &mut res_test, i, &a, i, scratch_test.borrow());
module_ref.vec_znx_lsh(base2k, k, &mut res_ref, i, &a, i, scratch_ref.borrow());
module_test.vec_znx_lsh(base2k, k, &mut res_test, i, &a, i, scratch_test.borrow());
}
assert_eq!(a.digest_u64(), a_digest);
@@ -851,7 +836,7 @@ where
}
}
pub fn test_vec_znx_lsh_inplace<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_lsh_inplace<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxLshInplace<BR> + VecZnxLshTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
@@ -868,16 +853,16 @@ where
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_lsh_tmp_bytes());
for res_size in [1, 2, 3, 4] {
for k in 0..basek * res_size {
for k in 0..base2k * res_size {
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
res_ref.fill_uniform(basek, &mut source);
res_ref.fill_uniform(base2k, &mut source);
res_test.raw_mut().copy_from_slice(res_ref.raw());
for i in 0..cols {
module_ref.vec_znx_lsh_inplace(basek, k, &mut res_ref, i, scratch_ref.borrow());
module_test.vec_znx_lsh_inplace(basek, k, &mut res_test, i, scratch_test.borrow());
module_ref.vec_znx_lsh_inplace(base2k, k, &mut res_ref, i, scratch_ref.borrow());
module_test.vec_znx_lsh_inplace(base2k, k, &mut res_test, i, scratch_test.borrow());
}
assert_eq!(res_ref, res_test);
@@ -885,7 +870,7 @@ where
}
}
pub fn test_vec_znx_rsh<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_rsh<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxRsh<BR> + VecZnxLshTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
@@ -902,22 +887,22 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
for k in 0..res_size * basek {
for k in 0..res_size * base2k {
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
// Set d to garbage
res_ref.fill_uniform(basek, &mut source);
res_test.fill_uniform(basek, &mut source);
res_ref.fill_uniform(base2k, &mut source);
res_test.fill_uniform(base2k, &mut source);
// Reference
for i in 0..cols {
module_ref.vec_znx_rsh(basek, k, &mut res_ref, i, &a, i, scratch_ref.borrow());
module_test.vec_znx_rsh(basek, k, &mut res_test, i, &a, i, scratch_test.borrow());
module_ref.vec_znx_rsh(base2k, k, &mut res_ref, i, &a, i, scratch_ref.borrow());
module_test.vec_znx_rsh(base2k, k, &mut res_test, i, &a, i, scratch_test.borrow());
}
assert_eq!(a.digest_u64(), a_digest);
@@ -927,7 +912,7 @@ where
}
}
pub fn test_vec_znx_rsh_inplace<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_rsh_inplace<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxRshInplace<BR> + VecZnxLshTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
@@ -943,16 +928,16 @@ where
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_lsh_tmp_bytes());
for res_size in [1, 2, 3, 4] {
for k in 0..basek * res_size {
for k in 0..base2k * res_size {
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
res_ref.fill_uniform(basek, &mut source);
res_ref.fill_uniform(base2k, &mut source);
res_test.raw_mut().copy_from_slice(res_ref.raw());
for i in 0..cols {
module_ref.vec_znx_rsh_inplace(basek, k, &mut res_ref, i, scratch_ref.borrow());
module_test.vec_znx_rsh_inplace(basek, k, &mut res_test, i, scratch_test.borrow());
module_ref.vec_znx_rsh_inplace(base2k, k, &mut res_ref, i, scratch_ref.borrow());
module_test.vec_znx_rsh_inplace(base2k, k, &mut res_test, i, scratch_test.borrow());
}
assert_eq!(res_ref, res_test);
@@ -960,7 +945,7 @@ where
}
}
pub fn test_vec_znx_split_ring<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_split_ring<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxSplitRing<BR> + ModuleNew<BR> + VecZnxSplitRingTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
@@ -977,7 +962,7 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest = a.digest_u64();
for res_size in [1, 2, 3, 4] {
@@ -992,11 +977,11 @@ where
];
res_ref.iter_mut().for_each(|ri| {
ri.fill_uniform(basek, &mut source);
ri.fill_uniform(base2k, &mut source);
});
res_test.iter_mut().for_each(|ri| {
ri.fill_uniform(basek, &mut source);
ri.fill_uniform(base2k, &mut source);
});
for i in 0..cols {
@@ -1013,7 +998,7 @@ where
}
}
pub fn test_vec_znx_sub_scalar<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_sub_scalar<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxSubScalar,
Module<BT>: VecZnxSubScalar,
@@ -1025,12 +1010,12 @@ where
let cols: usize = 2;
let mut a: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n, cols);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
for b_size in [1, 2, 3, 4] {
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, b_size);
b.fill_uniform(basek, &mut source);
b.fill_uniform(base2k, &mut source);
let b_digest: u64 = b.digest_u64();
for res_size in [1, 2, 3, 4] {
@@ -1038,8 +1023,8 @@ where
let mut res_1: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
// Set d to garbage
res_0.fill_uniform(basek, &mut source);
res_1.fill_uniform(basek, &mut source);
res_0.fill_uniform(base2k, &mut source);
res_1.fill_uniform(base2k, &mut source);
// Reference
for i in 0..cols {
@@ -1054,7 +1039,7 @@ where
}
}
pub fn test_vec_znx_sub_scalar_inplace<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_sub_scalar_inplace<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxSubScalarInplace,
Module<BT>: VecZnxSubScalarInplace,
@@ -1066,14 +1051,14 @@ where
let cols: usize = 2;
let mut a: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n, cols);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
let mut res_0: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_1: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
res_0.fill_uniform(basek, &mut source);
res_0.fill_uniform(base2k, &mut source);
res_1.raw_mut().copy_from_slice(res_0.raw());
for i in 0..cols {
@@ -1086,7 +1071,7 @@ where
}
}
pub fn test_vec_znx_sub<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_sub<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxSub,
Module<BT>: VecZnxSub,
@@ -1099,12 +1084,12 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
for b_size in [1, 2, 3, 4] {
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, b_size);
b.fill_uniform(basek, &mut source);
b.fill_uniform(base2k, &mut source);
let b_digest: u64 = b.digest_u64();
for res_size in [1, 2, 3, 4] {
@@ -1112,8 +1097,8 @@ where
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
// Set d to garbage
res_ref.fill_uniform(basek, &mut source);
res_test.fill_uniform(basek, &mut source);
res_ref.fill_uniform(base2k, &mut source);
res_test.fill_uniform(base2k, &mut source);
// Reference
for i in 0..cols {
@@ -1130,10 +1115,10 @@ where
}
}
pub fn test_vec_znx_sub_ab_inplace<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_sub_inplace<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxSubABInplace,
Module<BT>: VecZnxSubABInplace,
Module<BR>: VecZnxSubInplace,
Module<BT>: VecZnxSubInplace,
{
assert_eq!(module_ref.n(), module_test.n());
let n: usize = module_ref.n();
@@ -1143,19 +1128,19 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
res_ref.fill_uniform(basek, &mut source);
res_ref.fill_uniform(base2k, &mut source);
res_test.raw_mut().copy_from_slice(res_ref.raw());
for i in 0..cols {
module_test.vec_znx_sub_ab_inplace(&mut res_ref, i, &a, i);
module_ref.vec_znx_sub_ab_inplace(&mut res_test, i, &a, i);
module_test.vec_znx_sub_inplace(&mut res_ref, i, &a, i);
module_ref.vec_znx_sub_inplace(&mut res_test, i, &a, i);
}
assert_eq!(a.digest_u64(), a_digest);
@@ -1164,10 +1149,10 @@ where
}
}
pub fn test_vec_znx_sub_ba_inplace<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_sub_negate_inplace<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxSubBAInplace,
Module<BT>: VecZnxSubBAInplace,
Module<BR>: VecZnxSubNegateInplace,
Module<BT>: VecZnxSubNegateInplace,
{
assert_eq!(module_ref.n(), module_test.n());
let n: usize = module_ref.n();
@@ -1177,19 +1162,19 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
res_ref.fill_uniform(basek, &mut source);
res_ref.fill_uniform(base2k, &mut source);
res_test.raw_mut().copy_from_slice(res_ref.raw());
for i in 0..cols {
module_test.vec_znx_sub_ba_inplace(&mut res_ref, i, &a, i);
module_ref.vec_znx_sub_ba_inplace(&mut res_test, i, &a, i);
module_test.vec_znx_sub_negate_inplace(&mut res_ref, i, &a, i);
module_ref.vec_znx_sub_negate_inplace(&mut res_test, i, &a, i);
}
assert_eq!(a.digest_u64(), a_digest);
@@ -1198,7 +1183,7 @@ where
}
}
pub fn test_vec_znx_switch_ring<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_switch_ring<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxSwitchRing,
Module<BT>: VecZnxSwitchRing,
@@ -1213,7 +1198,7 @@ where
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
// Fill a with random i64
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
@@ -1221,8 +1206,8 @@ where
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n << 1, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n << 1, cols, res_size);
res_ref.fill_uniform(basek, &mut source);
res_test.fill_uniform(basek, &mut source);
res_ref.fill_uniform(base2k, &mut source);
res_test.fill_uniform(base2k, &mut source);
// Normalize on c
for i in 0..cols {
@@ -1238,8 +1223,8 @@ where
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n >> 1, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n >> 1, cols, res_size);
res_ref.fill_uniform(basek, &mut source);
res_test.fill_uniform(basek, &mut source);
res_ref.fill_uniform(base2k, &mut source);
res_test.fill_uniform(base2k, &mut source);
// Normalize on c
for i in 0..cols {