Ref. + AVX code & generic tests + benches (#85)

This commit is contained in:
Jean-Philippe Bossuat
2025-09-15 16:16:11 +02:00
committed by GitHub
parent 99b9e3e10e
commit 56dbd29c59
286 changed files with 27797 additions and 7270 deletions

View File

@@ -0,0 +1,68 @@
pub mod serialization;
pub mod svp;
pub mod vec_znx;
pub mod vec_znx_big;
pub mod vec_znx_dft;
pub mod vmp;
#[macro_export]
macro_rules! backend_test_suite {
(
mod $modname:ident,
backend = $backend:ty,
size = $size:expr,
tests = {
$( $(#[$attr:meta])* $test_name:ident => $impl:path ),+ $(,)?
}
) => {
mod $modname {
use poulpy_hal::{api::ModuleNew, layouts::Module};
use once_cell::sync::Lazy;
static MODULE: Lazy<Module<$backend>> =
Lazy::new(|| Module::<$backend>::new($size));
$(
$(#[$attr])*
#[test]
fn $test_name() {
($impl)(&*MODULE);
}
)+
}
};
}
#[macro_export]
macro_rules! cross_backend_test_suite {
(
mod $modname:ident,
backend_ref = $backend_ref:ty,
backend_test = $backend_test:ty,
size = $size:expr,
basek = $basek:expr,
tests = {
$( $(#[$attr:meta])* $test_name:ident => $impl:path ),+ $(,)?
}
) => {
mod $modname {
use poulpy_hal::{api::ModuleNew, layouts::Module};
use once_cell::sync::Lazy;
static MODULE_REF: Lazy<Module<$backend_ref>> =
Lazy::new(|| Module::<$backend_ref>::new($size));
static MODULE_TEST: Lazy<Module<$backend_test>> =
Lazy::new(|| Module::<$backend_test>::new($size));
$(
$(#[$attr])*
#[test]
fn $test_name() {
($impl)($basek, &*MODULE_REF, &*MODULE_TEST);
}
)+
}
};
}

View File

@@ -0,0 +1,54 @@
use std::fmt::Debug;
use crate::{
layouts::{FillUniform, ReaderFrom, Reset, WriterTo},
source::Source,
};
/// Generic test for serialization and deserialization.
///
/// - `T` must implement I/O traits, zeroing, cloning, and random filling.
pub fn test_reader_writer_interface<T>(mut original: T)
where
T: WriterTo + ReaderFrom + PartialEq + Eq + Debug + Clone + Reset + FillUniform,
{
// Fill original with uniform random data
let mut source = Source::new([0u8; 32]);
original.fill_uniform(50, &mut source);
// Serialize into a buffer
let mut buffer = Vec::new();
original.write_to(&mut buffer).expect("write_to failed");
// Prepare receiver: same shape, but zeroed
let mut receiver = original.clone();
receiver.reset();
// Deserialize from buffer
let mut reader: &[u8] = &buffer;
receiver.read_from(&mut reader).expect("read_from failed");
// Ensure serialization round-trip correctness
assert_eq!(
&original, &receiver,
"Deserialized object does not match the original"
);
}
#[test]
fn scalar_znx_serialize() {
let original: crate::layouts::ScalarZnx<Vec<u8>> = crate::layouts::ScalarZnx::alloc(1024, 3);
test_reader_writer_interface(original);
}
#[test]
fn vec_znx_serialize() {
let original: crate::layouts::VecZnx<Vec<u8>> = crate::layouts::VecZnx::alloc(1024, 3, 4);
test_reader_writer_interface(original);
}
#[test]
fn mat_znx_serialize() {
let original: crate::layouts::MatZnx<Vec<u8>> = crate::layouts::MatZnx::alloc(1024, 3, 2, 2, 4);
test_reader_writer_interface(original);
}

View File

@@ -0,0 +1,470 @@
use rand::RngCore;
use crate::{
api::{
ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDft, SvpApplyDftToDft, SvpApplyDftToDftAdd, SvpApplyDftToDftInplace,
SvpPPolAlloc, SvpPrepare, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAlloc, VecZnxDftApply,
VecZnxIdftApplyConsume,
},
layouts::{Backend, DataViewMut, DigestU64, FillUniform, Module, ScalarZnx, ScratchOwned, SvpPPol, VecZnx, VecZnxDft},
source::Source,
};
pub fn test_svp_apply_dft<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: SvpPrepare<BR>
+ SvpApplyDft<BR>
+ SvpPPolAlloc<BR>
+ VecZnxDftAlloc<BR>
+ VecZnxBigNormalize<BR>
+ VecZnxIdftApplyConsume<BR>
+ VecZnxBigNormalizeTmpBytes,
Module<BT>: SvpPrepare<BT>
+ SvpApplyDft<BT>
+ SvpPPolAlloc<BT>
+ VecZnxDftAlloc<BT>
+ VecZnxBigNormalize<BT>
+ VecZnxIdftApplyConsume<BT>
+ VecZnxBigNormalizeTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
let n: usize = module_ref.n();
let cols: usize = 2;
let mut source: Source = Source::new([0u8; 32]);
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes());
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
let mut scalar: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n, cols);
scalar.fill_uniform(basek, &mut source);
let scalar_digest: u64 = scalar.digest_u64();
let mut svp_ref: SvpPPol<Vec<u8>, BR> = module_ref.svp_ppol_alloc(cols);
let mut svp_test: SvpPPol<Vec<u8>, BT> = module_test.svp_ppol_alloc(cols);
for j in 0..cols {
module_ref.svp_prepare(&mut svp_ref, j, &scalar, j);
module_test.svp_prepare(&mut svp_test, j, &scalar, j);
}
assert_eq!(scalar.digest_u64(), scalar_digest);
let svp_ref_digest: u64 = svp_ref.digest_u64();
let svp_test_digest: u64 = svp_test.digest_u64();
for a_size in [1, 2, 3, 4] {
// Create a random input VecZnx
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
// Allocate VecZnxDft from FFT64Ref and module to test
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, res_size);
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, res_size);
// Fill output with garbage
source.fill_bytes(res_dft_ref.data_mut());
source.fill_bytes(res_dft_test.data_mut());
for j in 0..cols {
module_ref.svp_apply_dft(&mut res_dft_ref, j, &svp_ref, j, &a, j);
module_test.svp_apply_dft(&mut res_dft_test, j, &svp_test, j, &a, j);
}
// Assert no change to inputs
assert_eq!(svp_ref.digest_u64(), svp_ref_digest);
assert_eq!(svp_test.digest_u64(), svp_test_digest);
assert_eq!(a.digest_u64(), a_digest);
let res_big_ref: crate::layouts::VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
let res_big_test: crate::layouts::VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
&mut res_ref,
j,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
&mut res_test,
j,
&res_big_test,
j,
scratch_test.borrow(),
);
}
assert_eq!(res_ref, res_test);
}
}
}
pub fn test_svp_apply_dft_to_dft<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: SvpPrepare<BR>
+ SvpApplyDftToDft<BR>
+ SvpPPolAlloc<BR>
+ VecZnxDftAlloc<BR>
+ VecZnxBigNormalize<BR>
+ VecZnxDftApply<BR>
+ VecZnxIdftApplyConsume<BR>
+ VecZnxBigNormalizeTmpBytes,
Module<BT>: SvpPrepare<BT>
+ SvpApplyDftToDft<BT>
+ SvpPPolAlloc<BT>
+ VecZnxDftAlloc<BT>
+ VecZnxBigNormalize<BT>
+ VecZnxDftApply<BT>
+ VecZnxIdftApplyConsume<BT>
+ VecZnxBigNormalizeTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
let n: usize = module_ref.n();
let cols: usize = 2;
let mut source: Source = Source::new([0u8; 32]);
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes());
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
let mut scalar: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n, cols);
scalar.fill_uniform(basek, &mut source);
let scalar_digest: u64 = scalar.digest_u64();
let mut svp_ref: SvpPPol<Vec<u8>, BR> = module_ref.svp_ppol_alloc(cols);
let mut svp_test: SvpPPol<Vec<u8>, BT> = module_test.svp_ppol_alloc(cols);
for j in 0..cols {
module_ref.svp_prepare(&mut svp_ref, j, &scalar, j);
module_test.svp_prepare(&mut svp_test, j, &scalar, j);
}
assert_eq!(scalar.digest_u64(), scalar_digest);
let svp_ref_digest: u64 = svp_ref.digest_u64();
let svp_test_digest: u64 = svp_test.digest_u64();
for a_size in [3] {
// Create a random input VecZnx
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
let a_digest: u64 = a.digest_u64();
let mut a_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
let mut a_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, a_size);
for j in 0..cols {
module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j);
module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j);
}
assert_eq!(a.digest_u64(), a_digest);
let a_dft_ref_digest: u64 = a_dft_ref.digest_u64();
let a_dft_test_digest: u64 = a_dft_test.digest_u64();
for res_size in [3] {
// Allocate VecZnxDft from FFT64Ref and module to test
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, res_size);
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, res_size);
// Fill output with garbage
source.fill_bytes(res_dft_ref.data_mut());
source.fill_bytes(res_dft_test.data_mut());
for j in 0..cols {
module_ref.svp_apply_dft_to_dft(&mut res_dft_ref, j, &svp_ref, j, &a_dft_ref, j);
module_test.svp_apply_dft_to_dft(&mut res_dft_test, j, &svp_test, j, &a_dft_test, j);
}
// Assert no change to inputs
assert_eq!(a_dft_ref.digest_u64(), a_dft_ref_digest);
assert_eq!(a_dft_test.digest_u64(), a_dft_test_digest);
assert_eq!(svp_ref.digest_u64(), svp_ref_digest);
assert_eq!(svp_test.digest_u64(), svp_test_digest);
assert_eq!(a.digest_u64(), a_digest);
let res_big_ref: crate::layouts::VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
let res_big_test: crate::layouts::VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
println!("res_big_ref: {}", res_big_ref);
println!("res_big_test: {}", res_big_test);
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
&mut res_ref,
j,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
&mut res_test,
j,
&res_big_test,
j,
scratch_test.borrow(),
);
}
assert_eq!(res_ref, res_test);
}
}
}
pub fn test_svp_apply_dft_to_dft_add<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: SvpPrepare<BR>
+ SvpApplyDftToDftAdd<BR>
+ SvpPPolAlloc<BR>
+ VecZnxDftAlloc<BR>
+ VecZnxBigNormalize<BR>
+ VecZnxDftApply<BR>
+ VecZnxIdftApplyConsume<BR>
+ VecZnxBigNormalizeTmpBytes,
Module<BT>: SvpPrepare<BT>
+ SvpApplyDftToDftAdd<BT>
+ SvpPPolAlloc<BT>
+ VecZnxDftAlloc<BT>
+ VecZnxBigNormalize<BT>
+ VecZnxDftApply<BT>
+ VecZnxIdftApplyConsume<BT>
+ VecZnxBigNormalizeTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
let n: usize = module_ref.n();
let cols: usize = 2;
let mut source: Source = Source::new([0u8; 32]);
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes());
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
let mut scalar: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n, cols);
scalar.fill_uniform(basek, &mut source);
let scalar_digest: u64 = scalar.digest_u64();
let mut svp_ref: SvpPPol<Vec<u8>, BR> = module_ref.svp_ppol_alloc(cols);
let mut svp_test: SvpPPol<Vec<u8>, BT> = module_test.svp_ppol_alloc(cols);
for j in 0..cols {
module_ref.svp_prepare(&mut svp_ref, j, &scalar, j);
module_test.svp_prepare(&mut svp_test, j, &scalar, j);
}
assert_eq!(scalar.digest_u64(), scalar_digest);
let svp_ref_digest: u64 = svp_ref.digest_u64();
let svp_test_digest: u64 = svp_test.digest_u64();
for a_size in [1, 2, 3, 4] {
// Create a random input VecZnx
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
let a_digest: u64 = a.digest_u64();
let mut a_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
let mut a_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, a_size);
for j in 0..cols {
module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j);
module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j);
}
let a_dft_ref_digest: u64 = a_dft_ref.digest_u64();
let a_dft_test_digest: u64 = a_dft_test.digest_u64();
for res_size in [1, 2, 3, 4] {
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
res.fill_uniform(basek, &mut source);
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, res_size);
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, res_size);
for j in 0..cols {
module_ref.vec_znx_dft_apply(1, 0, &mut res_dft_ref, j, &res, j);
module_test.vec_znx_dft_apply(1, 0, &mut res_dft_test, j, &res, j);
}
// Fill output with garbage
source.fill_bytes(res_dft_ref.data_mut());
source.fill_bytes(res_dft_test.data_mut());
for j in 0..cols {
module_ref.svp_apply_dft_to_dft_add(&mut res_dft_ref, j, &svp_ref, j, &a_dft_ref, j);
module_test.svp_apply_dft_to_dft_add(&mut res_dft_test, j, &svp_test, j, &a_dft_test, j);
}
// Assert no change to inputs
assert_eq!(a_dft_ref.digest_u64(), a_dft_ref_digest);
assert_eq!(a_dft_test.digest_u64(), a_dft_test_digest);
assert_eq!(svp_ref.digest_u64(), svp_ref_digest);
assert_eq!(svp_test.digest_u64(), svp_test_digest);
assert_eq!(a.digest_u64(), a_digest);
let res_big_ref: crate::layouts::VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
let res_big_test: crate::layouts::VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
&mut res_ref,
j,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
&mut res_test,
j,
&res_big_test,
j,
scratch_test.borrow(),
);
}
assert_eq!(res_ref, res_test);
}
}
}
pub fn test_svp_apply_dft_to_dft_inplace<BR: Backend, BT: Backend>(
basek: usize,
module_ref: &Module<BR>,
module_test: &Module<BT>,
) where
Module<BR>: SvpPrepare<BR>
+ SvpApplyDftToDftInplace<BR>
+ SvpPPolAlloc<BR>
+ VecZnxDftAlloc<BR>
+ VecZnxBigNormalize<BR>
+ VecZnxDftApply<BR>
+ VecZnxIdftApplyConsume<BR>
+ VecZnxBigNormalizeTmpBytes,
Module<BT>: SvpPrepare<BT>
+ SvpApplyDftToDftInplace<BT>
+ SvpPPolAlloc<BT>
+ VecZnxDftAlloc<BT>
+ VecZnxBigNormalize<BT>
+ VecZnxDftApply<BT>
+ VecZnxIdftApplyConsume<BT>
+ VecZnxBigNormalizeTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
let n: usize = module_ref.n();
let cols: usize = 2;
let mut source: Source = Source::new([0u8; 32]);
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes());
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
let mut scalar: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n, cols);
scalar.fill_uniform(basek, &mut source);
let scalar_digest: u64 = scalar.digest_u64();
let mut svp_ref: SvpPPol<Vec<u8>, BR> = module_ref.svp_ppol_alloc(cols);
let mut svp_test: SvpPPol<Vec<u8>, BT> = module_test.svp_ppol_alloc(cols);
for j in 0..cols {
module_ref.svp_prepare(&mut svp_ref, j, &scalar, j);
module_test.svp_prepare(&mut svp_test, j, &scalar, j);
}
assert_eq!(scalar.digest_u64(), scalar_digest);
let svp_ref_digest: u64 = svp_ref.digest_u64();
let svp_test_digest: u64 = svp_test.digest_u64();
for res_size in [1, 2, 3, 4] {
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
res.fill_uniform(basek, &mut source);
let res_digest: u64 = res.digest_u64();
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, res_size);
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, res_size);
for j in 0..cols {
module_ref.vec_znx_dft_apply(1, 0, &mut res_dft_ref, j, &res, j);
module_test.vec_znx_dft_apply(1, 0, &mut res_dft_test, j, &res, j);
}
assert_eq!(res.digest_u64(), res_digest);
for j in 0..cols {
module_ref.svp_apply_dft_to_dft_inplace(&mut res_dft_ref, j, &svp_ref, j);
module_test.svp_apply_dft_to_dft_inplace(&mut res_dft_test, j, &svp_test, j);
}
// Assert no change to inputs
assert_eq!(svp_ref.digest_u64(), svp_ref_digest);
assert_eq!(svp_test.digest_u64(), svp_test_digest);
let res_big_ref: crate::layouts::VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
let res_big_test: crate::layouts::VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
&mut res_ref,
j,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
&mut res_test,
j,
&res_big_test,
j,
scratch_test.borrow(),
);
}
println!("res_ref: {}", res_ref);
println!("res_test: {}", res_test);
assert_eq!(res_ref, res_test);
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,930 @@
use rand::RngCore;
use crate::{
api::{
ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxBigAlloc, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAdd,
VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftApply, VecZnxDftCopy, VecZnxDftSub, VecZnxDftSubABInplace,
VecZnxDftSubBAInplace, VecZnxIdftApply, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxIdftApplyTmpBytes,
},
layouts::{Backend, DataViewMut, DigestU64, FillUniform, Module, ScratchOwned, VecZnx, VecZnxBig, VecZnxDft},
source::Source,
};
pub fn test_vec_znx_dft_add<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxDftAdd<BR>
+ VecZnxDftAlloc<BR>
+ VecZnxDftApply<BR>
+ VecZnxIdftApplyConsume<BR>
+ VecZnxBigNormalize<BR>
+ VecZnxBigNormalizeTmpBytes,
Module<BT>: VecZnxDftAdd<BT>
+ VecZnxDftAlloc<BT>
+ VecZnxDftApply<BT>
+ VecZnxIdftApplyConsume<BT>
+ VecZnxBigNormalize<BT>
+ VecZnxBigNormalizeTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
let n: usize = module_ref.n();
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes());
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
let a_digest = a.digest_u64();
let mut a_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
let mut a_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, a_size);
for j in 0..cols {
module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j);
module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j);
}
assert_eq!(a.digest_u64(), a_digest);
let a_dft_ref_digest: u64 = a_dft_ref.digest_u64();
let a_dft_test_digest: u64 = a_dft_test.digest_u64();
for b_size in [1, 2, 3, 4] {
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, b_size);
b.fill_uniform(basek, &mut source);
let b_digest: u64 = b.digest_u64();
let mut b_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, b_size);
let mut b_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, b_size);
for j in 0..cols {
module_ref.vec_znx_dft_apply(1, 0, &mut b_dft_ref, j, &b, j);
module_test.vec_znx_dft_apply(1, 0, &mut b_dft_test, j, &b, j);
}
assert_eq!(b.digest_u64(), b_digest);
let b_dft_ref_digest: u64 = b_dft_ref.digest_u64();
let b_dft_test_digest: u64 = b_dft_test.digest_u64();
for res_size in [1, 2, 3, 4] {
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, res_size);
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, res_size);
// Set d to garbage
source.fill_bytes(res_dft_ref.data_mut());
source.fill_bytes(res_dft_test.data_mut());
// Reference
for i in 0..cols {
module_ref.vec_znx_dft_add(&mut res_dft_ref, i, &a_dft_ref, i, &b_dft_ref, i);
module_test.vec_znx_dft_add(&mut res_dft_test, i, &a_dft_test, i, &b_dft_test, i);
}
assert_eq!(a_dft_ref.digest_u64(), a_dft_ref_digest);
assert_eq!(a_dft_test.digest_u64(), a_dft_test_digest);
assert_eq!(b_dft_ref.digest_u64(), b_dft_ref_digest);
assert_eq!(b_dft_test.digest_u64(), b_dft_test_digest);
let res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
let res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let res_ref_digest: u64 = res_big_ref.digest_u64();
let res_test_digest: u64 = res_big_test.digest_u64();
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
&mut res_small_ref,
j,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
&mut res_small_test,
j,
&res_big_test,
j,
scratch_test.borrow(),
);
}
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
assert_eq!(res_big_test.digest_u64(), res_test_digest);
assert_eq!(res_small_ref, res_small_test);
}
}
}
}
pub fn test_vec_znx_dft_add_inplace<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxDftAddInplace<BR>
+ VecZnxDftAlloc<BR>
+ VecZnxDftApply<BR>
+ VecZnxIdftApplyConsume<BR>
+ VecZnxBigNormalize<BR>
+ VecZnxBigNormalizeTmpBytes,
Module<BT>: VecZnxDftAddInplace<BT>
+ VecZnxDftAlloc<BT>
+ VecZnxDftApply<BT>
+ VecZnxIdftApplyConsume<BT>
+ VecZnxBigNormalize<BT>
+ VecZnxBigNormalizeTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
let n: usize = module_ref.n();
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes());
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
let a_digest = a.digest_u64();
let mut a_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
let mut a_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, a_size);
for j in 0..cols {
module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j);
module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j);
}
assert_eq!(a.digest_u64(), a_digest);
let a_dft_ref_digest: u64 = a_dft_ref.digest_u64();
let a_dft_test_digest: u64 = a_dft_test.digest_u64();
for res_size in [1, 2, 3, 4] {
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
res.fill_uniform(basek, &mut source);
let res_digest: u64 = res.digest_u64();
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, res_size);
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, res_size);
for j in 0..cols {
module_ref.vec_znx_dft_apply(1, 0, &mut res_dft_ref, j, &res, j);
module_test.vec_znx_dft_apply(1, 0, &mut res_dft_test, j, &res, j);
}
assert_eq!(res.digest_u64(), res_digest);
// Reference
for i in 0..cols {
module_ref.vec_znx_dft_add_inplace(&mut res_dft_ref, i, &a_dft_ref, i);
module_test.vec_znx_dft_add_inplace(&mut res_dft_test, i, &a_dft_test, i);
}
assert_eq!(a_dft_ref.digest_u64(), a_dft_ref_digest);
assert_eq!(a_dft_test.digest_u64(), a_dft_test_digest);
let res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
let res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let res_ref_digest: u64 = res_big_ref.digest_u64();
let res_test_digest: u64 = res_big_test.digest_u64();
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
&mut res_small_ref,
j,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
&mut res_small_test,
j,
&res_big_test,
j,
scratch_test.borrow(),
);
}
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
assert_eq!(res_big_test.digest_u64(), res_test_digest);
assert_eq!(res_small_ref, res_small_test);
}
}
}
pub fn test_vec_znx_copy<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxDftCopy<BR>
+ VecZnxDftAlloc<BR>
+ VecZnxDftApply<BR>
+ VecZnxIdftApplyConsume<BR>
+ VecZnxBigNormalize<BR>
+ VecZnxBigNormalizeTmpBytes,
Module<BT>: VecZnxDftCopy<BT>
+ VecZnxDftAlloc<BT>
+ VecZnxDftApply<BT>
+ VecZnxIdftApplyConsume<BT>
+ VecZnxBigNormalize<BT>
+ VecZnxBigNormalizeTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
let n: usize = module_ref.n();
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes());
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
for a_size in [1, 2, 6, 11] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
let a_digest = a.digest_u64();
let mut a_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
let mut a_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, a_size);
for j in 0..cols {
module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j);
module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j);
}
assert_eq!(a.digest_u64(), a_digest);
let a_dft_ref_digest: u64 = a_dft_ref.digest_u64();
let a_dft_test_digest: u64 = a_dft_test.digest_u64();
for res_size in [1, 2, 6, 11] {
for params in [[1, 0], [1, 1], [1, 2], [2, 2]] {
let steps: usize = params[0];
let offset: usize = params[1];
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, res_size);
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, res_size);
// Set d to garbage
source.fill_bytes(res_dft_ref.data_mut());
source.fill_bytes(res_dft_test.data_mut());
// Reference
for i in 0..cols {
module_ref.vec_znx_dft_copy(steps, offset, &mut res_dft_ref, i, &a_dft_ref, i);
module_test.vec_znx_dft_copy(steps, offset, &mut res_dft_test, i, &a_dft_test, i);
}
assert_eq!(a_dft_ref.digest_u64(), a_dft_ref_digest);
assert_eq!(a_dft_test.digest_u64(), a_dft_test_digest);
let res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
let res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let res_ref_digest: u64 = res_big_ref.digest_u64();
let res_test_digest: u64 = res_big_test.digest_u64();
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
&mut res_small_ref,
j,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
&mut res_small_test,
j,
&res_big_test,
j,
scratch_test.borrow(),
);
}
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
assert_eq!(res_big_test.digest_u64(), res_test_digest);
assert_eq!(res_small_ref, res_small_test);
}
}
}
}
pub fn test_vec_znx_idft_apply<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxDftApply<BR>
+ VecZnxDftAlloc<BR>
+ VecZnxBigAlloc<BR>
+ VecZnxBigNormalize<BR>
+ VecZnxBigNormalizeTmpBytes
+ VecZnxIdftApply<BR>,
Module<BT>: VecZnxDftApply<BT>
+ VecZnxDftAlloc<BT>
+ VecZnxBigAlloc<BT>
+ VecZnxBigNormalize<BT>
+ VecZnxBigNormalizeTmpBytes
+ VecZnxIdftApply<BT>,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
let n: usize = module_ref.n();
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes());
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
for params in [[1, 0], [1, 1], [1, 2], [2, 2]] {
let steps: usize = params[0];
let offset: usize = params[1];
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, a_size);
for j in 0..cols {
module_ref.vec_znx_dft_apply(steps, offset, &mut res_dft_ref, j, &a, j);
module_test.vec_znx_dft_apply(steps, offset, &mut res_dft_test, j, &a, j);
}
assert_eq!(a.digest_u64(), a_digest);
let res_dft_ref_digest: u64 = res_dft_ref.digest_u64();
let rest_dft_test_digest: u64 = res_dft_test.digest_u64();
let mut res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_big_alloc(cols, res_size);
let mut res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_big_alloc(cols, res_size);
for j in 0..cols {
module_ref.vec_znx_idft_apply(&mut res_big_ref, j, &res_dft_ref, j, scratch_ref.borrow());
module_test.vec_znx_idft_apply(
&mut res_big_test,
j,
&res_dft_test,
j,
scratch_test.borrow(),
);
}
assert_eq!(res_dft_ref.digest_u64(), res_dft_ref_digest);
assert_eq!(res_dft_test.digest_u64(), rest_dft_test_digest);
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let res_ref_digest: u64 = res_big_ref.digest_u64();
let res_test_digest: u64 = res_big_test.digest_u64();
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
&mut res_small_ref,
j,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
&mut res_small_test,
j,
&res_big_test,
j,
scratch_test.borrow(),
);
}
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
assert_eq!(res_big_test.digest_u64(), res_test_digest);
assert_eq!(res_small_ref, res_small_test);
}
}
}
}
pub fn test_vec_znx_idft_apply_tmpa<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxDftApply<BR>
+ VecZnxDftAlloc<BR>
+ VecZnxBigAlloc<BR>
+ VecZnxBigNormalize<BR>
+ VecZnxBigNormalizeTmpBytes
+ VecZnxIdftApplyTmpA<BR>,
Module<BT>: VecZnxDftApply<BT>
+ VecZnxDftAlloc<BT>
+ VecZnxBigAlloc<BT>
+ VecZnxBigNormalize<BT>
+ VecZnxBigNormalizeTmpBytes
+ VecZnxIdftApplyTmpA<BT>,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
let n: usize = module_ref.n();
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes());
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
for params in [[1, 0], [1, 1], [1, 2], [2, 2]] {
let steps: usize = params[0];
let offset: usize = params[1];
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, a_size);
for j in 0..cols {
module_ref.vec_znx_dft_apply(steps, offset, &mut res_dft_ref, j, &a, j);
module_test.vec_znx_dft_apply(steps, offset, &mut res_dft_test, j, &a, j);
}
assert_eq!(a.digest_u64(), a_digest);
let mut res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_big_alloc(cols, res_size);
let mut res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_big_alloc(cols, res_size);
for j in 0..cols {
module_ref.vec_znx_idft_apply_tmpa(&mut res_big_ref, j, &mut res_dft_ref, j);
module_test.vec_znx_idft_apply_tmpa(&mut res_big_test, j, &mut res_dft_test, j);
}
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let res_ref_digest: u64 = res_big_ref.digest_u64();
let res_test_digest: u64 = res_big_test.digest_u64();
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
&mut res_small_ref,
j,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
&mut res_small_test,
j,
&res_big_test,
j,
scratch_test.borrow(),
);
}
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
assert_eq!(res_big_test.digest_u64(), res_test_digest);
assert_eq!(res_small_ref, res_small_test);
}
}
}
}
pub fn test_vec_znx_idft_apply_consume<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxDftApply<BR>
+ VecZnxIdftApplyTmpBytes
+ VecZnxDftAlloc<BR>
+ VecZnxBigNormalize<BR>
+ VecZnxBigNormalizeTmpBytes
+ VecZnxIdftApplyConsume<BR>,
Module<BT>: VecZnxDftApply<BT>
+ VecZnxIdftApplyTmpBytes
+ VecZnxDftAlloc<BT>
+ VecZnxBigNormalize<BT>
+ VecZnxBigNormalizeTmpBytes
+ VecZnxIdftApplyConsume<BT>,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
let n: usize = module_ref.n();
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let mut scratch_ref: ScratchOwned<BR> =
ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes() | module_ref.vec_znx_idft_apply_tmp_bytes());
let mut scratch_test: ScratchOwned<BT> =
ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes() | module_test.vec_znx_idft_apply_tmp_bytes());
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
for params in [[1, 0], [1, 1], [1, 2], [2, 2]] {
let steps: usize = params[0];
let offset: usize = params[1];
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, a_size);
for j in 0..cols {
module_ref.vec_znx_dft_apply(steps, offset, &mut res_dft_ref, j, &a, j);
module_test.vec_znx_dft_apply(steps, offset, &mut res_dft_test, j, &a, j);
}
assert_eq!(a.digest_u64(), a_digest);
let res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
let res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let res_ref_digest: u64 = res_big_ref.digest_u64();
let res_test_digest: u64 = res_big_test.digest_u64();
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
&mut res_small_ref,
j,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
&mut res_small_test,
j,
&res_big_test,
j,
scratch_test.borrow(),
);
}
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
assert_eq!(res_big_test.digest_u64(), res_test_digest);
assert_eq!(res_small_ref, res_small_test);
}
}
}
}
pub fn test_vec_znx_dft_sub<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxDftSub<BR>
+ VecZnxDftAlloc<BR>
+ VecZnxDftApply<BR>
+ VecZnxIdftApplyConsume<BR>
+ VecZnxBigNormalize<BR>
+ VecZnxBigNormalizeTmpBytes,
Module<BT>: VecZnxDftSub<BT>
+ VecZnxDftAlloc<BT>
+ VecZnxDftApply<BT>
+ VecZnxIdftApplyConsume<BT>
+ VecZnxBigNormalize<BT>
+ VecZnxBigNormalizeTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
let n: usize = module_ref.n();
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes());
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
let a_digest = a.digest_u64();
let mut a_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
let mut a_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, a_size);
for j in 0..cols {
module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j);
module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j);
}
assert_eq!(a.digest_u64(), a_digest);
let a_dft_ref_digest: u64 = a_dft_ref.digest_u64();
let a_dft_test_digest: u64 = a_dft_test.digest_u64();
for b_size in [1, 2, 3, 4] {
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, b_size);
b.fill_uniform(basek, &mut source);
let b_digest: u64 = b.digest_u64();
let mut b_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, b_size);
let mut b_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, b_size);
for j in 0..cols {
module_ref.vec_znx_dft_apply(1, 0, &mut b_dft_ref, j, &b, j);
module_test.vec_znx_dft_apply(1, 0, &mut b_dft_test, j, &b, j);
}
assert_eq!(b.digest_u64(), b_digest);
let b_dft_ref_digest: u64 = b_dft_ref.digest_u64();
let b_dft_test_digest: u64 = b_dft_test.digest_u64();
for res_size in [1, 2, 3, 4] {
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, res_size);
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, res_size);
// Set d to garbage
source.fill_bytes(res_dft_ref.data_mut());
source.fill_bytes(res_dft_test.data_mut());
// Reference
for i in 0..cols {
module_ref.vec_znx_dft_sub(&mut res_dft_ref, i, &a_dft_ref, i, &b_dft_ref, i);
module_test.vec_znx_dft_sub(&mut res_dft_test, i, &a_dft_test, i, &b_dft_test, i);
}
assert_eq!(a_dft_ref.digest_u64(), a_dft_ref_digest);
assert_eq!(a_dft_test.digest_u64(), a_dft_test_digest);
assert_eq!(b_dft_ref.digest_u64(), b_dft_ref_digest);
assert_eq!(b_dft_test.digest_u64(), b_dft_test_digest);
let res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
let res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let res_ref_digest: u64 = res_big_ref.digest_u64();
let res_test_digest: u64 = res_big_test.digest_u64();
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
&mut res_small_ref,
j,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
&mut res_small_test,
j,
&res_big_test,
j,
scratch_test.borrow(),
);
}
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
assert_eq!(res_big_test.digest_u64(), res_test_digest);
assert_eq!(res_small_ref, res_small_test);
}
}
}
}
pub fn test_vec_znx_dft_sub_ab_inplace<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxDftSubABInplace<BR>
+ VecZnxDftAlloc<BR>
+ VecZnxDftApply<BR>
+ VecZnxIdftApplyConsume<BR>
+ VecZnxBigNormalize<BR>
+ VecZnxBigNormalizeTmpBytes,
Module<BT>: VecZnxDftSubABInplace<BT>
+ VecZnxDftAlloc<BT>
+ VecZnxDftApply<BT>
+ VecZnxIdftApplyConsume<BT>
+ VecZnxBigNormalize<BT>
+ VecZnxBigNormalizeTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
let n: usize = module_ref.n();
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes());
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
let a_digest = a.digest_u64();
let mut a_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
let mut a_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, a_size);
for j in 0..cols {
module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j);
module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j);
}
assert_eq!(a.digest_u64(), a_digest);
let a_dft_ref_digest: u64 = a_dft_ref.digest_u64();
let a_dft_test_digest: u64 = a_dft_test.digest_u64();
for res_size in [1, 2, 3, 4] {
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
res.fill_uniform(basek, &mut source);
let res_digest: u64 = res.digest_u64();
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, res_size);
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, res_size);
for j in 0..cols {
module_ref.vec_znx_dft_apply(1, 0, &mut res_dft_ref, j, &res, j);
module_test.vec_znx_dft_apply(1, 0, &mut res_dft_test, j, &res, j);
}
assert_eq!(res.digest_u64(), res_digest);
// Reference
for i in 0..cols {
module_ref.vec_znx_dft_sub_ab_inplace(&mut res_dft_ref, i, &a_dft_ref, i);
module_test.vec_znx_dft_sub_ab_inplace(&mut res_dft_test, i, &a_dft_test, i);
}
assert_eq!(a_dft_ref.digest_u64(), a_dft_ref_digest);
assert_eq!(a_dft_test.digest_u64(), a_dft_test_digest);
let res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
let res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let res_ref_digest: u64 = res_big_ref.digest_u64();
let res_test_digest: u64 = res_big_test.digest_u64();
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
&mut res_small_ref,
j,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
&mut res_small_test,
j,
&res_big_test,
j,
scratch_test.borrow(),
);
}
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
assert_eq!(res_big_test.digest_u64(), res_test_digest);
assert_eq!(res_small_ref, res_small_test);
}
}
}
pub fn test_vec_znx_dft_sub_ba_inplace<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxDftSubBAInplace<BR>
+ VecZnxDftAlloc<BR>
+ VecZnxDftApply<BR>
+ VecZnxIdftApplyConsume<BR>
+ VecZnxBigNormalize<BR>
+ VecZnxBigNormalizeTmpBytes,
Module<BT>: VecZnxDftSubBAInplace<BT>
+ VecZnxDftAlloc<BT>
+ VecZnxDftApply<BT>
+ VecZnxIdftApplyConsume<BT>
+ VecZnxBigNormalize<BT>
+ VecZnxBigNormalizeTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
let n: usize = module_ref.n();
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes());
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
let a_digest = a.digest_u64();
let mut a_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
let mut a_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, a_size);
for j in 0..cols {
module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j);
module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j);
}
assert_eq!(a.digest_u64(), a_digest);
let a_dft_ref_digest: u64 = a_dft_ref.digest_u64();
let a_dft_test_digest: u64 = a_dft_test.digest_u64();
for res_size in [1, 2, 3, 4] {
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
res.fill_uniform(basek, &mut source);
let res_digest: u64 = res.digest_u64();
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, res_size);
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, res_size);
for j in 0..cols {
module_ref.vec_znx_dft_apply(1, 0, &mut res_dft_ref, j, &res, j);
module_test.vec_znx_dft_apply(1, 0, &mut res_dft_test, j, &res, j);
}
assert_eq!(res.digest_u64(), res_digest);
// Reference
for i in 0..cols {
module_ref.vec_znx_dft_sub_ba_inplace(&mut res_dft_ref, i, &a_dft_ref, i);
module_test.vec_znx_dft_sub_ba_inplace(&mut res_dft_test, i, &a_dft_test, i);
}
assert_eq!(a_dft_ref.digest_u64(), a_dft_ref_digest);
assert_eq!(a_dft_test.digest_u64(), a_dft_test_digest);
let res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
let res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let res_ref_digest: u64 = res_big_ref.digest_u64();
let res_test_digest: u64 = res_big_test.digest_u64();
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
&mut res_small_ref,
j,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
&mut res_small_test,
j,
&res_big_test,
j,
scratch_test.borrow(),
);
}
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
assert_eq!(res_big_test.digest_u64(), res_test_digest);
assert_eq!(res_small_ref, res_small_test);
}
}
}

