From 785bb46df20590562d20922353578c37ac25216d Mon Sep 17 00:00:00 2001 From: Pro7ech Date: Thu, 6 Nov 2025 11:18:41 +0100 Subject: [PATCH] fix decoding to use rounded division instead of arithmetic right shift --- poulpy-hal/src/layouts/encoding.rs | 24 +++++++++++++--- .../bdd_arithmetic/ciphertexts/fhe_uint.rs | 28 +++++++++++++++++-- 2 files changed, 46 insertions(+), 6 deletions(-) diff --git a/poulpy-hal/src/layouts/encoding.rs b/poulpy-hal/src/layouts/encoding.rs index 4d61f93..37a4677 100644 --- a/poulpy-hal/src/layouts/encoding.rs +++ b/poulpy-hal/src/layouts/encoding.rs @@ -164,13 +164,15 @@ impl VecZnx { data.copy_from_slice(a.at(col, 0)); let rem: usize = base2k - (k % base2k); if k < base2k { - data.iter_mut().for_each(|x| *x >>= rem); + let scale = 1 << rem as i64; + data.iter_mut().for_each(|x| *x = div_round(*x, scale)); } else { (1..size).for_each(|i| { if i == size - 1 && rem != base2k { let k_rem: usize = (base2k - rem) % base2k; + let scale: i64 = 1 << rem as i64; izip!(a.at(col, i).iter(), data.iter_mut()).for_each(|(x, y)| { - *y = (*y << k_rem) + (x >> rem); + *y = (*y << k_rem) + div_round(*x, scale); }); } else { izip!(a.at(col, i).iter(), data.iter_mut()).for_each(|(x, y)| { @@ -197,7 +199,8 @@ impl VecZnx { let x: i64 = a.at(col, j)[idx]; if j == size - 1 && rem != base2k { let k_rem: usize = (base2k - rem) % base2k; - res = (res << k_rem) + (x >> rem); + let scale: i64 = 1 << rem as i64; + res = (res << k_rem) + div_round(x, scale); } else { res = (res << base2k) + x; } @@ -293,7 +296,8 @@ impl Zn { let x: i64 = a.at(0, j)[0]; if j == size - 1 && rem != base2k { let k_rem: usize = (base2k - rem) % base2k; - res = (res << k_rem) + (x >> rem); + let scale: i64 = 1 << rem as i64; + res = (res << k_rem) + div_round(x, scale); } else { res = (res << base2k) + x; } @@ -324,3 +328,15 @@ impl Zn { res } } + +#[inline] +pub fn div_round(a: i64, b: i64) -> i64 { + assert!(b != 0, "division by zero"); + let div: i64 = a / b; + let rem: i64 = a % b; + if (2 * rem.abs()) >= b.abs() { + div + a.signum() * b.signum() + } else { + div + } +} diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint.rs index b4aac17..5ae7145 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint.rs @@ -1,5 +1,6 @@ use poulpy_core::{ - GLWEAdd, GLWECopy, GLWEDecrypt, GLWEEncryptSk, GLWEPacking, GLWERotate, GLWESub, GLWETrace, LWEFromGLWE, ScratchTakeCore, + GLWEAdd, GLWECopy, GLWEDecrypt, GLWEEncryptSk, GLWENoise, GLWEPacking, GLWERotate, GLWESub, GLWETrace, LWEFromGLWE, + ScratchTakeCore, layouts::{ Base2K, Degree, GGLWEInfos, GGLWEPreparedToRef, GLWE, GLWEAutomorphismKeyHelper, GLWEInfos, GLWEPlaintextLayout, GLWESecretPreparedToRef, GLWEToMut, GLWEToRef, GetGaloisElement, LWEInfos, LWEToMut, Rank, TorusPrecision, @@ -101,7 +102,7 @@ impl FheUint { let pt_infos = GLWEPlaintextLayout { n: self.n(), base2k: self.base2k(), - k: 1_usize.into(), + k: 2_usize.into(), }; let (mut pt, scratch_1) = scratch.take_glwe_plaintext(&pt_infos); @@ -113,6 +114,29 @@ impl FheUint { } impl FheUint { + pub fn noise(&self, module: &M, want: u32, sk: &S, scratch: &mut Scratch) -> f64 + where + S: GLWESecretPreparedToRef + GLWEInfos, + M: ModuleLogN + GLWEDecrypt + GLWENoise, + Scratch: ScratchTakeCore, + { + #[cfg(debug_assertions)] + { + assert!(module.n().is_multiple_of(T::BITS as usize)); + assert_eq!(self.n(), module.n() as u32); + assert_eq!(sk.n(), module.n() as u32); + } + + let (mut pt, scratch_1) = scratch.take_glwe_plaintext(self); + let mut data_bits = vec![0i64; module.n()]; + let log_gap: usize = module.log_n() - T::LOG_BITS as usize; + for i in 0..T::BITS as usize { + data_bits[T::bit_index(i) << log_gap] = want.bit(i) as i64 + } + pt.encode_vec_i64(&data_bits, TorusPrecision(2)); + self.bits.noise(module, sk, &pt, scratch_1) + } + pub fn decrypt(&self, module: &M, sk: &S, scratch: &mut Scratch) -> T where S: GLWESecretPreparedToRef + GLWEInfos,