refactoring of vec_znx

This commit is contained in:
Jean-Philippe Bossuat
2025-04-28 10:33:15 +02:00
parent 39bbe5b917
commit 2f9a1cf6d9
13 changed files with 1218 additions and 738 deletions

View File

@@ -3,8 +3,8 @@ use rand_distr::{Distribution, Normal};
use sampling::source::Source;
pub trait Sampling {
/// Fills the first `limbs` limbs with uniform values in \[-2^{log_base2k-1}, 2^{log_base2k-1}\]
fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_i: usize, limbs: usize, source: &mut Source);
/// Fills the first `size` size with uniform values in \[-2^{log_base2k-1}, 2^{log_base2k-1}\]
fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_i: usize, size: usize, source: &mut Source);
/// Adds vector sampled according to the provided distribution, scaled by 2^{-log_k} and bounded to \[-bound, bound\].
fn add_dist_f64<D: Distribution<f64>>(
@@ -32,11 +32,11 @@ pub trait Sampling {
}
impl<B: Backend> Sampling for Module<B> {
fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_i: usize, limbs: usize, source: &mut Source) {
fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_i: usize, size: usize, source: &mut Source) {
let base2k: u64 = 1 << log_base2k;
let mask: u64 = base2k - 1;
let base2k_half: i64 = (base2k >> 1) as i64;
(0..limbs).for_each(|j| {
(0..size).for_each(|j| {
a.at_poly_mut(col_i, j)
.iter_mut()
.for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half);
@@ -114,17 +114,17 @@ mod tests {
let n: usize = 4096;
let module: Module<FFT64> = Module::<FFT64>::new(n);
let log_base2k: usize = 17;
let limbs: usize = 5;
let size: usize = 5;
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let zero: Vec<i64> = vec![0; n];
let one_12_sqrt: f64 = 0.28867513459481287;
(0..cols).for_each(|col_i| {
let mut a: VecZnx = VecZnx::new(&module, cols, limbs);
module.fill_uniform(log_base2k, &mut a, col_i, limbs, &mut source);
let mut a: VecZnx = VecZnx::new(&module, cols, size);
module.fill_uniform(log_base2k, &mut a, col_i, size, &mut source);
(0..cols).for_each(|col_j| {
if col_j != col_i {
(0..limbs).for_each(|limb_i| {
(0..size).for_each(|limb_i| {
assert_eq!(a.at_poly(col_j, limb_i), zero);
})
} else {
@@ -146,7 +146,7 @@ mod tests {
let module: Module<FFT64> = Module::<FFT64>::new(n);
let log_base2k: usize = 17;
let log_k: usize = 2 * 17;
let limbs: usize = 5;
let size: usize = 5;
let sigma: f64 = 3.2;
let bound: f64 = 6.0 * sigma;
let mut source: Source = Source::new([0u8; 32]);
@@ -154,11 +154,11 @@ mod tests {
let zero: Vec<i64> = vec![0; n];
let k_f64: f64 = (1u64 << log_k as u64) as f64;
(0..cols).for_each(|col_i| {
let mut a: VecZnx = VecZnx::new(&module, cols, limbs);
let mut a: VecZnx = VecZnx::new(&module, cols, size);
module.add_normal(log_base2k, &mut a, col_i, log_k, &mut source, sigma, bound);
(0..cols).for_each(|col_j| {
if col_j != col_i {
(0..limbs).for_each(|limb_i| {
(0..size).for_each(|limb_i| {
assert_eq!(a.at_poly(col_j, limb_i), zero);
})
} else {