Added more serialization tests + generalize methods to any n

This commit is contained in:
Pro7ech
2025-08-13 15:28:52 +02:00
parent 068470783e
commit 940742ce6c
117 changed files with 3658 additions and 2577 deletions

View File

@@ -14,16 +14,16 @@ use crate::{
VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxSwithcDegree, ZnxInfos, ZnxSliceSize, ZnxView,
ZnxViewMut, ZnxZero,
},
layouts::{Backend, Module, ScalarZnx, ScalarZnxToRef, Scratch, VecZnx, VecZnxOwned, VecZnxToMut, VecZnxToRef},
layouts::{Backend, Module, ScalarZnx, ScalarZnxToRef, Scratch, VecZnx, VecZnxToMut, VecZnxToRef},
oep::{
VecZnxAddDistF64Impl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarInplaceImpl,
VecZnxAllocBytesImpl, VecZnxAllocImpl, VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl,
VecZnxDecodeCoeffsi64Impl, VecZnxDecodeVecFloatImpl, VecZnxDecodeVeci64Impl, VecZnxEncodeCoeffsi64Impl,
VecZnxEncodeVeci64Impl, VecZnxFillDistF64Impl, VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxFromBytesImpl,
VecZnxLshInplaceImpl, VecZnxMergeImpl, VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl,
VecZnxNegateInplaceImpl, VecZnxNormalizeImpl, VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl,
VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl, VecZnxSplitImpl, VecZnxStdImpl,
VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl,
VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, VecZnxDecodeCoeffsi64Impl,
VecZnxDecodeVecFloatImpl, VecZnxDecodeVeci64Impl, VecZnxEncodeCoeffsi64Impl, VecZnxEncodeVeci64Impl,
VecZnxFillDistF64Impl, VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxLshInplaceImpl, VecZnxMergeImpl,
VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl,
VecZnxNormalizeImpl, VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl,
VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl, VecZnxSplitImpl, VecZnxStdImpl, VecZnxSubABInplaceImpl,
VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl,
},
},
implementation::cpu_spqlios::{
@@ -32,33 +32,6 @@ use crate::{
},
};
unsafe impl<B: Backend> VecZnxAllocImpl<B> for B
where
B: CPUAVX,
{
fn vec_znx_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxOwned {
VecZnxOwned::alloc::<i64>(n, cols, size)
}
}
unsafe impl<B: Backend> VecZnxFromBytesImpl<B> for B
where
B: CPUAVX,
{
fn vec_znx_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxOwned {
VecZnxOwned::from_bytes::<i64>(n, cols, size, bytes)
}
}
unsafe impl<B: Backend> VecZnxAllocBytesImpl<B> for B
where
B: CPUAVX,
{
fn vec_znx_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
VecZnxOwned::alloc_bytes::<i64>(n, cols, size)
}
}
unsafe impl<B: Backend> VecZnxNormalizeTmpBytesImpl<B> for B
where
B: CPUAVX,
@@ -156,9 +129,8 @@ where
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(b.n(), module.n());
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), res.n());
assert_eq!(b.n(), res.n());
assert_ne!(a.as_ptr(), b.as_ptr());
}
unsafe {
@@ -192,8 +164,7 @@ where
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), res.n());
}
unsafe {
vec_znx::vec_znx_add(
@@ -232,8 +203,7 @@ where
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), res.n());
}
unsafe {
@@ -269,9 +239,8 @@ where
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(b.n(), module.n());
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), res.n());
assert_eq!(b.n(), res.n());
assert_ne!(a.as_ptr(), b.as_ptr());
}
unsafe {
@@ -304,8 +273,7 @@ where
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), res.n());
}
unsafe {
vec_znx::vec_znx_sub(
@@ -337,8 +305,7 @@ where
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), res.n());
}
unsafe {
vec_znx::vec_znx_sub(
@@ -377,8 +344,7 @@ where
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), res.n());
}
unsafe {
@@ -411,8 +377,7 @@ where
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), res.n());
}
unsafe {
vec_znx::vec_znx_negate(
@@ -437,10 +402,6 @@ where
A: VecZnxToMut,
{
let mut a: VecZnx<&mut [u8]> = a.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
}
unsafe {
vec_znx::vec_znx_negate(
module.ptr() as *const module_info_t,
@@ -604,8 +565,7 @@ where
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), res.n());
}
unsafe {
vec_znx::vec_znx_automorphism(
@@ -633,7 +593,6 @@ where
let mut a: VecZnx<&mut [u8]> = a.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert!(
k & 1 != 0,
"invalid galois element: must be odd but is {}",
@@ -668,8 +627,8 @@ where
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), res.n());
assert_eq!(res.n(), res.n());
}
unsafe {
vec_znx::vec_znx_mul_xp_minus_one(
@@ -697,7 +656,7 @@ where
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), module.n());
assert_eq!(res.n(), res.n());
}
unsafe {
vec_znx::vec_znx_mul_xp_minus_one(
@@ -749,7 +708,7 @@ pub fn vec_znx_split_ref<R, A, B: Backend>(
let (n_in, n_out) = (a.n(), res[0].to_mut().n());
let (mut buf, _) = scratch.take_vec_znx(module, 1, a.size());
let (mut buf, _) = scratch.take_vec_znx(n_in.max(n_out), 1, a.size());
debug_assert!(
n_out < n_in,