Improve noise stats functionality

This commit is contained in:
Pro7ech
2025-11-10 17:38:52 +01:00
parent e7bf8e9307
commit af45595848
15 changed files with 58 additions and 33 deletions

View File

@@ -85,7 +85,7 @@ fn main() {
module.glwe_sub_inplace(&mut pt_want, &pt_have); module.glwe_sub_inplace(&mut pt_want, &pt_have);
// Ideal vs. actual noise // Ideal vs. actual noise
let noise_have: f64 = pt_want.data.std(base2k.into(), 0) * (ct.k().as_u32() as f64).exp2(); let noise_have: f64 = pt_want.data.stats(base2k.into(), 0).std() * (ct.k().as_u32() as f64).exp2();
let noise_want: f64 = SIGMA; let noise_want: f64 = SIGMA;
// Check // Check

View File

@@ -60,7 +60,7 @@ where
self.vec_znx_sub_scalar_inplace(&mut pt.data, 0, (dsize - 1) + row_i * dsize, pt_want, col_i); self.vec_znx_sub_scalar_inplace(&mut pt.data, 0, (dsize - 1) + row_i * dsize, pt_want, col_i);
let noise_have: f64 = pt.data.std(base2k, 0).log2(); let noise_have: f64 = pt.data.stats(base2k, 0).std().log2();
println!("noise_have: {noise_have}"); println!("noise_have: {noise_have}");

View File

@@ -107,7 +107,7 @@ where
self.vec_znx_sub_inplace(&mut pt_have.data, 0, &pt.data, 0); self.vec_znx_sub_inplace(&mut pt_have.data, 0, &pt.data, 0);
let std_pt: f64 = pt_have.data.std(base2k, 0).log2(); let std_pt: f64 = pt_have.data.stats(base2k, 0).std().log2();
let noise: f64 = max_noise(col_j); let noise: f64 = max_noise(col_j);
assert!(std_pt <= noise, "{std_pt} > {noise}"); assert!(std_pt <= noise, "{std_pt} > {noise}");
@@ -165,7 +165,7 @@ where
self.vec_znx_sub_inplace(&mut pt_have.data, 0, &pt.data, 0); self.vec_znx_sub_inplace(&mut pt_have.data, 0, &pt.data, 0);
let std_pt: f64 = pt_have.data.std(base2k, 0).log2(); let std_pt: f64 = pt_have.data.stats(base2k, 0).std().log2();
println!("col: {col_j} row: {row_i}: {std_pt}"); println!("col: {col_j} row: {row_i}: {std_pt}");
pt.data.zero(); pt.data.zero();
} }

View File

