rework as discussed

This commit is contained in:
Jean-Philippe Bossuat
2025-05-05 17:35:35 +02:00
parent bd105497fd
commit ffa363804b
16 changed files with 1154 additions and 1153 deletions

View File

@@ -1,53 +1,47 @@
use crate::znx_base::ZnxViewMut;
use crate::{Backend, Module, VecZnx};
use crate::{Backend, Module, VecZnx, VecZnxToMut};
use rand_distr::{Distribution, Normal};
use sampling::source::Source;
pub trait Sampling {
/// Fills the first `size` size with uniform values in \[-2^{log_base2k-1}, 2^{log_base2k-1}\]
fn fill_uniform<DataMut: AsMut<[u8]> + AsRef<[u8]>>(
&self,
log_base2k: usize,
a: &mut VecZnx<DataMut>,
col_i: usize,
size: usize,
source: &mut Source,
);
fn fill_uniform<A>(&self, log_base2k: usize, a: &mut A, col_i: usize, size: usize, source: &mut Source)
where
A: VecZnxToMut;
/// Adds vector sampled according to the provided distribution, scaled by 2^{-log_k} and bounded to \[-bound, bound\].
fn add_dist_f64<DataMut: AsMut<[u8]> + AsRef<[u8]>, D: Distribution<f64>>(
fn add_dist_f64<A, D: Distribution<f64>>(
&self,
log_base2k: usize,
a: &mut VecZnx<DataMut>,
a: &mut A,
col_i: usize,
log_k: usize,
source: &mut Source,
dist: D,
bound: f64,
);
) where
A: VecZnxToMut;
/// Adds a discrete normal vector scaled by 2^{-log_k} with the provided standard deviation and bounded to \[-bound, bound\].
fn add_normal<DataMut: AsMut<[u8]> + AsRef<[u8]>>(
fn add_normal<A>(
&self,
log_base2k: usize,
a: &mut VecZnx<DataMut>,
a: &mut A,
col_i: usize,
log_k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
);
) where
A: VecZnxToMut;
}
impl<B: Backend> Sampling for Module<B> {
fn fill_uniform<DataMut: AsMut<[u8]> + AsRef<[u8]>>(
&self,
log_base2k: usize,
a: &mut VecZnx<DataMut>,
col_i: usize,
size: usize,
source: &mut Source,
) {
fn fill_uniform<A>(&self, log_base2k: usize, a: &mut A, col_i: usize, size: usize, source: &mut Source)
where
A: VecZnxToMut,
{
let mut a: VecZnx<&mut [u8]> = a.to_mut();
let base2k: u64 = 1 << log_base2k;
let mask: u64 = base2k - 1;
let base2k_half: i64 = (base2k >> 1) as i64;
@@ -58,16 +52,19 @@ impl<B: Backend> Sampling for Module<B> {
})
}
fn add_dist_f64<DataMut: AsMut<[u8]> + AsRef<[u8]>, D: Distribution<f64>>(
fn add_dist_f64<A, D: Distribution<f64>>(
&self,
log_base2k: usize,
a: &mut VecZnx<DataMut>,
a: &mut A,
col_i: usize,
log_k: usize,
source: &mut Source,
dist: D,
bound: f64,
) {
) where
A: VecZnxToMut,
{
let mut a: VecZnx<&mut [u8]> = a.to_mut();
assert!(
(bound.log2().ceil() as i64) < 64,
"invalid bound: ceil(log2(bound))={} > 63",
@@ -96,16 +93,10 @@ impl<B: Backend> Sampling for Module<B> {
}
}
fn add_normal<DataMut: AsMut<[u8]> + AsRef<[u8]>>(
&self,
log_base2k: usize,
a: &mut VecZnx<DataMut>,
col_i: usize,
log_k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) {
fn add_normal<A>(&self, log_base2k: usize, a: &mut A, col_i: usize, log_k: usize, source: &mut Source, sigma: f64, bound: f64)
where
A: VecZnxToMut,
{
self.add_dist_f64(
log_base2k,
a,