This commit is contained in:
Jean-Philippe Bossuat
2025-04-25 15:24:09 +02:00
parent 90b34e171d
commit 2a96f89047
16 changed files with 864 additions and 895 deletions

View File

@@ -1,16 +1,17 @@
use crate::{Infos, Module, VecZnx};
use crate::{Backend, Infos, Module, VecZnx};
use rand_distr::{Distribution, Normal};
use sampling::source::Source;
pub trait Sampling {
/// Fills the first `cols` cols with uniform values in \[-2^{log_base2k-1}, 2^{log_base2k-1}\]
fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, cols: usize, source: &mut Source);
/// Fills the first `limbs` limbs with uniform values in \[-2^{log_base2k-1}, 2^{log_base2k-1}\]
fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_i: usize, limbs: usize, source: &mut Source);
/// Adds vector sampled according to the provided distribution, scaled by 2^{-log_k} and bounded to \[-bound, bound\].
fn add_dist_f64<D: Distribution<f64>>(
&self,
log_base2k: usize,
a: &mut VecZnx,
col_i: usize,
log_k: usize,
source: &mut Source,
dist: D,
@@ -18,24 +19,35 @@ pub trait Sampling {
);
/// Adds a discrete normal vector scaled by 2^{-log_k} with the provided standard deviation and bounded to \[-bound, bound\].
fn add_normal(&self, log_base2k: usize, a: &mut VecZnx, log_k: usize, source: &mut Source, sigma: f64, bound: f64);
fn add_normal(
&self,
log_base2k: usize,
a: &mut VecZnx,
col_i: usize,
log_k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
);
}
impl Sampling for Module {
fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, cols: usize, source: &mut Source) {
impl<B: Backend> Sampling for Module<B> {
fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_i: usize, limbs: usize, source: &mut Source) {
let base2k: u64 = 1 << log_base2k;
let mask: u64 = base2k - 1;
let base2k_half: i64 = (base2k >> 1) as i64;
let size: usize = a.n() * cols;
a.raw_mut()[..size]
.iter_mut()
.for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half);
(0..limbs).for_each(|j| {
a.at_poly_mut(col_i, j)
.iter_mut()
.for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half);
})
}
fn add_dist_f64<D: Distribution<f64>>(
&self,
log_base2k: usize,
a: &mut VecZnx,
col_i: usize,
log_k: usize,
source: &mut Source,
dist: D,
@@ -50,28 +62,42 @@ impl Sampling for Module {
let log_base2k_rem: usize = log_k % log_base2k;
if log_base2k_rem != 0 {
a.at_mut(a.cols() - 1).iter_mut().for_each(|a| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
}
*a += (dist_f64.round() as i64) << log_base2k_rem;
});
a.at_poly_mut(col_i, a.limbs() - 1)
.iter_mut()
.for_each(|a| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
}
*a += (dist_f64.round() as i64) << log_base2k_rem;
});
} else {
a.at_mut(a.cols() - 1).iter_mut().for_each(|a| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
}
*a += dist_f64.round() as i64
});
a.at_poly_mut(col_i, a.limbs() - 1)
.iter_mut()
.for_each(|a| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
}
*a += dist_f64.round() as i64
});
}
}
fn add_normal(&self, log_base2k: usize, a: &mut VecZnx, log_k: usize, source: &mut Source, sigma: f64, bound: f64) {
fn add_normal(
&self,
log_base2k: usize,
a: &mut VecZnx,
col_i: usize,
log_k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) {
self.add_dist_f64(
log_base2k,
a,
col_i,
log_k,
source,
Normal::new(0.0, sigma).unwrap(),