mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
Added more serialization tests + generalize methods to any n
This commit is contained in:
@@ -14,16 +14,16 @@ use crate::{
|
||||
VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxSwithcDegree, ZnxInfos, ZnxSliceSize, ZnxView,
|
||||
ZnxViewMut, ZnxZero,
|
||||
},
|
||||
layouts::{Backend, Module, ScalarZnx, ScalarZnxToRef, Scratch, VecZnx, VecZnxOwned, VecZnxToMut, VecZnxToRef},
|
||||
layouts::{Backend, Module, ScalarZnx, ScalarZnxToRef, Scratch, VecZnx, VecZnxToMut, VecZnxToRef},
|
||||
oep::{
|
||||
VecZnxAddDistF64Impl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarInplaceImpl,
|
||||
VecZnxAllocBytesImpl, VecZnxAllocImpl, VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl,
|
||||
VecZnxDecodeCoeffsi64Impl, VecZnxDecodeVecFloatImpl, VecZnxDecodeVeci64Impl, VecZnxEncodeCoeffsi64Impl,
|
||||
VecZnxEncodeVeci64Impl, VecZnxFillDistF64Impl, VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxFromBytesImpl,
|
||||
VecZnxLshInplaceImpl, VecZnxMergeImpl, VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl,
|
||||
VecZnxNegateInplaceImpl, VecZnxNormalizeImpl, VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl,
|
||||
VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl, VecZnxSplitImpl, VecZnxStdImpl,
|
||||
VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl,
|
||||
VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, VecZnxDecodeCoeffsi64Impl,
|
||||
VecZnxDecodeVecFloatImpl, VecZnxDecodeVeci64Impl, VecZnxEncodeCoeffsi64Impl, VecZnxEncodeVeci64Impl,
|
||||
VecZnxFillDistF64Impl, VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxLshInplaceImpl, VecZnxMergeImpl,
|
||||
VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl,
|
||||
VecZnxNormalizeImpl, VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl,
|
||||
VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl, VecZnxSplitImpl, VecZnxStdImpl, VecZnxSubABInplaceImpl,
|
||||
VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl,
|
||||
},
|
||||
},
|
||||
implementation::cpu_spqlios::{
|
||||
@@ -32,33 +32,6 @@ use crate::{
|
||||
},
|
||||
};
|
||||
|
||||
unsafe impl<B: Backend> VecZnxAllocImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn vec_znx_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxOwned {
|
||||
VecZnxOwned::alloc::<i64>(n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> VecZnxFromBytesImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn vec_znx_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxOwned {
|
||||
VecZnxOwned::from_bytes::<i64>(n, cols, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> VecZnxAllocBytesImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn vec_znx_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
|
||||
VecZnxOwned::alloc_bytes::<i64>(n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> VecZnxNormalizeTmpBytesImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
@@ -156,9 +129,8 @@ where
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(b.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(b.n(), res.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
@@ -192,8 +164,7 @@ where
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
@@ -232,8 +203,7 @@ where
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
|
||||
unsafe {
|
||||
@@ -269,9 +239,8 @@ where
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(b.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(b.n(), res.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
@@ -304,8 +273,7 @@ where
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
@@ -337,8 +305,7 @@ where
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
@@ -377,8 +344,7 @@ where
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
|
||||
unsafe {
|
||||
@@ -411,8 +377,7 @@ where
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_negate(
|
||||
@@ -437,10 +402,6 @@ where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
let mut a: VecZnx<&mut [u8]> = a.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_negate(
|
||||
module.ptr() as *const module_info_t,
|
||||
@@ -604,8 +565,7 @@ where
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
@@ -633,7 +593,6 @@ where
|
||||
let mut a: VecZnx<&mut [u8]> = a.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert!(
|
||||
k & 1 != 0,
|
||||
"invalid galois element: must be odd but is {}",
|
||||
@@ -668,8 +627,8 @@ where
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(res.n(), res.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_mul_xp_minus_one(
|
||||
@@ -697,7 +656,7 @@ where
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(res.n(), res.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_mul_xp_minus_one(
|
||||
@@ -749,7 +708,7 @@ pub fn vec_znx_split_ref<R, A, B: Backend>(
|
||||
|
||||
let (n_in, n_out) = (a.n(), res[0].to_mut().n());
|
||||
|
||||
let (mut buf, _) = scratch.take_vec_znx(module, 1, a.size());
|
||||
let (mut buf, _) = scratch.take_vec_znx(n_in.max(n_out), 1, a.size());
|
||||
|
||||
debug_assert!(
|
||||
n_out < n_in,
|
||||
|
||||
Reference in New Issue
Block a user