Ref. + AVX code & generic tests + benches (#85)

This commit is contained in:
Jean-Philippe Bossuat
2025-09-15 16:16:11 +02:00
committed by GitHub
parent 99b9e3e10e
commit 56dbd29c59
286 changed files with 27797 additions and 7270 deletions

View File

@@ -1,170 +1,98 @@
use rand_distr::{Distribution, Normal};
use crate::cpu_spqlios::{FFT64, ffi::vec_znx};
use crate::cpu_spqlios::{FFT64Spqlios, ffi::vec_znx};
use poulpy_hal::{
api::{TakeSlice, VecZnxBigAddDistF64, VecZnxBigFillDistF64, VecZnxBigNormalizeTmpBytes},
api::{TakeSlice, VecZnxBigNormalizeTmpBytes},
layouts::{
Backend, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef,
ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut,
},
oep::{
TakeSliceImpl, VecZnxBigAddDistF64Impl, VecZnxBigAddImpl, VecZnxBigAddInplaceImpl, VecZnxBigAddNormalImpl,
VecZnxBigAddSmallImpl, VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl,
VecZnxBigAutomorphismImpl, VecZnxBigAutomorphismInplaceImpl, VecZnxBigFillDistF64Impl, VecZnxBigFillNormalImpl,
VecZnxBigFromBytesImpl, VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl, VecZnxBigNormalizeTmpBytesImpl,
VecZnxBigSubABInplaceImpl, VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl, VecZnxBigSubSmallAImpl,
VecZnxBigSubSmallAInplaceImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl,
TakeSliceImpl, VecZnxBigAddImpl, VecZnxBigAddInplaceImpl, VecZnxBigAddNormalImpl, VecZnxBigAddSmallImpl,
VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl,
VecZnxBigAutomorphismInplaceImpl, VecZnxBigAutomorphismInplaceTmpBytesImpl, VecZnxBigFromBytesImpl,
VecZnxBigFromSmallImpl, VecZnxBigNegateImpl, VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl,
VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubABInplaceImpl, VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl,
VecZnxBigSubSmallAImpl, VecZnxBigSubSmallAInplaceImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl,
},
reference::{
vec_znx::vec_znx_add_normal_ref,
znx::{znx_copy_ref, znx_zero_ref},
},
source::Source,
};
unsafe impl VecZnxBigAllocBytesImpl<Self> for FFT64 {
unsafe impl VecZnxBigAllocBytesImpl<Self> for FFT64Spqlios {
fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
Self::layout_big_word_count() * n * cols * size * size_of::<f64>()
}
}
unsafe impl VecZnxBigAllocImpl<Self> for FFT64 {
unsafe impl VecZnxBigAllocImpl<Self> for FFT64Spqlios {
fn vec_znx_big_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxBigOwned<Self> {
VecZnxBig::alloc(n, cols, size)
}
}
unsafe impl VecZnxBigFromBytesImpl<Self> for FFT64 {
unsafe impl VecZnxBigFromBytesImpl<Self> for FFT64Spqlios {
fn vec_znx_big_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<Self> {
VecZnxBig::from_bytes(n, cols, size, bytes)
}
}
unsafe impl VecZnxBigAddDistF64Impl<Self> for FFT64 {
fn add_dist_f64_impl<R: VecZnxBigToMut<Self>, D: Distribution<f64>>(
_module: &Module<Self>,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) {
unsafe impl VecZnxBigFromSmallImpl<Self> for FFT64Spqlios {
fn vec_znx_big_from_small_impl<R, A>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<Self>,
A: VecZnxToRef,
{
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
assert!(
(bound.log2().ceil() as i64) < 64,
"invalid bound: ceil(log2(bound))={} > 63",
(bound.log2().ceil() as i64)
);
let a: VecZnx<&[u8]> = a.to_ref();
let limb: usize = k.div_ceil(basek) - 1;
let basek_rem: usize = (limb + 1) * basek - k;
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), a.n());
}
if basek_rem != 0 {
res.at_mut(res_col, limb).iter_mut().for_each(|x| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
}
*x += (dist_f64.round() as i64) << basek_rem;
});
} else {
res.at_mut(res_col, limb).iter_mut().for_each(|x| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
}
*x += dist_f64.round() as i64
});
let res_size: usize = res.size();
let a_size: usize = a.size();
let min_size: usize = res_size.min(a_size);
for j in 0..min_size {
znx_copy_ref(res.at_mut(res_col, j), a.at(a_col, j));
}
for j in min_size..res_size {
znx_zero_ref(res.at_mut(res_col, j));
}
}
}
unsafe impl VecZnxBigAddNormalImpl<Self> for FFT64 {
unsafe impl VecZnxBigAddNormalImpl<Self> for FFT64Spqlios {
fn add_normal_impl<R: VecZnxBigToMut<Self>>(
module: &Module<Self>,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) {
module.vec_znx_big_add_dist_f64(
basek,
res,
res_col,
k,
source,
Normal::new(0.0, sigma).unwrap(),
bound,
);
}
}
unsafe impl VecZnxBigFillDistF64Impl<Self> for FFT64 {
fn fill_dist_f64_impl<R: VecZnxBigToMut<Self>, D: Distribution<f64>>(
_module: &Module<Self>,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) {
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
assert!(
(bound.log2().ceil() as i64) < 64,
"invalid bound: ceil(log2(bound))={} > 63",
(bound.log2().ceil() as i64)
);
let limb: usize = k.div_ceil(basek) - 1;
let basek_rem: usize = (limb + 1) * basek - k;
if basek_rem != 0 {
res.at_mut(res_col, limb).iter_mut().for_each(|x| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
}
*x = (dist_f64.round() as i64) << basek_rem;
});
} else {
res.at_mut(res_col, limb).iter_mut().for_each(|x| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
}
*x = dist_f64.round() as i64
});
}
}
}
unsafe impl VecZnxBigFillNormalImpl<Self> for FFT64 {
fn fill_normal_impl<R: VecZnxBigToMut<Self>>(
module: &Module<Self>,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) {
module.vec_znx_big_fill_dist_f64(
basek,
res,
res_col,
k,
source,
Normal::new(0.0, sigma).unwrap(),
bound,
);
let res: VecZnxBig<&mut [u8], FFT64Spqlios> = res.to_mut();
let mut res_znx: VecZnx<&mut [u8]> = VecZnx {
data: res.data,
n: res.n,
cols: res.cols,
size: res.size,
max_size: res.max_size,
};
vec_znx_add_normal_ref(basek, &mut res_znx, res_col, k, sigma, bound, source);
}
}
unsafe impl VecZnxBigAddImpl<Self> for FFT64 {
unsafe impl VecZnxBigAddImpl<Self> for FFT64Spqlios {
/// Adds `a` to `b` and stores the result on `c`.
fn vec_znx_big_add_impl<R, A, B>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
@@ -199,7 +127,7 @@ unsafe impl VecZnxBigAddImpl<Self> for FFT64 {
}
}
unsafe impl VecZnxBigAddInplaceImpl<Self> for FFT64 {
unsafe impl VecZnxBigAddInplaceImpl<Self> for FFT64Spqlios {
/// Adds `a` to `b` and stores the result on `b`.
fn vec_znx_big_add_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
@@ -230,7 +158,7 @@ unsafe impl VecZnxBigAddInplaceImpl<Self> for FFT64 {
}
}
unsafe impl VecZnxBigAddSmallImpl<Self> for FFT64 {
unsafe impl VecZnxBigAddSmallImpl<Self> for FFT64Spqlios {
/// Adds `a` to `b` and stores the result on `c`.
fn vec_znx_big_add_small_impl<R, A, B>(
module: &Module<Self>,
@@ -272,7 +200,7 @@ unsafe impl VecZnxBigAddSmallImpl<Self> for FFT64 {
}
}
unsafe impl VecZnxBigAddSmallInplaceImpl<Self> for FFT64 {
unsafe impl VecZnxBigAddSmallInplaceImpl<Self> for FFT64Spqlios {
/// Adds `a` to `b` and stores the result on `b`.
fn vec_znx_big_add_small_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
@@ -303,7 +231,7 @@ unsafe impl VecZnxBigAddSmallInplaceImpl<Self> for FFT64 {
}
}
unsafe impl VecZnxBigSubImpl<Self> for FFT64 {
unsafe impl VecZnxBigSubImpl<Self> for FFT64Spqlios {
/// Subtracts `a` to `b` and stores the result on `c`.
fn vec_znx_big_sub_impl<R, A, B>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
@@ -338,7 +266,7 @@ unsafe impl VecZnxBigSubImpl<Self> for FFT64 {
}
}
unsafe impl VecZnxBigSubABInplaceImpl<Self> for FFT64 {
unsafe impl VecZnxBigSubABInplaceImpl<Self> for FFT64Spqlios {
/// Subtracts `a` from `b` and stores the result on `b`.
fn vec_znx_big_sub_ab_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
@@ -369,7 +297,7 @@ unsafe impl VecZnxBigSubABInplaceImpl<Self> for FFT64 {
}
}
unsafe impl VecZnxBigSubBAInplaceImpl<Self> for FFT64 {
unsafe impl VecZnxBigSubBAInplaceImpl<Self> for FFT64Spqlios {
/// Subtracts `b` from `a` and stores the result on `b`.
fn vec_znx_big_sub_ba_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
@@ -400,7 +328,7 @@ unsafe impl VecZnxBigSubBAInplaceImpl<Self> for FFT64 {
}
}
unsafe impl VecZnxBigSubSmallAImpl<Self> for FFT64 {
unsafe impl VecZnxBigSubSmallAImpl<Self> for FFT64Spqlios {
/// Subtracts `b` from `a` and stores the result on `c`.
fn vec_znx_big_sub_small_a_impl<R, A, B>(
module: &Module<Self>,
@@ -442,7 +370,7 @@ unsafe impl VecZnxBigSubSmallAImpl<Self> for FFT64 {
}
}
unsafe impl VecZnxBigSubSmallAInplaceImpl<Self> for FFT64 {
unsafe impl VecZnxBigSubSmallAInplaceImpl<Self> for FFT64Spqlios {
/// Subtracts `a` from `res` and stores the result on `res`.
fn vec_znx_big_sub_small_a_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
@@ -473,7 +401,7 @@ unsafe impl VecZnxBigSubSmallAInplaceImpl<Self> for FFT64 {
}
}
unsafe impl VecZnxBigSubSmallBImpl<Self> for FFT64 {
unsafe impl VecZnxBigSubSmallBImpl<Self> for FFT64Spqlios {
/// Subtracts `b` from `a` and stores the result on `c`.
fn vec_znx_big_sub_small_b_impl<R, A, B>(
module: &Module<Self>,
@@ -515,7 +443,7 @@ unsafe impl VecZnxBigSubSmallBImpl<Self> for FFT64 {
}
}
unsafe impl VecZnxBigSubSmallBInplaceImpl<Self> for FFT64 {
unsafe impl VecZnxBigSubSmallBInplaceImpl<Self> for FFT64Spqlios {
/// Subtracts `res` from `a` and stores the result on `res`.
fn vec_znx_big_sub_small_b_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
@@ -546,7 +474,29 @@ unsafe impl VecZnxBigSubSmallBInplaceImpl<Self> for FFT64 {
}
}
unsafe impl VecZnxBigNegateInplaceImpl<Self> for FFT64 {
unsafe impl VecZnxBigNegateImpl<Self> for FFT64Spqlios {
fn vec_znx_big_negate_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<Self>,
A: VecZnxBigToRef<Self>,
{
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
let a: VecZnxBig<&[u8], Self> = a.to_ref();
unsafe {
vec_znx::vec_znx_negate(
module.ptr(),
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
}
unsafe impl VecZnxBigNegateInplaceImpl<Self> for FFT64Spqlios {
fn vec_znx_big_negate_inplace_impl<A>(module: &Module<Self>, a: &mut A, a_col: usize)
where
A: VecZnxBigToMut<Self>,
@@ -566,13 +516,13 @@ unsafe impl VecZnxBigNegateInplaceImpl<Self> for FFT64 {
}
}
unsafe impl VecZnxBigNormalizeTmpBytesImpl<Self> for FFT64 {
unsafe impl VecZnxBigNormalizeTmpBytesImpl<Self> for FFT64Spqlios {
fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module<Self>) -> usize {
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr()) as usize }
}
}
unsafe impl VecZnxBigNormalizeImpl<Self> for FFT64
unsafe impl VecZnxBigNormalizeImpl<Self> for FFT64Spqlios
where
Self: TakeSliceImpl<Self>,
{
@@ -613,7 +563,7 @@ where
}
}
unsafe impl VecZnxBigAutomorphismImpl<Self> for FFT64 {
unsafe impl VecZnxBigAutomorphismImpl<Self> for FFT64Spqlios {
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
fn vec_znx_big_automorphism_impl<R, A>(module: &Module<Self>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
@@ -642,10 +592,21 @@ unsafe impl VecZnxBigAutomorphismImpl<Self> for FFT64 {
}
}
unsafe impl VecZnxBigAutomorphismInplaceImpl<Self> for FFT64 {
unsafe impl VecZnxBigAutomorphismInplaceTmpBytesImpl<Self> for FFT64Spqlios {
fn vec_znx_big_automorphism_inplace_tmp_bytes_impl(_module: &Module<Self>) -> usize {
0
}
}
unsafe impl VecZnxBigAutomorphismInplaceImpl<Self> for FFT64Spqlios {
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`.
fn vec_znx_big_automorphism_inplace_impl<A>(module: &Module<Self>, k: i64, a: &mut A, a_col: usize)
where
fn vec_znx_big_automorphism_inplace_impl<A>(
module: &Module<Self>,
k: i64,
a: &mut A,
a_col: usize,
_scratch: &mut Scratch<Self>,
) where
A: VecZnxBigToMut<Self>,
{
let mut a: VecZnxBig<&mut [u8], Self> = a.to_mut();