View File

@@ -0,0 +1,384 @@
use crate::{
api::{
ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxBigNormalize, VecZnxDftAlloc, VecZnxDftApply,
VecZnxIdftApplyConsume, VmpApplyDft, VmpApplyDftTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd,
VmpApplyDftToDftAddTmpBytes, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare,
},
layouts::{DataViewMut, DigestU64, FillUniform, MatZnx, Module, ScratchOwned, VecZnx, VecZnxBig},
source::Source,
};
use rand::RngCore;
use crate::layouts::{Backend, VecZnxDft, VmpPMat};
pub fn test_vmp_apply_dft<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: ModuleNew<BR>
+ VmpApplyDftTmpBytes
+ VmpApplyDft<BR>
+ VmpPMatAlloc<BR>
+ VecZnxDftAlloc<BR>
+ VmpPrepare<BR>
+ VecZnxDftAlloc<BR>
+ VecZnxIdftApplyConsume<BR>
+ VecZnxBigNormalize<BR>,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
Module<BT>: ModuleNew<BT>
+ VmpApplyDftTmpBytes
+ VmpApplyDft<BT>
+ VmpPMatAlloc<BT>
+ VecZnxDftAlloc<BT>
+ VmpPrepare<BT>
+ VecZnxDftAlloc<BT>
+ VecZnxIdftApplyConsume<BT>
+ VecZnxBigNormalize<BT>,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
let n: usize = module_ref.n();
let max_size: usize = 4;
let max_cols: usize = 2;
let mut source: Source = Source::new([0u8; 32]);
let mut scratch_ref: ScratchOwned<BR> =
ScratchOwned::alloc(module_ref.vmp_apply_dft_tmp_bytes(max_size, max_size, max_size, max_cols, max_cols, max_size));
let mut scratch_test: ScratchOwned<BT> =
ScratchOwned::alloc(module_test.vmp_apply_dft_tmp_bytes(max_size, max_size, max_size, max_cols, max_cols, max_size));
for cols_in in 1..max_cols + 1 {
for cols_out in 1..max_cols + 1 {
for size_in in 1..max_size + 1 {
for size_out in 1..max_size + 1 {
let rows: usize = cols_in;
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols_in, size_in);
a.fill_uniform(basek, &mut source);
let a_digest: u64 = a.digest_u64();
let mut mat: MatZnx<Vec<u8>> = MatZnx::alloc(n, rows, cols_in, cols_out, size_out);
mat.fill_uniform(basek, &mut source);
let mat_digest: u64 = mat.digest_u64();
let mut pmat_ref: VmpPMat<Vec<u8>, BR> = module_ref.vmp_pmat_alloc(rows, cols_in, cols_out, size_out);
let mut pmat_test: VmpPMat<Vec<u8>, BT> = module_test.vmp_pmat_alloc(rows, cols_in, cols_out, size_out);
module_ref.vmp_prepare(&mut pmat_ref, &mat, scratch_ref.borrow());
module_test.vmp_prepare(&mut pmat_test, &mat, scratch_test.borrow());
assert_eq!(mat.digest_u64(), mat_digest);
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols_out, size_out);
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols_out, size_out);
source.fill_bytes(res_dft_ref.data_mut());
source.fill_bytes(res_dft_test.data_mut());
module_ref.vmp_apply_dft(&mut res_dft_ref, &a, &pmat_ref, scratch_ref.borrow());
module_test.vmp_apply_dft(&mut res_dft_test, &a, &pmat_test, scratch_test.borrow());
assert_eq!(a.digest_u64(), a_digest);
let res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
let res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols_out, size_out);
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols_out, size_out);
let res_ref_digest: u64 = res_big_ref.digest_u64();
let res_test_digest: u64 = res_big_test.digest_u64();
for j in 0..cols_out {
module_ref.vec_znx_big_normalize(
basek,
&mut res_small_ref,
j,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
&mut res_small_test,
j,
&res_big_test,
j,
scratch_test.borrow(),
);
}
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
assert_eq!(res_big_test.digest_u64(), res_test_digest);
assert_eq!(res_small_ref, res_small_test);
}
}
}
}
}
pub fn test_vmp_apply_dft_to_dft<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: ModuleNew<BR>
+ VmpApplyDftToDftTmpBytes
+ VmpApplyDftToDft<BR>
+ VmpPMatAlloc<BR>
+ VecZnxDftAlloc<BR>
+ VmpPrepare<BR>
+ VecZnxIdftApplyConsume<BR>
+ VecZnxBigNormalize<BR>
+ VecZnxDftApply<BR>,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
Module<BT>: ModuleNew<BT>
+ VmpApplyDftToDftTmpBytes
+ VmpApplyDftToDft<BT>
+ VmpPMatAlloc<BT>
+ VecZnxDftAlloc<BT>
+ VmpPrepare<BT>
+ VecZnxIdftApplyConsume<BT>
+ VecZnxBigNormalize<BT>
+ VecZnxDftApply<BT>,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
let n: usize = module_ref.n();
let max_size: usize = 4;
let max_cols: usize = 2;
let mut source: Source = Source::new([0u8; 32]);
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(
module_ref.vmp_apply_dft_to_dft_tmp_bytes(max_size, max_size, max_size, max_cols, max_cols, max_size),
);
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(
module_test.vmp_apply_dft_to_dft_tmp_bytes(max_size, max_size, max_size, max_cols, max_cols, max_size),
);
for cols_in in 1..max_cols + 1 {
for cols_out in 1..max_cols + 1 {
for size_in in 1..max_size + 1 {
for size_out in 1..max_size + 1 {
let rows: usize = size_in;
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols_in, size_in);
a.fill_uniform(basek, &mut source);
let a_digest: u64 = a.digest_u64();
let mut a_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols_in, size_in);
let mut a_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols_in, size_in);
for j in 0..cols_in {
module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j);
module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j);
}
assert_eq!(a.digest_u64(), a_digest);
let mut mat: MatZnx<Vec<u8>> = MatZnx::alloc(n, rows, cols_in, cols_out, size_out);
mat.fill_uniform(basek, &mut source);
let mat_digest: u64 = mat.digest_u64();
let mut pmat_ref: VmpPMat<Vec<u8>, BR> = module_ref.vmp_pmat_alloc(rows, cols_in, cols_out, size_out);
let mut pmat_test: VmpPMat<Vec<u8>, BT> = module_test.vmp_pmat_alloc(rows, cols_in, cols_out, size_out);
module_ref.vmp_prepare(&mut pmat_ref, &mat, scratch_ref.borrow());
module_test.vmp_prepare(&mut pmat_test, &mat, scratch_test.borrow());
assert_eq!(mat.digest_u64(), mat_digest);
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols_out, size_out);
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols_out, size_out);
source.fill_bytes(res_dft_ref.data_mut());
source.fill_bytes(res_dft_test.data_mut());
module_ref.vmp_apply_dft_to_dft(
&mut res_dft_ref,
&a_dft_ref,
&pmat_ref,
scratch_ref.borrow(),
);
module_test.vmp_apply_dft_to_dft(
&mut res_dft_test,
&a_dft_test,
&pmat_test,
scratch_test.borrow(),
);
let res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
let res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols_out, size_out);
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols_out, size_out);
let res_ref_digest: u64 = res_big_ref.digest_u64();
let res_test_digest: u64 = res_big_test.digest_u64();
for j in 0..cols_out {
module_ref.vec_znx_big_normalize(
basek,
&mut res_small_ref,
j,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
&mut res_small_test,
j,
&res_big_test,
j,
scratch_test.borrow(),
);
}
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
assert_eq!(res_big_test.digest_u64(), res_test_digest);
assert_eq!(res_small_ref, res_small_test);
}
}
}
}
}
pub fn test_vmp_apply_dft_to_dft_add<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: ModuleNew<BR>
+ VmpApplyDftToDftAddTmpBytes
+ VmpApplyDftToDftAdd<BR>
+ VmpPMatAlloc<BR>
+ VecZnxDftAlloc<BR>
+ VmpPrepare<BR>
+ VecZnxIdftApplyConsume<BR>
+ VecZnxBigNormalize<BR>
+ VecZnxDftApply<BR>,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
Module<BT>: ModuleNew<BT>
+ VmpApplyDftToDftAddTmpBytes
+ VmpApplyDftToDftAdd<BT>
+ VmpPMatAlloc<BT>
+ VecZnxDftAlloc<BT>
+ VmpPrepare<BT>
+ VecZnxIdftApplyConsume<BT>
+ VecZnxBigNormalize<BT>
+ VecZnxDftApply<BT>,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
let n: usize = module_ref.n();
let max_size: usize = 4;
let max_cols: usize = 2;
let mut source: Source = Source::new([0u8; 32]);
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(
module_ref.vmp_apply_dft_to_dft_add_tmp_bytes(max_size, max_size, max_size, max_cols, max_cols, max_size),
);
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(
module_test.vmp_apply_dft_to_dft_add_tmp_bytes(max_size, max_size, max_size, max_cols, max_cols, max_size),
);
for cols_in in 1..max_cols + 1 {
for cols_out in 1..max_cols + 1 {
for size_in in 1..max_size + 1 {
for size_out in 1..max_size + 1 {
let rows: usize = size_in;
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols_in, size_in);
a.fill_uniform(basek, &mut source);
let a_digest: u64 = a.digest_u64();
let mut a_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols_in, size_in);
let mut a_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols_in, size_in);
for j in 0..cols_in {
module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j);
module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j);
}
assert_eq!(a.digest_u64(), a_digest);
let mut mat: MatZnx<Vec<u8>> = MatZnx::alloc(n, rows, cols_in, cols_out, size_out);
mat.fill_uniform(basek, &mut source);
let mat_digest: u64 = mat.digest_u64();
let mut pmat_ref: VmpPMat<Vec<u8>, BR> = module_ref.vmp_pmat_alloc(rows, cols_in, cols_out, size_out);
let mut pmat_test: VmpPMat<Vec<u8>, BT> = module_test.vmp_pmat_alloc(rows, cols_in, cols_out, size_out);
module_ref.vmp_prepare(&mut pmat_ref, &mat, scratch_ref.borrow());
module_test.vmp_prepare(&mut pmat_test, &mat, scratch_test.borrow());
assert_eq!(mat.digest_u64(), mat_digest);
for limb_offset in 0..size_out {
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols_out, size_out);
res.fill_uniform(basek, &mut source);
let res_digest: u64 = res.digest_u64();
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols_out, size_out);
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols_out, size_out);
for j in 0..cols_out {
module_ref.vec_znx_dft_apply(1, 0, &mut res_dft_ref, j, &res, j);
module_test.vec_znx_dft_apply(1, 0, &mut res_dft_test, j, &res, j);
}
assert_eq!(res.digest_u64(), res_digest);
module_ref.vmp_apply_dft_to_dft_add(
&mut res_dft_ref,
&a_dft_ref,
&pmat_ref,
limb_offset * cols_out,
scratch_ref.borrow(),
);
module_test.vmp_apply_dft_to_dft_add(
&mut res_dft_test,
&a_dft_test,
&pmat_test,
limb_offset * cols_out,
scratch_test.borrow(),
);
let res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
let res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols_out, size_out);
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols_out, size_out);
let res_ref_digest: u64 = res_big_ref.digest_u64();
let res_test_digest: u64 = res_big_test.digest_u64();
for j in 0..cols_out {
module_ref.vec_znx_big_normalize(
basek,
&mut res_small_ref,
j,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
&mut res_small_test,
j,
&res_big_test,
j,
scratch_test.borrow(),
);
}
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
assert_eq!(res_big_test.digest_u64(), res_test_digest);
assert_eq!(res_small_ref, res_small_test);
}
}
}
}
}
}