@@ -1,6 +1,6 @@
use poulpy_hal::{ use poulpy_hal::{
api::{ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxNormalizeInplace, VecZnxSubInplace}, api::{ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxNormalizeInplace, VecZnxSubInplace},
layouts::{Backend, DataRef, Module, Scratch, ScratchOwned}, layouts::{Backend, DataRef, Module, Scratch, ScratchOwned, Stats},
}; };
use crate::{ use crate::{
@@ -10,7 +10,7 @@ use crate::{
}; };
impl<D: DataRef> GLWE<D> { impl<D: DataRef> GLWE<D> {
pub fn noise<M, S, P, BE: Backend>(&self, module: &M, sk_prepared: &S, pt_want: &P, scratch: &mut Scratch<BE>) -> f64 pub fn noise<M, S, P, BE: Backend>(&self, module: &M, sk_prepared: &S, pt_want: &P, scratch: &mut Scratch<BE>) -> Stats
where where
M: GLWENoise<BE>, M: GLWENoise<BE>,
S: GLWESecretPreparedToRef<BE>, S: GLWESecretPreparedToRef<BE>,
@@ -30,7 +30,7 @@ impl<D: DataRef> GLWE<D> {
} }
pub trait GLWENoise<BE: Backend> { pub trait GLWENoise<BE: Backend> {
fn glwe_noise<R, S, P>(&self, res: &R, sk_prepared: &S, pt_want: &P, scratch: &mut Scratch<BE>) -> f64 fn glwe_noise<R, S, P>(&self, res: &R, sk_prepared: &S, pt_want: &P, scratch: &mut Scratch<BE>) -> Stats
where where
R: GLWEToRef, R: GLWEToRef,
S: GLWESecretPreparedToRef<BE>, S: GLWESecretPreparedToRef<BE>,
@@ -49,7 +49,7 @@ where
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
fn glwe_noise<R, S, P>(&self, res: &R, sk_prepared: &S, pt_want: &P, scratch: &mut Scratch<BE>) -> f64 fn glwe_noise<R, S, P>(&self, res: &R, sk_prepared: &S, pt_want: &P, scratch: &mut Scratch<BE>) -> Stats
where where
R: GLWEToRef, R: GLWEToRef,
S: GLWESecretPreparedToRef<BE>, S: GLWESecretPreparedToRef<BE>,
@@ -63,7 +63,7 @@ where
self.glwe_decrypt(res, &mut pt_have, sk_prepared, scratch); self.glwe_decrypt(res, &mut pt_have, sk_prepared, scratch);
self.vec_znx_sub_inplace(&mut pt_have.data, 0, &pt_want.data, 0); self.vec_znx_sub_inplace(&mut pt_have.data, 0, &pt_want.data, 0);
self.vec_znx_normalize_inplace(res_ref.base2k().into(), &mut pt_have.data, 0, scratch); self.vec_znx_normalize_inplace(res_ref.base2k().into(), &mut pt_have.data, 0, scratch);
pt_have.data.std(res_ref.base2k().into(), 0).log2() pt_have.data.stats(res_ref.base2k().into(), 0)
} }
fn glwe_assert_noise<R, S, P>(&self, res: &R, sk_prepared: &S, pt_want: &P, max_noise: f64) fn glwe_assert_noise<R, S, P>(&self, res: &R, sk_prepared: &S, pt_want: &P, max_noise: f64)
@@ -74,7 +74,10 @@ where
{ {
let res: &GLWE<&[u8]> = &res.to_ref(); let res: &GLWE<&[u8]> = &res.to_ref();
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(self.glwe_decrypt_tmp_bytes(res)); let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(self.glwe_decrypt_tmp_bytes(res));
let noise_have: f64 = self.glwe_noise(res, sk_prepared, pt_want, scratch.borrow()); let noise_have: f64 = self
.glwe_noise(res, sk_prepared, pt_want, scratch.borrow())
.std()
.log2();
assert!(noise_have <= max_noise, "{noise_have} {max_noise}"); assert!(noise_have <= max_noise, "{noise_have} {max_noise}");
} }
} }

View File

