mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
Add glwe blind selection
This commit is contained in:
90
poulpy-schemes/src/tfhe/bdd_arithmetic/blind_selection.rs
Normal file
90
poulpy-schemes/src/tfhe/bdd_arithmetic/blind_selection.rs
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use poulpy_core::{
|
||||||
|
GLWECopy, GLWEDecrypt, ScratchTakeCore,
|
||||||
|
layouts::{GGSWInfos, GGSWPrepared, GLWE, GLWEInfos, GLWEToMut, GLWEToRef},
|
||||||
|
};
|
||||||
|
use poulpy_hal::layouts::{Backend, Module, Scratch, ZnxZero};
|
||||||
|
|
||||||
|
use crate::tfhe::bdd_arithmetic::{Cmux, GetGGSWBit, UnsignedInteger};
|
||||||
|
|
||||||
|
impl<T: UnsignedInteger, BE: Backend> GLWEBlinSelection<T, BE> for Module<BE> where Self: GLWECopy + Cmux<BE> + GLWEDecrypt<BE> {}
|
||||||
|
|
||||||
|
pub trait GLWEBlinSelection<T: UnsignedInteger, BE: Backend>
|
||||||
|
where
|
||||||
|
Self: GLWECopy + Cmux<BE> + GLWEDecrypt<BE>,
|
||||||
|
{
|
||||||
|
fn glwe_blind_selection_tmp_bytes<R, K>(&self, res_infos: &R, k_infos: &K) -> usize
|
||||||
|
where
|
||||||
|
R: GLWEInfos,
|
||||||
|
K: GGSWInfos,
|
||||||
|
{
|
||||||
|
self.cmux_tmp_bytes(res_infos, res_infos, k_infos) + GLWE::bytes_of_from_infos(res_infos)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn glwe_blind_selection<R, A, K>(
|
||||||
|
&self,
|
||||||
|
res: &mut R,
|
||||||
|
mut a: HashMap<usize, &mut A>,
|
||||||
|
fhe_uint: &K,
|
||||||
|
bit_rsh: usize,
|
||||||
|
bit_mask: usize,
|
||||||
|
scratch: &mut Scratch<BE>,
|
||||||
|
) where
|
||||||
|
R: GLWEToMut,
|
||||||
|
A: GLWEToMut + GLWEToRef,
|
||||||
|
K: GetGGSWBit<BE>,
|
||||||
|
Scratch<BE>: ScratchTakeCore<BE>,
|
||||||
|
{
|
||||||
|
assert!(bit_rsh + bit_mask <= T::BITS as usize);
|
||||||
|
|
||||||
|
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
|
||||||
|
|
||||||
|
for i in 0..bit_mask {
|
||||||
|
let t: usize = 1 << (bit_mask - i - 1);
|
||||||
|
|
||||||
|
let bit: &GGSWPrepared<&[u8], BE> = &fhe_uint.get_bit(bit_rsh + bit_mask - i - 1); // MSB -> LSB traversal
|
||||||
|
|
||||||
|
for j in 0..t {
|
||||||
|
let hi: Option<&mut A> = a.remove(&j);
|
||||||
|
let lo: Option<&mut A> = a.remove(&(j + t));
|
||||||
|
|
||||||
|
match (lo, hi) {
|
||||||
|
(Some(lo), Some(hi)) => {
|
||||||
|
self.cmux_inplace(lo, hi, bit, scratch);
|
||||||
|
a.insert(j, lo);
|
||||||
|
}
|
||||||
|
|
||||||
|
(Some(lo), None) => {
|
||||||
|
let (mut zero, scratch_1) = scratch.take_glwe(res);
|
||||||
|
zero.data_mut().zero();
|
||||||
|
self.cmux_inplace(lo, &zero, bit, scratch_1);
|
||||||
|
a.insert(j, lo);
|
||||||
|
}
|
||||||
|
|
||||||
|
(None, Some(hi)) => {
|
||||||
|
let (mut zero, scratch_1) = scratch.take_glwe(res);
|
||||||
|
zero.data_mut().zero();
|
||||||
|
self.cmux_inplace(&mut zero, hi, bit, scratch_1);
|
||||||
|
self.glwe_copy(hi, &zero);
|
||||||
|
a.insert(j, hi);
|
||||||
|
}
|
||||||
|
|
||||||
|
(None, None) => {
|
||||||
|
// No low or high branch — nothing to insert
|
||||||
|
// leave empty; future iterations will combine actual ciphertexts
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let out: Option<&mut A> = a.remove(&0);
|
||||||
|
|
||||||
|
if let Some(out) = out {
|
||||||
|
self.glwe_copy(res, out);
|
||||||
|
} else {
|
||||||
|
res.data_mut().zero();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
mod bdd_2w_to_1w;
|
mod bdd_2w_to_1w;
|
||||||
mod blind_rotation;
|
mod blind_rotation;
|
||||||
|
mod blind_selection;
|
||||||
mod ciphertexts;
|
mod ciphertexts;
|
||||||
mod circuits;
|
mod circuits;
|
||||||
mod eval;
|
mod eval;
|
||||||
@@ -7,6 +8,7 @@ mod key;
|
|||||||
|
|
||||||
pub use bdd_2w_to_1w::*;
|
pub use bdd_2w_to_1w::*;
|
||||||
pub use blind_rotation::*;
|
pub use blind_rotation::*;
|
||||||
|
pub use blind_selection::*;
|
||||||
pub use ciphertexts::*;
|
pub use ciphertexts::*;
|
||||||
pub(crate) use circuits::*;
|
pub(crate) use circuits::*;
|
||||||
pub use eval::*;
|
pub use eval::*;
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ use crate::tfhe::{
|
|||||||
bdd_arithmetic::tests::test_suite::{
|
bdd_arithmetic::tests::test_suite::{
|
||||||
TestContext, test_bdd_add, test_bdd_and, test_bdd_or, test_bdd_prepare, test_bdd_sll, test_bdd_slt, test_bdd_sltu,
|
TestContext, test_bdd_add, test_bdd_and, test_bdd_or, test_bdd_prepare, test_bdd_sll, test_bdd_slt, test_bdd_sltu,
|
||||||
test_bdd_sra, test_bdd_srl, test_bdd_sub, test_bdd_xor, test_fhe_uint_splice_u8, test_fhe_uint_splice_u16,
|
test_bdd_sra, test_bdd_srl, test_bdd_sub, test_bdd_xor, test_fhe_uint_splice_u8, test_fhe_uint_splice_u16,
|
||||||
test_glwe_to_glwe_blind_rotation, test_scalar_to_ggsw_blind_rotation,
|
test_glwe_blind_selection, test_glwe_to_glwe_blind_rotation, test_scalar_to_ggsw_blind_rotation,
|
||||||
},
|
},
|
||||||
blind_rotation::CGGI,
|
blind_rotation::CGGI,
|
||||||
};
|
};
|
||||||
@@ -14,6 +14,11 @@ use crate::tfhe::{
|
|||||||
static TEST_CONTEXT_CGGI_FFT64_REF: LazyLock<TestContext<CGGI, FFT64Ref>> =
|
static TEST_CONTEXT_CGGI_FFT64_REF: LazyLock<TestContext<CGGI, FFT64Ref>> =
|
||||||
LazyLock::new(|| TestContext::<CGGI, FFT64Ref>::new());
|
LazyLock::new(|| TestContext::<CGGI, FFT64Ref>::new());
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_glwe_blind_selection_fft64_ref() {
|
||||||
|
test_glwe_blind_selection(&TEST_CONTEXT_CGGI_FFT64_REF)
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_fhe_uint_splice_u8_fft64_ref() {
|
fn test_fhe_uint_splice_u8_fft64_ref() {
|
||||||
test_fhe_uint_splice_u8(&TEST_CONTEXT_CGGI_FFT64_REF)
|
test_fhe_uint_splice_u8(&TEST_CONTEXT_CGGI_FFT64_REF)
|
||||||
|
|||||||
@@ -0,0 +1,147 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use poulpy_core::{
|
||||||
|
GGSWEncryptSk, GLWEDecrypt, GLWEEncryptSk, ScratchTakeCore,
|
||||||
|
layouts::{
|
||||||
|
Base2K, Dnum, Dsize, GGSWLayout, GGSWPreparedFactory, GLWE, GLWELayout, GLWEPlaintext, GLWESecretPrepared,
|
||||||
|
GLWESecretPreparedFactory, Rank, TorusPrecision,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
use poulpy_hal::{
|
||||||
|
api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow},
|
||||||
|
layouts::{Backend, Module, Scratch, ScratchOwned},
|
||||||
|
source::Source,
|
||||||
|
};
|
||||||
|
use rand::RngCore;
|
||||||
|
|
||||||
|
use crate::tfhe::{
|
||||||
|
bdd_arithmetic::{
|
||||||
|
FheUintPrepared, GLWEBlinSelection,
|
||||||
|
tests::test_suite::{TEST_BASE2K, TEST_RANK, TestContext},
|
||||||
|
},
|
||||||
|
blind_rotation::BlindRotationAlgo,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn test_glwe_blind_selection<BRA: BlindRotationAlgo, BE: Backend>(test_context: &TestContext<BRA, BE>)
|
||||||
|
where
|
||||||
|
Module<BE>: ModuleNew<BE>
|
||||||
|
+ GLWESecretPreparedFactory<BE>
|
||||||
|
+ GGSWPreparedFactory<BE>
|
||||||
|
+ GGSWEncryptSk<BE>
|
||||||
|
+ GLWEBlinSelection<u32, BE>
|
||||||
|
+ GLWEDecrypt<BE>
|
||||||
|
+ GLWEEncryptSk<BE>,
|
||||||
|
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
|
||||||
|
Scratch<BE>: ScratchTakeCore<BE>,
|
||||||
|
{
|
||||||
|
let module: &Module<BE> = &test_context.module;
|
||||||
|
let sk_glwe_prep: &GLWESecretPrepared<Vec<u8>, BE> = &test_context.sk_glwe;
|
||||||
|
|
||||||
|
let base2k: Base2K = TEST_BASE2K.into();
|
||||||
|
let rank: Rank = TEST_RANK.into();
|
||||||
|
let k_glwe: TorusPrecision = TorusPrecision(26);
|
||||||
|
let k_ggsw: TorusPrecision = TorusPrecision(39);
|
||||||
|
let dnum: Dnum = Dnum(3);
|
||||||
|
|
||||||
|
let glwe_infos: GLWELayout = GLWELayout {
|
||||||
|
n: module.n().into(),
|
||||||
|
base2k,
|
||||||
|
k: k_glwe,
|
||||||
|
rank,
|
||||||
|
};
|
||||||
|
let ggsw_infos: GGSWLayout = GGSWLayout {
|
||||||
|
n: module.n().into(),
|
||||||
|
base2k,
|
||||||
|
k: k_ggsw,
|
||||||
|
rank,
|
||||||
|
dnum,
|
||||||
|
dsize: Dsize(1),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut source: Source = Source::new([6u8; 32]);
|
||||||
|
let mut source_xa: Source = Source::new([2u8; 32]);
|
||||||
|
let mut source_xe: Source = Source::new([3u8; 32]);
|
||||||
|
|
||||||
|
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(1 << 22);
|
||||||
|
|
||||||
|
let mut res: GLWE<Vec<u8>> = GLWE::alloc_from_infos(&glwe_infos);
|
||||||
|
|
||||||
|
let k: u32 = source.next_u32();
|
||||||
|
|
||||||
|
let mut k_enc_prep: FheUintPrepared<Vec<u8>, u32, BE> =
|
||||||
|
FheUintPrepared::<Vec<u8>, u32, BE>::alloc_from_infos(module, &ggsw_infos);
|
||||||
|
k_enc_prep.encrypt_sk(
|
||||||
|
module,
|
||||||
|
k,
|
||||||
|
sk_glwe_prep,
|
||||||
|
&mut source_xa,
|
||||||
|
&mut source_xe,
|
||||||
|
scratch.borrow(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let digit = 5;
|
||||||
|
let mask: u32 = (1 << digit) - 1;
|
||||||
|
|
||||||
|
// Starting bit
|
||||||
|
let mut bit_start: usize = 0;
|
||||||
|
|
||||||
|
let mut data = vec![0i64; 1 << digit];
|
||||||
|
data.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
|
||||||
|
|
||||||
|
for _ in 0..32_usize.div_ceil(digit) {
|
||||||
|
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc_from_infos(&glwe_infos);
|
||||||
|
|
||||||
|
let mut cts_map: HashMap<usize, &mut GLWE<Vec<u8>>> = HashMap::new();
|
||||||
|
let mut cts: Vec<GLWE<Vec<u8>>> = Vec::new();
|
||||||
|
|
||||||
|
for value in data.iter().take(1 << digit) {
|
||||||
|
pt.encode_coeff_i64(*value, TorusPrecision(base2k.as_u32()), 0);
|
||||||
|
let mut ct = GLWE::alloc_from_infos(&glwe_infos);
|
||||||
|
ct.encrypt_sk(
|
||||||
|
module,
|
||||||
|
&pt,
|
||||||
|
sk_glwe_prep,
|
||||||
|
&mut source_xa,
|
||||||
|
&mut source_xe,
|
||||||
|
scratch.borrow(),
|
||||||
|
);
|
||||||
|
cts.push(ct);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (i, ct) in cts.iter_mut().enumerate() {
|
||||||
|
if i.is_multiple_of(3) {
|
||||||
|
cts_map.insert(i, ct);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// How many bits to take
|
||||||
|
let bit_size: usize = (32 - bit_start).min(digit);
|
||||||
|
|
||||||
|
module.glwe_blind_selection(
|
||||||
|
&mut res,
|
||||||
|
cts_map,
|
||||||
|
&k_enc_prep,
|
||||||
|
bit_start,
|
||||||
|
bit_size,
|
||||||
|
scratch.borrow(),
|
||||||
|
);
|
||||||
|
|
||||||
|
res.decrypt(module, &mut pt, sk_glwe_prep, scratch.borrow());
|
||||||
|
|
||||||
|
let idx = ((k >> bit_start) & mask) as usize;
|
||||||
|
if idx.is_multiple_of(3) {
|
||||||
|
assert_eq!(0, pt.decode_coeff_i64(TorusPrecision(base2k.as_u32()), 0));
|
||||||
|
} else {
|
||||||
|
assert_eq!(
|
||||||
|
data[idx],
|
||||||
|
pt.decode_coeff_i64(TorusPrecision(base2k.as_u32()), 0)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
bit_start += digit;
|
||||||
|
|
||||||
|
if bit_start >= 32 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,6 +3,7 @@ mod and;
|
|||||||
mod fheuint;
|
mod fheuint;
|
||||||
mod ggsw_blind_rotations;
|
mod ggsw_blind_rotations;
|
||||||
mod glwe_blind_rotation;
|
mod glwe_blind_rotation;
|
||||||
|
mod glwe_blind_selection;
|
||||||
mod or;
|
mod or;
|
||||||
mod prepare;
|
mod prepare;
|
||||||
mod sll;
|
mod sll;
|
||||||
@@ -18,6 +19,7 @@ pub use and::*;
|
|||||||
pub use fheuint::*;
|
pub use fheuint::*;
|
||||||
pub use ggsw_blind_rotations::*;
|
pub use ggsw_blind_rotations::*;
|
||||||
pub use glwe_blind_rotation::*;
|
pub use glwe_blind_rotation::*;
|
||||||
|
pub use glwe_blind_selection::*;
|
||||||
pub use or::*;
|
pub use or::*;
|
||||||
use poulpy_hal::{
|
use poulpy_hal::{
|
||||||
api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow},
|
api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow},
|
||||||
|
|||||||
Reference in New Issue
Block a user