diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint.rs index 402f412..b4aac17 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint.rs @@ -287,7 +287,7 @@ where impl ScratchTakeBDD for Scratch where Self: ScratchTakeCore {} impl FheUint { - pub fn get_bit(&self, module: &M, bit: usize, res: &mut R, ks: &K, scratch: &mut Scratch) + pub fn get_bit_lwe(&self, module: &M, bit: usize, res: &mut R, ks: &K, scratch: &mut Scratch) where R: LWEToMut, K: GGLWEPreparedToRef + GGLWEInfos, @@ -298,6 +298,37 @@ impl FheUint { res.to_mut() .from_glwe(module, self, T::bit_index(bit) << log_gap, ks, scratch); } + + pub fn get_bit_glwe(&self, module: &M, bit: usize, res: &mut R, keys: &H, scratch: &mut Scratch) + where + R: GLWEToMut, + K: GGLWEPreparedToRef + GGLWEInfos, + M: ModuleLogN + GLWERotate + GLWETrace, + H: GLWEAutomorphismKeyHelper, + K: GGLWEPreparedToRef + GGLWEInfos + GetGaloisElement, + Scratch: ScratchTakeCore, + { + let log_gap: usize = module.log_n() - T::LOG_BITS as usize; + let rot = (T::bit_index(bit) << log_gap) as i64; + module.glwe_rotate(-rot, res, self); + module.glwe_trace_inplace(res, 0, keys, scratch); + } + + pub fn get_byte(&self, module: &M, byte: usize, res: &mut R, keys: &H, scratch: &mut Scratch) + where + R: GLWEToMut, + K: GGLWEPreparedToRef + GGLWEInfos, + M: ModuleLogN + GLWERotate + GLWETrace, + H: GLWEAutomorphismKeyHelper, + K: GGLWEPreparedToRef + GGLWEInfos + GetGaloisElement, + Scratch: ScratchTakeCore, + { + let log_gap: usize = module.log_n() - T::LOG_BITS as usize; + let trace_start = (T::LOG_BITS - T::LOG_BYTES) as usize; + let rot = (T::bit_index(byte << 3) << log_gap) as i64; + module.glwe_rotate(-rot, res, self); + module.glwe_trace_inplace(res, trace_start, keys, scratch); + } } impl GLWEToRef for FheUint { diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared.rs index 3ea8423..ccb7e07 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared.rs @@ -254,7 +254,7 @@ where let mut lwe: LWE> = LWE::alloc_from_infos(bits); //TODO: add TakeLWE let (mut tmp_ggsw, scratch_1) = scratch.take_ggsw(res); for (bit, dst) in res.bits.iter_mut().enumerate() { - bits.get_bit(self, bit, &mut lwe, ks, scratch_1); + bits.get_bit_lwe(self, bit, &mut lwe, ks, scratch_1); cbt.execute_to_constant(self, &mut tmp_ggsw, &lwe, 1, 1, scratch_1); dst.prepare(self, &tmp_ggsw, scratch_1); } diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared_debug.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared_debug.rs index 22fd2b6..cd7b7db 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared_debug.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared_debug.rs @@ -127,7 +127,7 @@ where { let mut lwe: LWE> = LWE::alloc_from_infos(bits); //TODO: add TakeLWE for (bit, dst) in res.bits.iter_mut().enumerate() { - bits.get_bit(self, bit, &mut lwe, &key.ks, scratch); + bits.get_bit_lwe(self, bit, &mut lwe, &key.ks, scratch); key.cbt.execute_to_constant(self, dst, &lwe, 1, 1, scratch); } } diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/fft64_ref.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/fft64_ref.rs index 7653f65..cd316d0 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/fft64_ref.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/fft64_ref.rs @@ -5,8 +5,8 @@ use poulpy_backend::FFT64Ref; use crate::tfhe::{ bdd_arithmetic::tests::test_suite::{ TestContext, test_bdd_add, test_bdd_and, test_bdd_or, test_bdd_prepare, test_bdd_sll, test_bdd_slt, test_bdd_sltu, - test_bdd_sra, test_bdd_srl, test_bdd_sub, test_bdd_xor, test_fhe_uint_sext, test_fhe_uint_splice_u8, - test_fhe_uint_splice_u16, test_glwe_blind_selection, test_glwe_to_glwe_blind_rotation, + test_bdd_sra, test_bdd_srl, test_bdd_sub, test_bdd_xor, test_fhe_uint_get_bit_glwe, test_fhe_uint_sext, + test_fhe_uint_splice_u8, test_fhe_uint_splice_u16, test_glwe_blind_selection, test_glwe_to_glwe_blind_rotation, test_scalar_to_ggsw_blind_rotation, }, blind_rotation::CGGI, @@ -15,6 +15,11 @@ use crate::tfhe::{ static TEST_CONTEXT_CGGI_FFT64_REF: LazyLock> = LazyLock::new(|| TestContext::::new()); +#[test] +fn test_fhe_uint_get_bit_glwe_fft64_ref() { + test_fhe_uint_get_bit_glwe(&TEST_CONTEXT_CGGI_FFT64_REF); +} + #[test] fn test_fhe_uint_sext_fft64_ref() { test_fhe_uint_sext(&TEST_CONTEXT_CGGI_FFT64_REF); diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/fheuint.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/fheuint.rs index 78e1de4..fc1af3f 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/fheuint.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/fheuint.rs @@ -7,10 +7,11 @@ use poulpy_hal::{ layouts::{Backend, Module, Scratch, ScratchOwned}, source::Source, }; +use rand::RngCore; use crate::tfhe::{ bdd_arithmetic::{ - BDDKeyPrepared, FheUint, ScratchTakeBDD, + BDDKeyPrepared, FheUint, ScratchTakeBDD, ToBits, tests::test_suite::{TEST_GLWE_INFOS, TestContext}, }, blind_rotation::BlindRotationAlgo, @@ -171,3 +172,40 @@ where } } } + +pub fn test_fhe_uint_get_bit_glwe(test_context: &TestContext) +where + Module: GLWEEncryptSk + GLWERotate + GLWETrace + GLWESub + GLWEAdd + GLWEDecrypt, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchTakeBDD, +{ + let glwe_infos: GLWELayout = TEST_GLWE_INFOS; + + let module: &Module = &test_context.module; + let sk: &GLWESecretPrepared, BE> = &test_context.sk_glwe; + let keys: &BDDKeyPrepared, BRA, BE> = &test_context.bdd_key; + + let mut source_xa: Source = Source::new([2u8; 32]); + let mut source_xe: Source = Source::new([3u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(1 << 22); + + let mut a_enc: FheUint, u32> = FheUint::, u32>::alloc_from_infos(&glwe_infos); + let mut c_enc: FheUint, u32> = FheUint::, u32>::alloc_from_infos(&glwe_infos); + + let a: u32 = source_xa.next_u32(); + + a_enc.encrypt_sk( + module, + a, + sk, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + for i in 0..32 { + a_enc.get_bit_glwe(module, i, &mut c_enc, keys, scratch.borrow()); + assert_eq!(a.bit(i) as u32, c_enc.decrypt(module, sk, scratch.borrow())); + } +}