Added more serialization tests + generalize methods to any n

This commit is contained in:
Pro7ech
2025-08-13 15:28:52 +02:00
parent 068470783e
commit 940742ce6c
117 changed files with 3658 additions and 2577 deletions

View File

@@ -3,7 +3,7 @@ use std::fmt::Debug;
use sampling::source::Source;
use crate::hal::{
api::{FillUniform, ZnxZero},
api::{FillUniform, Reset},
layouts::{ReaderFrom, WriterTo},
};
@@ -12,7 +12,7 @@ use crate::hal::{
/// - `T` must implement I/O traits, zeroing, cloning, and random filling.
pub fn test_reader_writer_interface<T>(mut original: T)
where
T: WriterTo + ReaderFrom + PartialEq + Eq + Debug + Clone + ZnxZero + FillUniform,
T: WriterTo + ReaderFrom + PartialEq + Eq + Debug + Clone + Reset + FillUniform,
{
// Fill original with uniform random data
let mut source = Source::new([0u8; 32]);
@@ -24,7 +24,7 @@ where
// Prepare receiver: same shape, but zeroed
let mut receiver = original.clone();
receiver.zero();
receiver.reset();
// Deserialize from buffer
let mut reader: &[u8] = &buffer;
@@ -45,7 +45,7 @@ fn scalar_znx_serialize() {
#[test]
fn vec_znx_serialize() {
let original: crate::hal::layouts::VecZnx<Vec<u8>> = crate::hal::layouts::VecZnx::alloc::<i64>(1024, 3, 4);
let original: crate::hal::layouts::VecZnx<Vec<u8>> = crate::hal::layouts::VecZnx::alloc(1024, 3, 4);
test_reader_writer_interface(original);
}

View File

@@ -2,25 +2,23 @@ use itertools::izip;
use sampling::source::Source;
use crate::hal::{
api::{
VecZnxAddNormal, VecZnxAlloc, VecZnxDecodeVeci64, VecZnxEncodeVeci64, VecZnxFillUniform, VecZnxStd, ZnxInfos, ZnxView,
ZnxViewMut,
},
api::{VecZnxAddNormal, VecZnxDecodeVeci64, VecZnxEncodeVeci64, VecZnxFillUniform, VecZnxStd, ZnxInfos, ZnxView, ZnxViewMut},
layouts::{Backend, Module, VecZnx},
};
pub fn test_vec_znx_fill_uniform<B: Backend>(module: &Module<B>)
where
Module<B>: VecZnxFillUniform + VecZnxStd + VecZnxAlloc,
Module<B>: VecZnxFillUniform + VecZnxStd,
{
let n: usize = module.n();
let basek: usize = 17;
let size: usize = 5;
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let zero: Vec<i64> = vec![0; module.n()];
let zero: Vec<i64> = vec![0; n];
let one_12_sqrt: f64 = 0.28867513459481287;
(0..cols).for_each(|col_i| {
let mut a: VecZnx<_> = module.vec_znx_alloc(cols, size);
let mut a: VecZnx<_> = VecZnx::alloc(n, cols, size);
module.vec_znx_fill_uniform(basek, &mut a, col_i, size * basek, &mut source);
(0..cols).for_each(|col_j| {
if col_j != col_i {
@@ -42,8 +40,9 @@ where
pub fn test_vec_znx_add_normal<B: Backend>(module: &Module<B>)
where
Module<B>: VecZnxAddNormal + VecZnxStd + VecZnxAlloc,
Module<B>: VecZnxAddNormal + VecZnxStd,
{
let n: usize = module.n();
let basek: usize = 17;
let k: usize = 2 * 17;
let size: usize = 5;
@@ -51,10 +50,10 @@ where
let bound: f64 = 6.0 * sigma;
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let zero: Vec<i64> = vec![0; module.n()];
let zero: Vec<i64> = vec![0; n];
let k_f64: f64 = (1u64 << k as u64) as f64;
(0..cols).for_each(|col_i| {
let mut a: VecZnx<_> = module.vec_znx_alloc(cols, size);
let mut a: VecZnx<_> = VecZnx::alloc(n, cols, size);
module.vec_znx_add_normal(basek, &mut a, col_i, k, &mut source, sigma, bound);
(0..cols).for_each(|col_j| {
if col_j != col_i {
@@ -71,21 +70,22 @@ where
pub fn test_vec_znx_encode_vec_i64_lo_norm<B: Backend>(module: &Module<B>)
where
Module<B>: VecZnxEncodeVeci64 + VecZnxDecodeVeci64 + VecZnxAlloc,
Module<B>: VecZnxEncodeVeci64 + VecZnxDecodeVeci64,
{
let n: usize = module.n();
let basek: usize = 17;
let size: usize = 5;
let k: usize = size * basek - 5;
let mut a: VecZnx<_> = module.vec_znx_alloc(2, size);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, 2, size);
let mut source: Source = Source::new([0u8; 32]);
let raw: &mut [i64] = a.raw_mut();
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
(0..a.cols()).for_each(|col_i| {
let mut have: Vec<i64> = vec![i64::default(); module.n()];
let mut have: Vec<i64> = vec![i64::default(); n];
have.iter_mut()
.for_each(|x| *x = (source.next_i64() << 56) >> 56);
module.encode_vec_i64(basek, &mut a, col_i, k, &have, 10);
let mut want: Vec<i64> = vec![i64::default(); module.n()];
let mut want: Vec<i64> = vec![i64::default(); n];
module.decode_vec_i64(basek, &a, col_i, k, &mut want);
izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
});
@@ -93,17 +93,18 @@ where
pub fn test_vec_znx_encode_vec_i64_hi_norm<B: Backend>(module: &Module<B>)
where
Module<B>: VecZnxEncodeVeci64 + VecZnxDecodeVeci64 + VecZnxAlloc,
Module<B>: VecZnxEncodeVeci64 + VecZnxDecodeVeci64,
{
let n: usize = module.n();
let basek: usize = 17;
let size: usize = 5;
for k in [1, basek / 2, size * basek - 5] {
let mut a: VecZnx<_> = module.vec_znx_alloc(2, size);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, 2, size);
let mut source = Source::new([0u8; 32]);
let raw: &mut [i64] = a.raw_mut();
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
(0..a.cols()).for_each(|col_i| {
let mut have: Vec<i64> = vec![i64::default(); module.n()];
let mut have: Vec<i64> = vec![i64::default(); n];
have.iter_mut().for_each(|x| {
if k < 64 {
*x = source.next_u64n(1 << k, (1 << k) - 1) as i64;
@@ -112,7 +113,7 @@ where
}
});
module.encode_vec_i64(basek, &mut a, col_i, k, &have, 63);
let mut want: Vec<i64> = vec![i64::default(); module.n()];
let mut want: Vec<i64> = vec![i64::default(); n];
module.decode_vec_i64(basek, &a, col_i, k, &mut want);
izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
})