@@ -159,7 +159,7 @@ where
col_i, col_i,
); );
let noise_have: f64 = pt.data.std(base2k, 0).log2(); let noise_have: f64 = pt.data.stats(base2k, 0).std().log2();
let noise_want: f64 = log2_std_noise_gglwe_product( let noise_want: f64 = log2_std_noise_gglwe_product(
n as f64, n as f64,
base2k * di, base2k * di,
@@ -306,7 +306,7 @@ where
col_i, col_i,
); );
let noise_have: f64 = pt.data.std(base2k, 0).log2(); let noise_have: f64 = pt.data.stats(base2k, 0).std().log2();
let noise_want: f64 = log2_std_noise_gglwe_product( let noise_want: f64 = log2_std_noise_gglwe_product(
n as f64, n as f64,
base2k * di, base2k * di,

View File

@@ -74,7 +74,7 @@ where
module.glwe_sub_inplace(&mut pt_want, &pt_have); module.glwe_sub_inplace(&mut pt_want, &pt_have);
let noise_have: f64 = pt_want.data.std(base2k, 0) * (ct.k().as_u32() as f64).exp2(); let noise_have: f64 = pt_want.data.stats(base2k, 0).std() * (ct.k().as_u32() as f64).exp2();
let noise_want: f64 = SIGMA; let noise_want: f64 = SIGMA;
assert!(noise_have <= noise_want + 0.2); assert!(noise_have <= noise_want + 0.2);
@@ -147,7 +147,7 @@ where
module.glwe_sub_inplace(&mut pt_want, &pt_have); module.glwe_sub_inplace(&mut pt_want, &pt_have);
let noise_have: f64 = pt_want.data.std(base2k, 0) * (ct.k().as_u32() as f64).exp2(); let noise_have: f64 = pt_want.data.stats(base2k, 0).std() * (ct.k().as_u32() as f64).exp2();
let noise_want: f64 = SIGMA; let noise_want: f64 = SIGMA;
assert!( assert!(
@@ -203,7 +203,7 @@ where
); );
ct.decrypt(module, &mut pt, &sk_prepared, scratch.borrow()); ct.decrypt(module, &mut pt, &sk_prepared, scratch.borrow());
assert!((SIGMA - pt.data.std(base2k, 0) * (k_ct as f64).exp2()) <= 0.2); assert!((SIGMA - pt.data.stats(base2k, 0).std() * (k_ct as f64).exp2()) <= 0.2);
} }
} }
@@ -271,7 +271,7 @@ where
module.glwe_sub_inplace(&mut pt_want, &pt_have); module.glwe_sub_inplace(&mut pt_want, &pt_have);
let noise_have: f64 = pt_want.data.std(base2k, 0).log2(); let noise_have: f64 = pt_want.data.stats(base2k, 0).std().log2();
let noise_want: f64 = ((((rank as f64) + 1.0) * n as f64 * 0.5 * SIGMA * SIGMA).sqrt()).log2() - (k_ct as f64); let noise_want: f64 = ((((rank as f64) + 1.0) * n as f64 * 0.5 * SIGMA * SIGMA).sqrt()).log2() - (k_ct as f64);
assert!( assert!(

View File

@@ -150,7 +150,7 @@ where
module.glwe_sub_inplace(&mut pt, &pt_want); module.glwe_sub_inplace(&mut pt, &pt_want);
let noise_have: f64 = pt.std().log2(); let noise_have: f64 = pt.stats().std().log2();
assert!( assert!(
noise_have < -((k_ct - base2k) as f64), noise_have < -((k_ct - base2k) as f64),

View File

@@ -123,7 +123,7 @@ where
module.vec_znx_sub_inplace(&mut pt_want.data, 0, &pt_have.data, 0); module.vec_znx_sub_inplace(&mut pt_want.data, 0, &pt_have.data, 0);
module.vec_znx_normalize_inplace(base2k, &mut pt_want.data, 0, scratch.borrow()); module.vec_znx_normalize_inplace(base2k, &mut pt_want.data, 0, scratch.borrow());
let noise_have: f64 = pt_want.std().log2(); let noise_have: f64 = pt_want.stats().std().log2();
let mut noise_want: f64 = var_noise_gglwe_product( let mut noise_want: f64 = var_noise_gglwe_product(
n as f64, n as f64,

View File

@@ -1,5 +1,5 @@
use crate::layouts::{GLWEPlaintext, LWEInfos, LWEPlaintext, TorusPrecision}; use crate::layouts::{GLWEPlaintext, LWEInfos, LWEPlaintext, TorusPrecision};
use poulpy_hal::layouts::{DataMut, DataRef}; use poulpy_hal::layouts::{DataMut, DataRef, Stats};
use rug::Float; use rug::Float;
impl<D: DataMut> GLWEPlaintext<D> { impl<D: DataMut> GLWEPlaintext<D> {
@@ -29,8 +29,8 @@ impl<D: DataRef> GLWEPlaintext<D> {
self.data.decode_vec_float(self.base2k().into(), 0, data); self.data.decode_vec_float(self.base2k().into(), 0, data);
} }
pub fn std(&self) -> f64 { pub fn stats(&self) -> Stats {
self.data.std(self.base2k().into(), 0) self.data.stats(self.base2k().into(), 0)
} }
} }

View File

@@ -18,6 +18,7 @@ pub use module::*;
pub use scalar_znx::*; pub use scalar_znx::*;
pub use scratch::*; pub use scratch::*;
pub use serialization::*; pub use serialization::*;
pub use stats::*;
pub use svp_ppol::*; pub use svp_ppol::*;
pub use vec_znx::*; pub use vec_znx::*;
pub use vec_znx_big::*; pub use vec_znx_big::*;

View File

@@ -6,15 +6,33 @@ use rug::{
use crate::layouts::{Backend, DataRef, VecZnx, VecZnxBig, VecZnxBigToRef, ZnxInfos}; use crate::layouts::{Backend, DataRef, VecZnx, VecZnxBig, VecZnxBigToRef, ZnxInfos};
pub struct Stats {
max: f64,
std: f64,
}
impl Stats {
pub fn max(&self) -> f64 {
self.max
}
pub fn std(&self) -> f64 {
self.std
}
}
impl<D: DataRef> VecZnx<D> { impl<D: DataRef> VecZnx<D> {
pub fn std(&self, base2k: usize, col: usize) -> f64 { pub fn stats(&self, base2k: usize, col: usize) -> Stats {
let prec: u32 = (self.size() * base2k) as u32; let prec: u32 = (self.size() * base2k) as u32;
let mut data: Vec<Float> = (0..self.n()).map(|_| Float::with_val(prec, 0)).collect(); let mut data: Vec<Float> = (0..self.n()).map(|_| Float::with_val(prec, 0)).collect();
self.decode_vec_float(base2k, col, &mut data); self.decode_vec_float(base2k, col, &mut data);
// std = sqrt(sum((xi - avg)^2) / n) // std = sqrt(sum((xi - avg)^2) / n)
let mut avg: Float = Float::with_val(prec, 0); let mut avg: Float = Float::with_val(prec, 0);
let mut max: Float = Float::with_val(prec, 0);
data.iter().for_each(|x| { data.iter().for_each(|x| {
avg.add_assign_round(x, Round::Nearest); avg.add_assign_round(x, Round::Nearest);
max.max_mut(&Float::with_val(53, x.abs_ref()));
}); });
avg.div_assign_round(Float::with_val(prec, data.len()), Round::Nearest); avg.div_assign_round(Float::with_val(prec, data.len()), Round::Nearest);
data.iter_mut().for_each(|x| { data.iter_mut().for_each(|x| {
@@ -24,12 +42,15 @@ impl<D: DataRef> VecZnx<D> {
data.iter().for_each(|x| std += x * x); data.iter().for_each(|x| std += x * x);
std.div_assign_round(Float::with_val(prec, data.len()), Round::Nearest); std.div_assign_round(Float::with_val(prec, data.len()), Round::Nearest);
std = std.sqrt(); std = std.sqrt();
std.to_f64() Stats {
std: std.to_f64(),
max: max.to_f64(),
}
} }
} }
impl<D: DataRef, B: Backend + Backend<ScalarBig = i64>> VecZnxBig<D, B> { impl<D: DataRef, B: Backend + Backend<ScalarBig = i64>> VecZnxBig<D, B> {
pub fn std(&self, base2k: usize, col: usize) -> f64 { pub fn stats(&self, base2k: usize, col: usize) -> Stats {
let self_ref: VecZnxBig<&[u8], B> = self.to_ref(); let self_ref: VecZnxBig<&[u8], B> = self.to_ref();
let znx: VecZnx<&[u8]> = VecZnx { let znx: VecZnx<&[u8]> = VecZnx {
data: self_ref.data, data: self_ref.data,
@@ -38,6 +59,6 @@ impl<D: DataRef, B: Backend + Backend<ScalarBig = i64>> VecZnxBig<D, B> {
size: self_ref.size, size: self_ref.size,
max_size: self_ref.max_size, max_size: self_ref.max_size,
}; };
znx.std(base2k, col) znx.stats(base2k, col)
} }
} }

View File

@@ -324,7 +324,7 @@ where
assert_eq!(a.at(col_j, limb_i), zero); assert_eq!(a.at(col_j, limb_i), zero);
}) })
} else { } else {
let std: f64 = a.std(base2k, col_i) * k_f64; let std: f64 = a.stats(base2k, col_i).std() * k_f64;
assert!( assert!(
(std - sigma * sqrt2).abs() < 0.1, (std - sigma * sqrt2).abs() < 0.1,
"std={} ~!= {}", "std={} ~!= {}",

View File

@@ -662,8 +662,8 @@ mod tests {
vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut res_ref, j, &mut carry); vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut res_ref, j, &mut carry);
vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut res_test, j, &mut carry); vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut res_test, j, &mut carry);
assert!(res_ref.std(base2k, j).log2() - (k as f64) <= (k * base2k) as f64); assert!(res_ref.stats(base2k, j).std().log2() - (k as f64) <= (k * base2k) as f64);
assert!(res_test.std(base2k, j).log2() - (k as f64) <= (k * base2k) as f64); assert!(res_test.stats(base2k, j).std().log2() - (k as f64) <= (k * base2k) as f64);
} }
} }
} }

View File

@@ -717,7 +717,7 @@ where
assert_eq!(a.at(col_j, limb_i), zero); assert_eq!(a.at(col_j, limb_i), zero);
}) })
} else { } else {
let std: f64 = a.std(base2k, col_i); let std: f64 = a.stats(base2k, col_i).std();
assert!( assert!(
(std - one_12_sqrt).abs() < 0.01, (std - one_12_sqrt).abs() < 0.01,
"std={std} ~!= {one_12_sqrt}", "std={std} ~!= {one_12_sqrt}",
@@ -750,7 +750,7 @@ where
assert_eq!(a.at(col_j, limb_i), zero); assert_eq!(a.at(col_j, limb_i), zero);
}) })
} else { } else {
let std: f64 = a.std(base2k, col_i) * k_f64; let std: f64 = a.stats(base2k, col_i).std() * k_f64;
assert!((std - sigma).abs() < 0.1, "std={std} ~!= {sigma}"); assert!((std - sigma).abs() < 0.1, "std={std} ~!= {sigma}");
} }
}) })
@@ -782,7 +782,7 @@ where
assert_eq!(a.at(col_j, limb_i), zero); assert_eq!(a.at(col_j, limb_i), zero);
}) })
} else { } else {
let std: f64 = a.std(base2k, col_i) * k_f64; let std: f64 = a.stats(base2k, col_i).std() * k_f64;
assert!( assert!(
(std - sigma * sqrt2).abs() < 0.1, (std - sigma * sqrt2).abs() < 0.1,
"std={std} ~!= {}", "std={std} ~!= {}",

View File

@@ -8,7 +8,7 @@ use poulpy_core::{
}; };
use poulpy_hal::{ use poulpy_hal::{
api::ModuleLogN, api::ModuleLogN,
layouts::{Backend, Data, DataMut, DataRef, Scratch}, layouts::{Backend, Data, DataMut, DataRef, Scratch, Stats},
source::Source, source::Source,
}; };
use std::{collections::HashMap, marker::PhantomData}; use std::{collections::HashMap, marker::PhantomData};
@@ -114,7 +114,7 @@ impl<D: DataMut, T: UnsignedInteger + ToBits> FheUint<D, T> {
} }
impl<D: DataRef, T: UnsignedInteger + FromBits> FheUint<D, T> { impl<D: DataRef, T: UnsignedInteger + FromBits> FheUint<D, T> {
pub fn noise<S, M, BE: Backend>(&self, module: &M, want: u32, sk: &S, scratch: &mut Scratch<BE>) -> f64 pub fn noise<S, M, BE: Backend>(&self, module: &M, want: u32, sk: &S, scratch: &mut Scratch<BE>) -> Stats
where where
S: GLWESecretPreparedToRef<BE> + GLWEInfos, S: GLWESecretPreparedToRef<BE> + GLWEInfos,
M: ModuleLogN + GLWEDecrypt<BE> + GLWENoise<BE>, M: ModuleLogN + GLWEDecrypt<BE> + GLWENoise<BE>,