From 39bbe5b91704469c7956ed3adaad0dd8294113ed Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Mon, 28 Apr 2025 09:02:42 +0200 Subject: [PATCH] added tests for sampling (and indirectly stats) --- base2k/.vscode/settings.json | 8 +++ base2k/src/encoding.rs | 12 ++--- base2k/src/sampling.rs | 102 ++++++++++++++++++++++++++++------- base2k/src/stats.rs | 13 +++-- base2k/src/vec_znx_big.rs | 2 +- 5 files changed, 107 insertions(+), 30 deletions(-) create mode 100644 base2k/.vscode/settings.json diff --git a/base2k/.vscode/settings.json b/base2k/.vscode/settings.json new file mode 100644 index 0000000..eecbcdc --- /dev/null +++ b/base2k/.vscode/settings.json @@ -0,0 +1,8 @@ +{ + "github.copilot.enable": { + "*": false, + "plaintext": false, + "markdown": false, + "scminput": false + } +} \ No newline at end of file diff --git a/base2k/src/encoding.rs b/base2k/src/encoding.rs index 6034b95..980dab4 100644 --- a/base2k/src/encoding.rs +++ b/base2k/src/encoding.rs @@ -271,9 +271,9 @@ mod tests { let n: usize = 8; let module: Module = Module::::new(n); let log_base2k: usize = 17; - let cols: usize = 5; - let log_k: usize = cols * log_base2k - 5; - let mut a: VecZnx = VecZnx::new(&module, 2, cols); + let limbs: usize = 5; + let log_k: usize = limbs * log_base2k - 5; + let mut a: VecZnx = VecZnx::new(&module, 2, limbs); let mut source: Source = Source::new([0u8; 32]); let raw: &mut [i64] = a.raw_mut(); raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); @@ -293,9 +293,9 @@ mod tests { let n: usize = 8; let module: Module = Module::::new(n); let log_base2k: usize = 17; - let cols: usize = 5; - let log_k: usize = cols * log_base2k - 5; - let mut a: VecZnx = VecZnx::new(&module, 2, cols); + let limbs: usize = 5; + let log_k: usize = limbs * log_base2k - 5; + let mut a: VecZnx = VecZnx::new(&module, 2, limbs); let mut source = Source::new([0u8; 32]); let raw: &mut [i64] = a.raw_mut(); raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); diff --git a/base2k/src/sampling.rs b/base2k/src/sampling.rs index c415b80..80d174c 100644 --- a/base2k/src/sampling.rs +++ b/base2k/src/sampling.rs @@ -1,4 +1,4 @@ -use crate::{Backend, Module, VecZnx, ZnxInfos, ZnxLayout}; +use crate::{Backend, Module, VecZnx, ZnxLayout}; use rand_distr::{Distribution, Normal}; use sampling::source::Source; @@ -59,28 +59,25 @@ impl Sampling for Module { (bound.log2().ceil() as i64) ); + let limb: usize = (log_k + log_base2k - 1) / log_base2k - 1; let log_base2k_rem: usize = log_k % log_base2k; if log_base2k_rem != 0 { - a.at_poly_mut(col_i, a.limbs() - 1) - .iter_mut() - .for_each(|a| { - let mut dist_f64: f64 = dist.sample(source); - while dist_f64.abs() > bound { - dist_f64 = dist.sample(source) - } - *a += (dist_f64.round() as i64) << log_base2k_rem; - }); + a.at_poly_mut(col_i, limb).iter_mut().for_each(|a| { + let mut dist_f64: f64 = dist.sample(source); + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } + *a += (dist_f64.round() as i64) << log_base2k_rem; + }); } else { - a.at_poly_mut(col_i, a.limbs() - 1) - .iter_mut() - .for_each(|a| { - let mut dist_f64: f64 = dist.sample(source); - while dist_f64.abs() > bound { - dist_f64 = dist.sample(source) - } - *a += dist_f64.round() as i64 - }); + a.at_poly_mut(col_i, limb).iter_mut().for_each(|a| { + let mut dist_f64: f64 = dist.sample(source); + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } + *a += dist_f64.round() as i64 + }); } } @@ -105,3 +102,70 @@ impl Sampling for Module { ); } } + +#[cfg(test)] +mod tests { + use super::Sampling; + use crate::{FFT64, Module, Stats, VecZnx, ZnxBase, ZnxLayout}; + use sampling::source::Source; + + #[test] + fn fill_uniform() { + let n: usize = 4096; + let module: Module = Module::::new(n); + let log_base2k: usize = 17; + let limbs: usize = 5; + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + let zero: Vec = vec![0; n]; + let one_12_sqrt: f64 = 0.28867513459481287; + (0..cols).for_each(|col_i| { + let mut a: VecZnx = VecZnx::new(&module, cols, limbs); + module.fill_uniform(log_base2k, &mut a, col_i, limbs, &mut source); + (0..cols).for_each(|col_j| { + if col_j != col_i { + (0..limbs).for_each(|limb_i| { + assert_eq!(a.at_poly(col_j, limb_i), zero); + }) + } else { + let std: f64 = a.std(col_i, log_base2k); + assert!( + (std - one_12_sqrt).abs() < 0.01, + "std={} ~!= {}", + std, + one_12_sqrt + ); + } + }) + }); + } + + #[test] + fn add_normal() { + let n: usize = 4096; + let module: Module = Module::::new(n); + let log_base2k: usize = 17; + let log_k: usize = 2 * 17; + let limbs: usize = 5; + let sigma: f64 = 3.2; + let bound: f64 = 6.0 * sigma; + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + let zero: Vec = vec![0; n]; + let k_f64: f64 = (1u64 << log_k as u64) as f64; + (0..cols).for_each(|col_i| { + let mut a: VecZnx = VecZnx::new(&module, cols, limbs); + module.add_normal(log_base2k, &mut a, col_i, log_k, &mut source, sigma, bound); + (0..cols).for_each(|col_j| { + if col_j != col_i { + (0..limbs).for_each(|limb_i| { + assert_eq!(a.at_poly(col_j, limb_i), zero); + }) + } else { + let std: f64 = a.std(col_i, log_base2k) * k_f64; + assert!((std - sigma).abs() < 0.1, "std={} ~!= {}", std, sigma); + } + }) + }); + } +} diff --git a/base2k/src/stats.rs b/base2k/src/stats.rs index 44e441f..7fcf7c3 100644 --- a/base2k/src/stats.rs +++ b/base2k/src/stats.rs @@ -3,11 +3,16 @@ use rug::Float; use rug::float::Round; use rug::ops::{AddAssignRound, DivAssignRound, SubAssignRound}; -impl VecZnx { - pub fn std(&self, poly_idx: usize, log_base2k: usize) -> f64 { - let prec: u32 = (self.cols() * log_base2k) as u32; +pub trait Stats { + /// Returns the standard devaition of the i-th polynomial. + fn std(&self, col_i: usize, log_base2k: usize) -> f64; +} + +impl Stats for VecZnx { + fn std(&self, col_i: usize, log_base2k: usize) -> f64 { + let prec: u32 = (self.limbs() * log_base2k) as u32; let mut data: Vec = (0..self.n()).map(|_| Float::with_val(prec, 0)).collect(); - self.decode_vec_float(poly_idx, log_base2k, &mut data); + self.decode_vec_float(col_i, log_base2k, &mut data); // std = sqrt(sum((xi - avg)^2) / n) let mut avg: Float = Float::with_val(prec, 0); data.iter().for_each(|x| { diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index 7a8cc48..8c67a8d 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,5 +1,5 @@ use crate::ffi::vec_znx_big::{self, vec_znx_big_t}; -use crate::{Backend, FFT64, Module, VecZnx, VecZnxDft, ZnxBase, ZnxInfos, ZnxLayout, alloc_aligned, assert_alignement}; +use crate::{Backend, FFT64, Module, VecZnx, ZnxBase, ZnxInfos, ZnxLayout, alloc_aligned, assert_alignement}; use std::marker::PhantomData; pub struct VecZnxBig {