Improve noise stats functionality

This commit is contained in:
Pro7ech
2025-11-10 17:38:52 +01:00
parent e7bf8e9307
commit af45595848
15 changed files with 58 additions and 33 deletions

View File

@@ -18,6 +18,7 @@ pub use module::*;
pub use scalar_znx::*;
pub use scratch::*;
pub use serialization::*;
pub use stats::*;
pub use svp_ppol::*;
pub use vec_znx::*;
pub use vec_znx_big::*;

View File

@@ -6,15 +6,33 @@ use rug::{
use crate::layouts::{Backend, DataRef, VecZnx, VecZnxBig, VecZnxBigToRef, ZnxInfos};
pub struct Stats {
max: f64,
std: f64,
}
impl Stats {
pub fn max(&self) -> f64 {
self.max
}
pub fn std(&self) -> f64 {
self.std
}
}
impl<D: DataRef> VecZnx<D> {
pub fn std(&self, base2k: usize, col: usize) -> f64 {
pub fn stats(&self, base2k: usize, col: usize) -> Stats {
let prec: u32 = (self.size() * base2k) as u32;
let mut data: Vec<Float> = (0..self.n()).map(|_| Float::with_val(prec, 0)).collect();
self.decode_vec_float(base2k, col, &mut data);
// std = sqrt(sum((xi - avg)^2) / n)
let mut avg: Float = Float::with_val(prec, 0);
let mut max: Float = Float::with_val(prec, 0);
data.iter().for_each(|x| {
avg.add_assign_round(x, Round::Nearest);
max.max_mut(&Float::with_val(53, x.abs_ref()));
});
avg.div_assign_round(Float::with_val(prec, data.len()), Round::Nearest);
data.iter_mut().for_each(|x| {
@@ -24,12 +42,15 @@ impl<D: DataRef> VecZnx<D> {
data.iter().for_each(|x| std += x * x);
std.div_assign_round(Float::with_val(prec, data.len()), Round::Nearest);
std = std.sqrt();
std.to_f64()
Stats {
std: std.to_f64(),
max: max.to_f64(),
}
}
}
impl<D: DataRef, B: Backend + Backend<ScalarBig = i64>> VecZnxBig<D, B> {
pub fn std(&self, base2k: usize, col: usize) -> f64 {
pub fn stats(&self, base2k: usize, col: usize) -> Stats {
let self_ref: VecZnxBig<&[u8], B> = self.to_ref();
let znx: VecZnx<&[u8]> = VecZnx {
data: self_ref.data,
@@ -38,6 +59,6 @@ impl<D: DataRef, B: Backend + Backend<ScalarBig = i64>> VecZnxBig<D, B> {
size: self_ref.size,
max_size: self_ref.max_size,
};
znx.std(base2k, col)
znx.stats(base2k, col)
}
}

View File

@@ -324,7 +324,7 @@ where
assert_eq!(a.at(col_j, limb_i), zero);
})
} else {
let std: f64 = a.std(base2k, col_i) * k_f64;
let std: f64 = a.stats(base2k, col_i).std() * k_f64;
assert!(
(std - sigma * sqrt2).abs() < 0.1,
"std={} ~!= {}",

View File

@@ -662,8 +662,8 @@ mod tests {
vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut res_ref, j, &mut carry);
vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut res_test, j, &mut carry);
assert!(res_ref.std(base2k, j).log2() - (k as f64) <= (k * base2k) as f64);
assert!(res_test.std(base2k, j).log2() - (k as f64) <= (k * base2k) as f64);
assert!(res_ref.stats(base2k, j).std().log2() - (k as f64) <= (k * base2k) as f64);
assert!(res_test.stats(base2k, j).std().log2() - (k as f64) <= (k * base2k) as f64);
}
}
}

View File

@@ -717,7 +717,7 @@ where
assert_eq!(a.at(col_j, limb_i), zero);
})
} else {
let std: f64 = a.std(base2k, col_i);
let std: f64 = a.stats(base2k, col_i).std();
assert!(
(std - one_12_sqrt).abs() < 0.01,
"std={std} ~!= {one_12_sqrt}",
@@ -750,7 +750,7 @@ where
assert_eq!(a.at(col_j, limb_i), zero);
})
} else {
let std: f64 = a.std(base2k, col_i) * k_f64;
let std: f64 = a.stats(base2k, col_i).std() * k_f64;
assert!((std - sigma).abs() < 0.1, "std={std} ~!= {sigma}");
}
})
@@ -782,7 +782,7 @@ where
assert_eq!(a.at(col_j, limb_i), zero);
})
} else {
let std: f64 = a.std(base2k, col_i) * k_f64;
let std: f64 = a.stats(base2k, col_i).std() * k_f64;
assert!(
(std - sigma * sqrt2).abs() < 0.1,
"std={std} ~!= {}",