everything compiles. Scratchpad not yet implemented

This commit is contained in:
Janmajaya Mall
2025-05-03 16:37:20 +05:30
parent 3ed6fa8ab5
commit ff8370e023
19 changed files with 919 additions and 504 deletions

View File

@@ -1,16 +1,24 @@
use crate::{Backend, Module, VecZnx, znx_base::ZnxLayout};
use crate::znx_base::ZnxViewMut;
use crate::{Backend, Module, VecZnx};
use rand_distr::{Distribution, Normal};
use sampling::source::Source;
pub trait Sampling {
/// Fills the first `size` size with uniform values in \[-2^{log_base2k-1}, 2^{log_base2k-1}\]
fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_i: usize, size: usize, source: &mut Source);
/// Adds vector sampled according to the provided distribution, scaled by 2^{-log_k} and bounded to \[-bound, bound\].
fn add_dist_f64<D: Distribution<f64>>(
fn fill_uniform<DataMut: AsMut<[u8]> + AsRef<[u8]>>(
&self,
log_base2k: usize,
a: &mut VecZnx,
a: &mut VecZnx<DataMut>,
col_i: usize,
size: usize,
source: &mut Source,
);
/// Adds vector sampled according to the provided distribution, scaled by 2^{-log_k} and bounded to \[-bound, bound\].
fn add_dist_f64<DataMut: AsMut<[u8]> + AsRef<[u8]>, D: Distribution<f64>>(
&self,
log_base2k: usize,
a: &mut VecZnx<DataMut>,
col_i: usize,
log_k: usize,
source: &mut Source,
@@ -19,10 +27,10 @@ pub trait Sampling {
);
/// Adds a discrete normal vector scaled by 2^{-log_k} with the provided standard deviation and bounded to \[-bound, bound\].
fn add_normal(
fn add_normal<DataMut: AsMut<[u8]> + AsRef<[u8]>>(
&self,
log_base2k: usize,
a: &mut VecZnx,
a: &mut VecZnx<DataMut>,
col_i: usize,
log_k: usize,
source: &mut Source,
@@ -32,22 +40,29 @@ pub trait Sampling {
}
impl<B: Backend> Sampling for Module<B> {
fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_a: usize, size: usize, source: &mut Source) {
fn fill_uniform<DataMut: AsMut<[u8]> + AsRef<[u8]>>(
&self,
log_base2k: usize,
a: &mut VecZnx<DataMut>,
col_i: usize,
size: usize,
source: &mut Source,
) {
let base2k: u64 = 1 << log_base2k;
let mask: u64 = base2k - 1;
let base2k_half: i64 = (base2k >> 1) as i64;
(0..size).for_each(|j| {
a.at_mut(col_a, j)
a.at_mut(col_i, j)
.iter_mut()
.for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half);
})
}
fn add_dist_f64<D: Distribution<f64>>(
fn add_dist_f64<DataMut: AsMut<[u8]> + AsRef<[u8]>, D: Distribution<f64>>(
&self,
log_base2k: usize,
a: &mut VecZnx,
col_a: usize,
a: &mut VecZnx<DataMut>,
col_i: usize,
log_k: usize,
source: &mut Source,
dist: D,
@@ -63,7 +78,7 @@ impl<B: Backend> Sampling for Module<B> {
let log_base2k_rem: usize = log_k % log_base2k;
if log_base2k_rem != 0 {
a.at_mut(col_a, limb).iter_mut().for_each(|a| {
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
@@ -71,7 +86,7 @@ impl<B: Backend> Sampling for Module<B> {
*a += (dist_f64.round() as i64) << log_base2k_rem;
});
} else {
a.at_mut(col_a, limb).iter_mut().for_each(|a| {
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
@@ -81,11 +96,11 @@ impl<B: Backend> Sampling for Module<B> {
}
}
fn add_normal(
fn add_normal<DataMut: AsMut<[u8]> + AsRef<[u8]>>(
&self,
log_base2k: usize,
a: &mut VecZnx,
col_a: usize,
a: &mut VecZnx<DataMut>,
col_i: usize,
log_k: usize,
source: &mut Source,
sigma: f64,
@@ -94,7 +109,7 @@ impl<B: Backend> Sampling for Module<B> {
self.add_dist_f64(
log_base2k,
a,
col_a,
col_i,
log_k,
source,
Normal::new(0.0, sigma).unwrap(),
@@ -106,7 +121,9 @@ impl<B: Backend> Sampling for Module<B> {
#[cfg(test)]
mod tests {
use super::Sampling;
use crate::{FFT64, Module, Stats, VecZnx, VecZnxOps, znx_base::ZnxLayout};
use crate::vec_znx_ops::*;
use crate::znx_base::*;
use crate::{FFT64, Module, Stats, VecZnx};
use sampling::source::Source;
#[test]
@@ -120,7 +137,7 @@ mod tests {
let zero: Vec<i64> = vec![0; n];
let one_12_sqrt: f64 = 0.28867513459481287;
(0..cols).for_each(|col_i| {
let mut a: VecZnx = module.new_vec_znx(cols, size);
let mut a: VecZnx<_> = module.new_vec_znx(cols, size);
module.fill_uniform(log_base2k, &mut a, col_i, size, &mut source);
(0..cols).for_each(|col_j| {
if col_j != col_i {
@@ -154,7 +171,7 @@ mod tests {
let zero: Vec<i64> = vec![0; n];
let k_f64: f64 = (1u64 << log_k as u64) as f64;
(0..cols).for_each(|col_i| {
let mut a: VecZnx = module.new_vec_znx(cols, size);
let mut a: VecZnx<_> = module.new_vec_znx(cols, size);
module.add_normal(log_base2k, &mut a, col_i, log_k, &mut source, sigma, bound);
(0..cols).for_each(|col_j| {
if col_j != col_i {