Ref. + AVX code & generic tests + benches (#85)

This commit is contained in:
Jean-Philippe Bossuat
2025-09-15 16:16:11 +02:00
committed by GitHub
parent 99b9e3e10e
commit 56dbd29c59
286 changed files with 27797 additions and 7270 deletions

View File

@@ -1,17 +1,14 @@
use poulpy_hal::{
api::TakeSlice,
layouts::{Scratch, Zn, ZnToMut, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut},
oep::{
TakeSliceImpl, ZnAddDistF64Impl, ZnAddNormalImpl, ZnFillDistF64Impl, ZnFillNormalImpl, ZnFillUniformImpl,
ZnNormalizeInplaceImpl,
},
oep::{TakeSliceImpl, ZnAddNormalImpl, ZnFillNormalImpl, ZnFillUniformImpl, ZnNormalizeInplaceImpl},
reference::zn::{zn_add_normal, zn_fill_normal, zn_fill_uniform},
source::Source,
};
use rand_distr::Normal;
use crate::cpu_spqlios::{FFT64, ffi::zn64};
use crate::cpu_spqlios::{FFT64Spqlios, ffi::zn64};
unsafe impl ZnNormalizeInplaceImpl<Self> for FFT64
unsafe impl ZnNormalizeInplaceImpl<Self> for FFT64Spqlios
where
Self: TakeSliceImpl<Self>,
{
@@ -39,113 +36,17 @@ where
}
}
unsafe impl ZnFillUniformImpl<Self> for FFT64 {
fn zn_fill_uniform_impl<R>(n: usize, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
unsafe impl ZnFillUniformImpl<Self> for FFT64Spqlios {
fn zn_fill_uniform_impl<R>(n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: ZnToMut,
{
let mut a: Zn<&mut [u8]> = res.to_mut();
let base2k: u64 = 1 << basek;
let mask: u64 = base2k - 1;
let base2k_half: i64 = (base2k >> 1) as i64;
(0..k.div_ceil(basek)).for_each(|j| {
a.at_mut(res_col, j)[..n]
.iter_mut()
.for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half);
})
zn_fill_uniform(n, basek, res, res_col, source);
}
}
unsafe impl ZnFillDistF64Impl<Self> for FFT64 {
fn zn_fill_dist_f64_impl<R, D: rand::prelude::Distribution<f64>>(
n: usize,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) where
R: ZnToMut,
{
let mut a: Zn<&mut [u8]> = res.to_mut();
assert!(
(bound.log2().ceil() as i64) < 64,
"invalid bound: ceil(log2(bound))={} > 63",
(bound.log2().ceil() as i64)
);
let limb: usize = k.div_ceil(basek) - 1;
let basek_rem: usize = (limb + 1) * basek - k;
if basek_rem != 0 {
a.at_mut(res_col, limb)[..n].iter_mut().for_each(|a| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
}
*a = (dist_f64.round() as i64) << basek_rem;
});
} else {
a.at_mut(res_col, limb)[..n].iter_mut().for_each(|a| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
}
*a = dist_f64.round() as i64
});
}
}
}
unsafe impl ZnAddDistF64Impl<Self> for FFT64 {
fn zn_add_dist_f64_impl<R, D: rand::prelude::Distribution<f64>>(
n: usize,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) where
R: ZnToMut,
{
let mut a: Zn<&mut [u8]> = res.to_mut();
assert!(
(bound.log2().ceil() as i64) < 64,
"invalid bound: ceil(log2(bound))={} > 63",
(bound.log2().ceil() as i64)
);
let limb: usize = k.div_ceil(basek) - 1;
let basek_rem: usize = (limb + 1) * basek - k;
if basek_rem != 0 {
a.at_mut(res_col, limb)[..n].iter_mut().for_each(|a| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
}
*a += (dist_f64.round() as i64) << basek_rem;
});
} else {
a.at_mut(res_col, limb)[..n].iter_mut().for_each(|a| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
}
*a += dist_f64.round() as i64
});
}
}
}
unsafe impl ZnFillNormalImpl<Self> for FFT64
where
Self: ZnFillDistF64Impl<Self>,
{
unsafe impl ZnFillNormalImpl<Self> for FFT64Spqlios {
#[allow(clippy::too_many_arguments)]
fn zn_fill_normal_impl<R>(
n: usize,
basek: usize,
@@ -158,23 +59,12 @@ where
) where
R: ZnToMut,
{
Self::zn_fill_dist_f64_impl(
n,
basek,
res,
res_col,
k,
source,
Normal::new(0.0, sigma).unwrap(),
bound,
);
zn_fill_normal(n, basek, res, res_col, k, source, sigma, bound);
}
}
unsafe impl ZnAddNormalImpl<Self> for FFT64
where
Self: ZnAddDistF64Impl<Self>,
{
unsafe impl ZnAddNormalImpl<Self> for FFT64Spqlios {
#[allow(clippy::too_many_arguments)]
fn zn_add_normal_impl<R>(
n: usize,
basek: usize,
@@ -187,15 +77,6 @@ where
) where
R: ZnToMut,
{
Self::zn_add_dist_f64_impl(
n,
basek,
res,
res_col,
k,
source,
Normal::new(0.0, sigma).unwrap(),
bound,
);
zn_add_normal(n, basek, res, res_col, k, source, sigma, bound);
}
}