diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/blind_selection.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/blind_selection.rs new file mode 100644 index 0000000..3242b2e --- /dev/null +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/blind_selection.rs @@ -0,0 +1,90 @@ +use std::collections::HashMap; + +use poulpy_core::{ + GLWECopy, GLWEDecrypt, ScratchTakeCore, + layouts::{GGSWInfos, GGSWPrepared, GLWE, GLWEInfos, GLWEToMut, GLWEToRef}, +}; +use poulpy_hal::layouts::{Backend, Module, Scratch, ZnxZero}; + +use crate::tfhe::bdd_arithmetic::{Cmux, GetGGSWBit, UnsignedInteger}; + +impl GLWEBlinSelection for Module where Self: GLWECopy + Cmux + GLWEDecrypt {} + +pub trait GLWEBlinSelection +where + Self: GLWECopy + Cmux + GLWEDecrypt, +{ + fn glwe_blind_selection_tmp_bytes(&self, res_infos: &R, k_infos: &K) -> usize + where + R: GLWEInfos, + K: GGSWInfos, + { + self.cmux_tmp_bytes(res_infos, res_infos, k_infos) + GLWE::bytes_of_from_infos(res_infos) + } + + #[allow(clippy::too_many_arguments)] + fn glwe_blind_selection( + &self, + res: &mut R, + mut a: HashMap, + fhe_uint: &K, + bit_rsh: usize, + bit_mask: usize, + scratch: &mut Scratch, + ) where + R: GLWEToMut, + A: GLWEToMut + GLWEToRef, + K: GetGGSWBit, + Scratch: ScratchTakeCore, + { + assert!(bit_rsh + bit_mask <= T::BITS as usize); + + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + + for i in 0..bit_mask { + let t: usize = 1 << (bit_mask - i - 1); + + let bit: &GGSWPrepared<&[u8], BE> = &fhe_uint.get_bit(bit_rsh + bit_mask - i - 1); // MSB -> LSB traversal + + for j in 0..t { + let hi: Option<&mut A> = a.remove(&j); + let lo: Option<&mut A> = a.remove(&(j + t)); + + match (lo, hi) { + (Some(lo), Some(hi)) => { + self.cmux_inplace(lo, hi, bit, scratch); + a.insert(j, lo); + } + + (Some(lo), None) => { + let (mut zero, scratch_1) = scratch.take_glwe(res); + zero.data_mut().zero(); + self.cmux_inplace(lo, &zero, bit, scratch_1); + a.insert(j, lo); + } + + (None, Some(hi)) => { + let (mut zero, scratch_1) = scratch.take_glwe(res); + zero.data_mut().zero(); + self.cmux_inplace(&mut zero, hi, bit, scratch_1); + self.glwe_copy(hi, &zero); + a.insert(j, hi); + } + + (None, None) => { + // No low or high branch — nothing to insert + // leave empty; future iterations will combine actual ciphertexts + } + } + } + } + + let out: Option<&mut A> = a.remove(&0); + + if let Some(out) = out { + self.glwe_copy(res, out); + } else { + res.data_mut().zero(); + } + } +} diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/mod.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/mod.rs index 5b3661f..e3fb490 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/mod.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/mod.rs @@ -1,5 +1,6 @@ mod bdd_2w_to_1w; mod blind_rotation; +mod blind_selection; mod ciphertexts; mod circuits; mod eval; @@ -7,6 +8,7 @@ mod key; pub use bdd_2w_to_1w::*; pub use blind_rotation::*; +pub use blind_selection::*; pub use ciphertexts::*; pub(crate) use circuits::*; pub use eval::*; diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/fft64_ref.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/fft64_ref.rs index e4a6de4..2238266 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/fft64_ref.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/fft64_ref.rs @@ -6,7 +6,7 @@ use crate::tfhe::{ bdd_arithmetic::tests::test_suite::{ TestContext, test_bdd_add, test_bdd_and, test_bdd_or, test_bdd_prepare, test_bdd_sll, test_bdd_slt, test_bdd_sltu, test_bdd_sra, test_bdd_srl, test_bdd_sub, test_bdd_xor, test_fhe_uint_splice_u8, test_fhe_uint_splice_u16, - test_glwe_to_glwe_blind_rotation, test_scalar_to_ggsw_blind_rotation, + test_glwe_blind_selection, test_glwe_to_glwe_blind_rotation, test_scalar_to_ggsw_blind_rotation, }, blind_rotation::CGGI, }; @@ -14,6 +14,11 @@ use crate::tfhe::{ static TEST_CONTEXT_CGGI_FFT64_REF: LazyLock> = LazyLock::new(|| TestContext::::new()); +#[test] +fn test_glwe_blind_selection_fft64_ref() { + test_glwe_blind_selection(&TEST_CONTEXT_CGGI_FFT64_REF) +} + #[test] fn test_fhe_uint_splice_u8_fft64_ref() { test_fhe_uint_splice_u8(&TEST_CONTEXT_CGGI_FFT64_REF) diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/glwe_blind_selection.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/glwe_blind_selection.rs new file mode 100644 index 0000000..7bbe62a --- /dev/null +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/glwe_blind_selection.rs @@ -0,0 +1,147 @@ +use std::collections::HashMap; + +use poulpy_core::{ + GGSWEncryptSk, GLWEDecrypt, GLWEEncryptSk, ScratchTakeCore, + layouts::{ + Base2K, Dnum, Dsize, GGSWLayout, GGSWPreparedFactory, GLWE, GLWELayout, GLWEPlaintext, GLWESecretPrepared, + GLWESecretPreparedFactory, Rank, TorusPrecision, + }, +}; +use poulpy_hal::{ + api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow}, + layouts::{Backend, Module, Scratch, ScratchOwned}, + source::Source, +}; +use rand::RngCore; + +use crate::tfhe::{ + bdd_arithmetic::{ + FheUintPrepared, GLWEBlinSelection, + tests::test_suite::{TEST_BASE2K, TEST_RANK, TestContext}, + }, + blind_rotation::BlindRotationAlgo, +}; + +pub fn test_glwe_blind_selection(test_context: &TestContext) +where + Module: ModuleNew + + GLWESecretPreparedFactory + + GGSWPreparedFactory + + GGSWEncryptSk + + GLWEBlinSelection + + GLWEDecrypt + + GLWEEncryptSk, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchTakeCore, +{ + let module: &Module = &test_context.module; + let sk_glwe_prep: &GLWESecretPrepared, BE> = &test_context.sk_glwe; + + let base2k: Base2K = TEST_BASE2K.into(); + let rank: Rank = TEST_RANK.into(); + let k_glwe: TorusPrecision = TorusPrecision(26); + let k_ggsw: TorusPrecision = TorusPrecision(39); + let dnum: Dnum = Dnum(3); + + let glwe_infos: GLWELayout = GLWELayout { + n: module.n().into(), + base2k, + k: k_glwe, + rank, + }; + let ggsw_infos: GGSWLayout = GGSWLayout { + n: module.n().into(), + base2k, + k: k_ggsw, + rank, + dnum, + dsize: Dsize(1), + }; + + let mut source: Source = Source::new([6u8; 32]); + let mut source_xa: Source = Source::new([2u8; 32]); + let mut source_xe: Source = Source::new([3u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(1 << 22); + + let mut res: GLWE> = GLWE::alloc_from_infos(&glwe_infos); + + let k: u32 = source.next_u32(); + + let mut k_enc_prep: FheUintPrepared, u32, BE> = + FheUintPrepared::, u32, BE>::alloc_from_infos(module, &ggsw_infos); + k_enc_prep.encrypt_sk( + module, + k, + sk_glwe_prep, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + let digit = 5; + let mask: u32 = (1 << digit) - 1; + + // Starting bit + let mut bit_start: usize = 0; + + let mut data = vec![0i64; 1 << digit]; + data.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); + + for _ in 0..32_usize.div_ceil(digit) { + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_infos); + + let mut cts_map: HashMap>> = HashMap::new(); + let mut cts: Vec>> = Vec::new(); + + for value in data.iter().take(1 << digit) { + pt.encode_coeff_i64(*value, TorusPrecision(base2k.as_u32()), 0); + let mut ct = GLWE::alloc_from_infos(&glwe_infos); + ct.encrypt_sk( + module, + &pt, + sk_glwe_prep, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + cts.push(ct); + } + + for (i, ct) in cts.iter_mut().enumerate() { + if i.is_multiple_of(3) { + cts_map.insert(i, ct); + } + } + + // How many bits to take + let bit_size: usize = (32 - bit_start).min(digit); + + module.glwe_blind_selection( + &mut res, + cts_map, + &k_enc_prep, + bit_start, + bit_size, + scratch.borrow(), + ); + + res.decrypt(module, &mut pt, sk_glwe_prep, scratch.borrow()); + + let idx = ((k >> bit_start) & mask) as usize; + if idx.is_multiple_of(3) { + assert_eq!(0, pt.decode_coeff_i64(TorusPrecision(base2k.as_u32()), 0)); + } else { + assert_eq!( + data[idx], + pt.decode_coeff_i64(TorusPrecision(base2k.as_u32()), 0) + ); + } + + bit_start += digit; + + if bit_start >= 32 { + break; + } + } +} diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs index 53a02ae..6a51a97 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs @@ -3,6 +3,7 @@ mod and; mod fheuint; mod ggsw_blind_rotations; mod glwe_blind_rotation; +mod glwe_blind_selection; mod or; mod prepare; mod sll; @@ -18,6 +19,7 @@ pub use and::*; pub use fheuint::*; pub use ggsw_blind_rotations::*; pub use glwe_blind_rotation::*; +pub use glwe_blind_selection::*; pub use or::*; use poulpy_hal::{ api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow},