From 9d5bc4363208efa16c2473d95eeae79fdcb5cf3f Mon Sep 17 00:00:00 2001 From: Pro7ech Date: Wed, 29 Oct 2025 10:05:44 +0100 Subject: [PATCH] Update bit encoding to byte interleaving to enable trivial byte-level manipulation --- .../src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs | 15 +++--- .../src/tfhe/bdd_arithmetic/blind_rotation.rs | 2 +- .../bdd_arithmetic/ciphertexts/fhe_uint.rs | 51 ++++++++++--------- .../ciphertexts/fhe_uint_prepared.rs | 6 +-- .../ciphertexts/fhe_uint_prepared_debug.rs | 2 +- poulpy-schemes/src/tfhe/bdd_arithmetic/mod.rs | 35 ++++++++++--- .../bdd_arithmetic/tests/test_suite/mod.rs | 2 +- 7 files changed, 72 insertions(+), 41 deletions(-) diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs index 3195917..d69bfe1 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs @@ -1,7 +1,10 @@ use std::marker::PhantomData; use poulpy_core::{GLWECopy, GLWEPacking, ScratchTakeCore, layouts::GGSWPrepared}; -use poulpy_hal::layouts::{Backend, DataMut, DataRef, Module, Scratch}; +use poulpy_hal::{ + api::ModuleLogN, + layouts::{Backend, DataMut, DataRef, Module, Scratch}, +}; use crate::tfhe::{ bdd_arithmetic::{ @@ -18,7 +21,7 @@ impl ExecuteBDDCircuit2WTo1W for Module< pub trait ExecuteBDDCircuit2WTo1W where - Self: Sized + ExecuteBDDCircuit + GLWEPacking + GLWECopy, + Self: Sized + ModuleLogN + ExecuteBDDCircuit + GLWEPacking + GLWECopy, { /// Operations Z x Z -> Z fn execute_bdd_circuit_2w_to_1w( @@ -45,7 +48,7 @@ where _phantom: PhantomData, }; - let (mut out_bits, scratch_1) = scratch.take_glwe_slice(T::WORD_SIZE, out); + let (mut out_bits, scratch_1) = scratch.take_glwe_slice(T::BITS as usize, out); // Evaluates out[i] = circuit[i](a, b) self.execute_bdd_circuit(&mut out_bits, &helper, circuit, scratch_1); @@ -62,15 +65,15 @@ struct FheUintHelper<'a, T: UnsignedInteger, BE: Backend> { impl<'a, T: UnsignedInteger, BE: Backend> GetGGSWBit for FheUintHelper<'a, T, BE> { fn get_bit(&self, bit: usize) -> GGSWPrepared<&[u8], BE> { - let lo: usize = bit % T::WORD_SIZE; - let hi: usize = bit / T::WORD_SIZE; + let lo: usize = bit % T::BITS as usize; + let hi: usize = bit / T::BITS as usize; self.data[hi].get_bit(lo) } } impl<'a, T: UnsignedInteger, BE: Backend> BitSize for FheUintHelper<'a, T, BE> { fn bit_size(&self) -> usize { - T::WORD_SIZE * self.data.len() + T::BITS as usize * self.data.len() } } diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/blind_rotation.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/blind_rotation.rs index b50329c..593abf7 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/blind_rotation.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/blind_rotation.rs @@ -195,7 +195,7 @@ where K: GetGGSWBit, Scratch: ScratchTakeCore, { - assert!(bit_rsh + bit_mask <= T::WORD_SIZE); + assert!(bit_rsh + bit_mask <= T::BITS as usize); let mut res: GLWE<&mut [u8]> = res.to_mut(); diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint.rs index cb23e8e..f8a5e12 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint.rs @@ -1,4 +1,3 @@ -use itertools::Itertools; use poulpy_core::{ GLWECopy, GLWEDecrypt, GLWEEncryptSk, GLWEPacking, GLWERotate, LWEFromGLWE, ScratchTakeCore, layouts::{ @@ -7,7 +6,7 @@ use poulpy_core::{ }, }; use poulpy_hal::{ - api::ModuleN, + api::ModuleLogN, layouts::{Backend, Data, DataMut, DataRef, Scratch}, source::Source, }; @@ -68,22 +67,23 @@ impl FheUint { scratch: &mut Scratch, ) where S: GLWESecretPreparedToRef + GLWEInfos, - M: ModuleN + GLWEEncryptSk, + M: ModuleLogN + GLWEEncryptSk, Scratch: ScratchTakeCore, { #[cfg(debug_assertions)] { - assert!(module.n().is_multiple_of(T::WORD_SIZE)); + assert!(module.n().is_multiple_of(T::BITS as usize)); assert_eq!(self.n(), module.n() as u32); assert_eq!(sk.n(), module.n() as u32); } - let gap: usize = module.n() / T::WORD_SIZE; - let mut data_bits: Vec = vec![0i64; module.n()]; - for i in 0..T::WORD_SIZE { - data_bits[i * gap] = data.bit(i) as i64 + let log_gap: usize = module.log_n() - T::LOG_BITS as usize; + + // Interleaves bytes + for i in 0..T::BITS as usize { + data_bits[T::bit_index(i) << log_gap] = data.bit(i) as i64 } let pt_infos = GLWEPlaintextLayout { @@ -104,18 +104,16 @@ impl FheUint { pub fn decrypt(&self, module: &M, sk: &S, scratch: &mut Scratch) -> T where S: GLWESecretPreparedToRef + GLWEInfos, - M: GLWEDecrypt, + M: ModuleLogN + GLWEDecrypt, Scratch: ScratchTakeCore, { #[cfg(debug_assertions)] { - assert!(module.n().is_multiple_of(T::WORD_SIZE)); + assert!(module.n().is_multiple_of(T::BITS as usize)); assert_eq!(self.n(), module.n() as u32); assert_eq!(sk.n(), module.n() as u32); } - let gap: usize = module.n() / T::WORD_SIZE; - let pt_infos = GLWEPlaintextLayout { n: self.n(), base2k: self.base2k(), @@ -126,11 +124,18 @@ impl FheUint { self.bits.decrypt(module, &mut pt, sk, scratch_1); - let mut data: Vec = vec![0i64; module.n()]; + let mut data_bits: Vec = vec![0i64; module.n()]; + pt.decode_vec_i64(&mut data_bits, TorusPrecision(2)); - pt.decode_vec_i64(&mut data, TorusPrecision(2)); + let mut bits: Vec = vec![0u8; T::BITS as usize]; + + let log_gap: usize = module.log_n() - T::LOG_BITS as usize; + + // Retrives from interleaved bytes + for i in 0..T::BITS as usize { + bits[i] = data_bits[T::bit_index(i) << log_gap] as u8 + } - let bits: Vec = data.iter().step_by(gap).map(|c| *c as u8).collect_vec(); T::from_bits(&bits) } } @@ -146,15 +151,14 @@ impl FheUint { ) where D1: DataMut, ATK: DataRef, - M: GLWEPacking + GLWECopy, + M: ModuleLogN + GLWEPacking + GLWECopy, Scratch: ScratchTakeCore, { // Repacks the GLWE ciphertexts bits - let gap: usize = module.n() / T::WORD_SIZE; - let log_gap: usize = (usize::BITS - (gap - 1).leading_zeros()) as usize; + let log_gap: usize = module.log_n() - T::LOG_BITS as usize; let mut cts: HashMap> = HashMap::new(); - for (i, ct) in tmp_res.iter_mut().enumerate().take(T::WORD_SIZE) { - cts.insert(i * gap, ct); + for (i, ct) in tmp_res.iter_mut().enumerate().take(T::BITS as usize) { + cts.insert(T::bit_index(i) << log_gap, ct); } module.glwe_pack(&mut cts, log_gap, auto_keys, scratch); @@ -169,11 +173,12 @@ impl FheUint { where L: LWEToMut, K: GGLWEPreparedToRef + GGLWEInfos, - M: LWEFromGLWE + GLWERotate, + M: ModuleLogN + LWEFromGLWE + GLWERotate, Scratch: ScratchTakeCore, { - let gap: usize = module.n() / T::WORD_SIZE; - res.to_mut().from_glwe(module, self, bit * gap, ks, scratch); + let log_gap: usize = module.log_n() - T::LOG_BITS as usize; + res.to_mut() + .from_glwe(module, self, T::bit_index(bit) << log_gap, ks, scratch); } } diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared.rs index 050cff0..3068581 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared.rs @@ -63,7 +63,7 @@ where rank: Rank, ) -> FheUintPrepared, T, BE> { FheUintPrepared { - bits: (0..T::WORD_SIZE) + bits: (0..T::BITS) .map(|_| GGSWPrepared::alloc(self, base2k, k, dnum, dsize, rank)) .collect(), _phantom: PhantomData, @@ -125,7 +125,7 @@ where { use poulpy_hal::{api::ScratchTakeBasic, layouts::ZnxZero}; - assert!(self.n().is_multiple_of(T::WORD_SIZE)); + assert!(self.n().is_multiple_of(T::BITS as usize)); assert_eq!(res.n(), self.n() as u32); assert_eq!(sk.n(), self.n() as u32); @@ -133,7 +133,7 @@ where let (mut pt, scratch_2) = scratch_1.take_scalar_znx(self.n(), 1); pt.zero(); - for i in 0..T::WORD_SIZE { + for i in 0..T::BITS as usize { use poulpy_hal::layouts::ZnxViewMut; pt.at_mut(0, 0)[0] = value.bit(i) as i64; tmp_ggsw.encrypt_sk(self, &pt, sk, source_xa, source_xe, scratch_2); diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared_debug.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared_debug.rs index cc21983..6f130d9 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared_debug.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared_debug.rs @@ -42,7 +42,7 @@ impl FheUintPreparedDebug, T> { M: ModuleN, { Self { - bits: (0..T::WORD_SIZE) + bits: (0..T::BITS) .map(|_| GGSW::alloc(module.n().into(), base2k, k, rank, dnum, dsize)) .collect(), _phantom: PhantomData, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/mod.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/mod.rs index f73eb30..4898b98 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/mod.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/mod.rs @@ -15,23 +15,46 @@ pub use key::*; pub mod tests; pub trait UnsignedInteger: Copy + 'static { - const WORD_SIZE: usize; + const BITS: u32; + const LOG_BITS: u32; + const LOG_BYTES: u32; + const LOG_BYTES_MASK: usize; + + #[inline(always)] + fn bit_index(i: usize) -> usize { + ((i & Self::LOG_BYTES_MASK) << 3) | (i >> Self::LOG_BYTES) + } } impl UnsignedInteger for u8 { - const WORD_SIZE: usize = 8; + const BITS: u32 = u8::BITS; + const LOG_BITS: u32 = (u32::BITS - (Self::BITS - 1).leading_zeros()); + const LOG_BYTES: u32 = Self::LOG_BITS - 3; + const LOG_BYTES_MASK: usize = (1 << Self::LOG_BYTES) - 1; } impl UnsignedInteger for u16 { - const WORD_SIZE: usize = 16; + const BITS: u32 = u16::BITS; + const LOG_BITS: u32 = (u32::BITS - (Self::BITS - 1).leading_zeros()); + const LOG_BYTES: u32 = Self::LOG_BITS - 3; + const LOG_BYTES_MASK: usize = (1 << Self::LOG_BYTES) - 1; } impl UnsignedInteger for u32 { - const WORD_SIZE: usize = 32; + const BITS: u32 = u32::BITS; + const LOG_BITS: u32 = (u32::BITS - (Self::BITS - 1).leading_zeros()); + const LOG_BYTES: u32 = Self::LOG_BITS - 3; + const LOG_BYTES_MASK: usize = (1 << Self::LOG_BYTES) - 1; } impl UnsignedInteger for u64 { - const WORD_SIZE: usize = 64; + const BITS: u32 = u64::BITS; + const LOG_BITS: u32 = (u32::BITS - (Self::BITS - 1).leading_zeros()); + const LOG_BYTES: u32 = Self::LOG_BITS >> 3; + const LOG_BYTES_MASK: usize = (1 << Self::LOG_BYTES) - 1; } impl UnsignedInteger for u128 { - const WORD_SIZE: usize = 128; + const BITS: u32 = u128::BITS; + const LOG_BITS: u32 = (u32::BITS - (Self::BITS - 1).leading_zeros()); + const LOG_BYTES: u32 = Self::LOG_BITS >> 3; + const LOG_BYTES_MASK: usize = (1 << Self::LOG_BYTES) - 1; } pub trait ToBits { diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs index 30e5b8f..766ae73 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs @@ -121,7 +121,7 @@ impl TestContext { } } -pub(crate) const TEST_N_GLWE: u32 = 512; +pub(crate) const TEST_N_GLWE: u32 = 256; pub(crate) const TEST_N_LWE: u32 = 77; pub(crate) const TEST_BASE2K: u32 = 13; pub(crate) const TEST_K_GLWE: u32 = 26;