mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
Ref. + AVX code & generic tests + benches (#85)
This commit is contained in:
committed by
GitHub
parent
99b9e3e10e
commit
56dbd29c59
53
poulpy-hal/src/reference/znx/sampling.rs
Normal file
53
poulpy-hal/src/reference/znx/sampling.rs
Normal file
@@ -0,0 +1,53 @@
|
||||
use rand_distr::{Distribution, Normal};
|
||||
|
||||
use crate::source::Source;
|
||||
|
||||
pub fn znx_fill_uniform_ref(basek: usize, res: &mut [i64], source: &mut Source) {
|
||||
let pow2k: u64 = 1 << basek;
|
||||
let mask: u64 = pow2k - 1;
|
||||
let pow2k_half: i64 = (pow2k >> 1) as i64;
|
||||
res.iter_mut()
|
||||
.for_each(|xi| *xi = (source.next_u64n(pow2k, mask) as i64) - pow2k_half)
|
||||
}
|
||||
|
||||
pub fn znx_fill_dist_f64_ref<D: rand::prelude::Distribution<f64>>(res: &mut [i64], dist: D, bound: f64, source: &mut Source) {
|
||||
res.iter_mut().for_each(|xi| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*xi = dist_f64.round() as i64
|
||||
})
|
||||
}
|
||||
|
||||
pub fn znx_add_dist_f64_ref<D: rand::prelude::Distribution<f64>>(res: &mut [i64], dist: D, bound: f64, source: &mut Source) {
|
||||
res.iter_mut().for_each(|xi| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*xi += dist_f64.round() as i64
|
||||
})
|
||||
}
|
||||
|
||||
pub fn znx_fill_normal_f64_ref(res: &mut [i64], sigma: f64, bound: f64, source: &mut Source) {
|
||||
let normal: Normal<f64> = Normal::new(0.0, sigma).unwrap();
|
||||
res.iter_mut().for_each(|xi| {
|
||||
let mut dist_f64: f64 = normal.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = normal.sample(source)
|
||||
}
|
||||
*xi = dist_f64.round() as i64
|
||||
})
|
||||
}
|
||||
|
||||
pub fn znx_add_normal_f64_ref(res: &mut [i64], sigma: f64, bound: f64, source: &mut Source) {
|
||||
let normal: Normal<f64> = Normal::new(0.0, sigma).unwrap();
|
||||
res.iter_mut().for_each(|xi| {
|
||||
let mut dist_f64: f64 = normal.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = normal.sample(source)
|
||||
}
|
||||
*xi += dist_f64.round() as i64
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user