mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 21:26:41 +01:00
Ref. + AVX code & generic tests + benches (#85)
This commit is contained in:
committed by
GitHub
parent
99b9e3e10e
commit
56dbd29c59
@@ -1,17 +1,14 @@
|
||||
use poulpy_hal::{
|
||||
api::TakeSlice,
|
||||
layouts::{Scratch, Zn, ZnToMut, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut},
|
||||
oep::{
|
||||
TakeSliceImpl, ZnAddDistF64Impl, ZnAddNormalImpl, ZnFillDistF64Impl, ZnFillNormalImpl, ZnFillUniformImpl,
|
||||
ZnNormalizeInplaceImpl,
|
||||
},
|
||||
oep::{TakeSliceImpl, ZnAddNormalImpl, ZnFillNormalImpl, ZnFillUniformImpl, ZnNormalizeInplaceImpl},
|
||||
reference::zn::{zn_add_normal, zn_fill_normal, zn_fill_uniform},
|
||||
source::Source,
|
||||
};
|
||||
use rand_distr::Normal;
|
||||
|
||||
use crate::cpu_spqlios::{FFT64, ffi::zn64};
|
||||
use crate::cpu_spqlios::{FFT64Spqlios, ffi::zn64};
|
||||
|
||||
unsafe impl ZnNormalizeInplaceImpl<Self> for FFT64
|
||||
unsafe impl ZnNormalizeInplaceImpl<Self> for FFT64Spqlios
|
||||
where
|
||||
Self: TakeSliceImpl<Self>,
|
||||
{
|
||||
@@ -39,113 +36,17 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ZnFillUniformImpl<Self> for FFT64 {
|
||||
fn zn_fill_uniform_impl<R>(n: usize, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
|
||||
unsafe impl ZnFillUniformImpl<Self> for FFT64Spqlios {
|
||||
fn zn_fill_uniform_impl<R>(n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
|
||||
where
|
||||
R: ZnToMut,
|
||||
{
|
||||
let mut a: Zn<&mut [u8]> = res.to_mut();
|
||||
let base2k: u64 = 1 << basek;
|
||||
let mask: u64 = base2k - 1;
|
||||
let base2k_half: i64 = (base2k >> 1) as i64;
|
||||
(0..k.div_ceil(basek)).for_each(|j| {
|
||||
a.at_mut(res_col, j)[..n]
|
||||
.iter_mut()
|
||||
.for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half);
|
||||
})
|
||||
zn_fill_uniform(n, basek, res, res_col, source);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ZnFillDistF64Impl<Self> for FFT64 {
|
||||
fn zn_fill_dist_f64_impl<R, D: rand::prelude::Distribution<f64>>(
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
let mut a: Zn<&mut [u8]> = res.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = k.div_ceil(basek) - 1;
|
||||
let basek_rem: usize = (limb + 1) * basek - k;
|
||||
|
||||
if basek_rem != 0 {
|
||||
a.at_mut(res_col, limb)[..n].iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a = (dist_f64.round() as i64) << basek_rem;
|
||||
});
|
||||
} else {
|
||||
a.at_mut(res_col, limb)[..n].iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a = dist_f64.round() as i64
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ZnAddDistF64Impl<Self> for FFT64 {
|
||||
fn zn_add_dist_f64_impl<R, D: rand::prelude::Distribution<f64>>(
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
let mut a: Zn<&mut [u8]> = res.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = k.div_ceil(basek) - 1;
|
||||
let basek_rem: usize = (limb + 1) * basek - k;
|
||||
|
||||
if basek_rem != 0 {
|
||||
a.at_mut(res_col, limb)[..n].iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a += (dist_f64.round() as i64) << basek_rem;
|
||||
});
|
||||
} else {
|
||||
a.at_mut(res_col, limb)[..n].iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a += dist_f64.round() as i64
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ZnFillNormalImpl<Self> for FFT64
|
||||
where
|
||||
Self: ZnFillDistF64Impl<Self>,
|
||||
{
|
||||
unsafe impl ZnFillNormalImpl<Self> for FFT64Spqlios {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn zn_fill_normal_impl<R>(
|
||||
n: usize,
|
||||
basek: usize,
|
||||
@@ -158,23 +59,12 @@ where
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
Self::zn_fill_dist_f64_impl(
|
||||
n,
|
||||
basek,
|
||||
res,
|
||||
res_col,
|
||||
k,
|
||||
source,
|
||||
Normal::new(0.0, sigma).unwrap(),
|
||||
bound,
|
||||
);
|
||||
zn_fill_normal(n, basek, res, res_col, k, source, sigma, bound);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ZnAddNormalImpl<Self> for FFT64
|
||||
where
|
||||
Self: ZnAddDistF64Impl<Self>,
|
||||
{
|
||||
unsafe impl ZnAddNormalImpl<Self> for FFT64Spqlios {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn zn_add_normal_impl<R>(
|
||||
n: usize,
|
||||
basek: usize,
|
||||
@@ -187,15 +77,6 @@ where
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
Self::zn_add_dist_f64_impl(
|
||||
n,
|
||||
basek,
|
||||
res,
|
||||
res_col,
|
||||
k,
|
||||
source,
|
||||
Normal::new(0.0, sigma).unwrap(),
|
||||
bound,
|
||||
);
|
||||
zn_add_normal(n, basek, res, res_col, k, source, sigma, bound);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user