Add support for blind retrieval

This commit is contained in:
Pro7ech
2025-11-15 22:41:11 +01:00
parent 28102b684f
commit b062c722a0
12 changed files with 882 additions and 25 deletions

View File

@@ -0,0 +1,221 @@
use itertools::Itertools;
use poulpy_core::{
GLWECopy, ScratchTakeCore,
layouts::{GGSWInfos, GGSWPrepared, GLWE, GLWEInfos, GLWEToMut, GLWEToRef},
};
use poulpy_hal::layouts::{Backend, Module, Scratch};
use crate::tfhe::bdd_arithmetic::{Cmux, Cswap, GetGGSWBit};
pub struct GLWEBlindRetriever {
accumulators: Vec<Accumulator>,
counter: usize,
}
impl GLWEBlindRetriever {
pub fn alloc<A>(infos: &A, size: usize) -> Self
where
A: GLWEInfos,
{
let log2_max_address: usize = (u32::BITS - (size as u32 - 1).leading_zeros()) as usize;
Self {
accumulators: (0..log2_max_address)
.map(|_| Accumulator::alloc(infos))
.collect_vec(),
counter: 0,
}
}
pub fn retrieve<M, R, A, S, BE: Backend>(
&mut self,
module: &M,
res: &mut R,
data: &[A],
selector: &S,
scratch: &mut Scratch<BE>,
) where
M: GLWECopy + Cmux<BE>,
R: GLWEToMut,
A: GLWEToRef,
S: GetGGSWBit<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
self.reset();
for ct in data {
self.add(module, ct, selector, scratch);
}
self.flush(module, res, selector, scratch);
}
pub fn add<A, S, M, BE: Backend>(&mut self, module: &M, a: &A, selector: &S, scratch: &mut Scratch<BE>)
where
A: GLWEToRef,
S: GetGGSWBit<BE>,
M: GLWECopy + Cmux<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
assert!(
(self.counter as u32) < 1 << self.accumulators.len(),
"Accumulating limit of {} reached",
1 << self.accumulators.len()
);
add_core(module, a, &mut self.accumulators, 0, selector, scratch);
self.counter += 1;
}
pub fn flush<R, M, S, BE: Backend>(&mut self, module: &M, res: &mut R, selector: &S, scratch: &mut Scratch<BE>)
where
R: GLWEToMut,
S: GetGGSWBit<BE>,
M: GLWECopy + Cmux<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
for i in 0..self.accumulators.len() - 1 {
let (acc_prev, acc_next) = self.accumulators.split_at_mut(i + 1);
if acc_prev[i].num != 0 {
add_core(
module,
&acc_prev[i].data,
acc_next,
i + 1,
selector,
scratch,
);
acc_prev[0].num = 0
}
}
module.glwe_copy(res, &self.accumulators.last().unwrap().data);
self.reset()
}
fn reset(&mut self) {
for acc in self.accumulators.iter_mut() {
acc.num = 0;
}
}
}
struct Accumulator {
data: GLWE<Vec<u8>>,
num: usize, // Number of accumulated values
}
impl Accumulator {
pub fn alloc<A>(infos: &A) -> Self
where
A: GLWEInfos,
{
Self {
data: GLWE::alloc_from_infos(infos),
num: 0,
}
}
}
fn add_core<A, S, M, BE: Backend>(
module: &M,
a: &A,
accumulators: &mut [Accumulator],
i: usize,
selector: &S,
scratch: &mut Scratch<BE>,
) where
A: GLWEToRef,
S: GetGGSWBit<BE>,
M: GLWECopy + Cmux<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
// Isolate the first accumulator
let (acc_prev, acc_next) = accumulators.split_at_mut(1);
match acc_prev[0].num {
0 => {
module.glwe_copy(&mut acc_prev[0].data, a);
acc_prev[0].num = 1;
}
1 => {
module.cmux_inplace_neg(&mut acc_prev[0].data, a, &selector.get_bit(i), scratch);
if !acc_next.is_empty() {
add_core(
module,
&acc_prev[0].data,
acc_next,
i + 1,
selector,
scratch,
);
}
acc_prev[0].num = 0
}
_ => {
panic!("something went wrong")
}
}
}
impl<BE: Backend> GLWEBlindRetrieval<BE> for Module<BE> where Self: GLWECopy + Cmux<BE> + Cswap<BE> {}
pub trait GLWEBlindRetrieval<BE: Backend>
where
Self: GLWECopy + Cmux<BE> + Cswap<BE>,
{
fn glwe_blind_retrieval_tmp_bytes<R, K>(&self, res_infos: &R, k_infos: &K) -> usize
where
R: GLWEInfos,
K: GGSWInfos,
{
self.cswap_tmp_bytes(res_infos, res_infos, k_infos)
}
fn glwe_blind_retrieval_statefull<R, K>(
&self,
res: &mut Vec<R>,
bits: &K,
bit_rsh: usize,
bit_mask: usize,
scratch: &mut Scratch<BE>,
) where
R: GLWEToMut + GLWEInfos,
K: GetGGSWBit<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
for i in 0..bit_mask {
let t: usize = 1 << (bit_mask - i - 1);
let bit: &GGSWPrepared<&[u8], BE> = &bits.get_bit(bit_rsh + bit_mask - i - 1); // MSB -> LSB traversal
for j in 0..t {
if j + t < res.len() {
let (lo, hi) = res.split_at_mut(j + t);
self.cswap(&mut lo[j], &mut hi[0], bit, scratch);
}
}
}
}
fn glwe_blind_retrieval_statefull_rev<R, K>(
&self,
res: &mut Vec<R>,
bits: &K,
bit_rsh: usize,
bit_mask: usize,
scratch: &mut Scratch<BE>,
) where
R: GLWEToMut + GLWEInfos,
K: GetGGSWBit<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
for i in (0..bit_mask).rev() {
let t: usize = 1 << (bit_mask - i - 1);
let bit: &GGSWPrepared<&[u8], BE> = &bits.get_bit(bit_rsh + bit_mask - i - 1); // MSB -> LSB traversal
for j in 0..t {
if j < res.len() && j + t < res.len() {
let (lo, hi) = res.split_at_mut(j + t);
self.cswap(&mut lo[j], &mut hi[0], bit, scratch);
}
}
}
}
}

View File

@@ -2,7 +2,7 @@ use std::collections::HashMap;
use poulpy_core::{
GLWECopy, GLWEDecrypt, ScratchTakeCore,
layouts::{GGSWInfos, GGSWPrepared, GLWE, GLWEInfos, GLWEToMut, GLWEToRef},
layouts::{GGSWInfos, GGSWPrepared, GLWE, GLWEInfos, GLWEToMut},
};
use poulpy_hal::layouts::{Backend, Module, Scratch, ZnxZero};
@@ -33,7 +33,7 @@ where
scratch: &mut Scratch<BE>,
) where
R: GLWEToMut,
A: GLWEToMut + GLWEToRef,
A: GLWEToMut,
K: GetGGSWBit<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{

View File

@@ -49,6 +49,18 @@ impl<'a, T: UnsignedInteger> FheUint<&'a mut [u8], T> {
}
}
impl<'a, T: UnsignedInteger> FheUint<&'a [u8], T> {
pub fn from_glwe_to_ref<G>(glwe: &'a G) -> Self
where
G: GLWEToRef,
{
FheUint {
bits: glwe.to_ref(),
_phantom: PhantomData,
}
}
}
impl<D: DataRef, T: UnsignedInteger> LWEInfos for FheUint<D, T> {
fn base2k(&self) -> poulpy_core::layouts::Base2K {
self.bits.base2k()
@@ -180,7 +192,7 @@ impl<D: DataMut, T: UnsignedInteger> FheUint<D, T> {
/// Packs Vec<GLWE(bit[i])> into [FheUint].
pub fn pack<G, M, K, H, BE: Backend>(&mut self, module: &M, mut bits: Vec<G>, keys: &H, scratch: &mut Scratch<BE>)
where
G: GLWEToMut + GLWEToRef + GLWEInfos,
G: GLWEToMut + GLWEInfos,
M: ModuleLogN + GLWEPacking<BE> + GLWECopy,
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
H: GLWEAutomorphismKeyHelper<K, BE>,

View File

@@ -4,14 +4,16 @@ use std::thread;
use itertools::Itertools;
use poulpy_core::{
GLWECopy, GLWEExternalProductInternal, GLWENormalize, GLWESub, ScratchTakeCore,
layouts::{GGSWInfos, GGSWPrepared, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, prepared::GGSWPreparedToRef},
layouts::{
GGSWInfos, GGSWPrepared, GLWE, GLWEInfos, GLWELayout, GLWEToMut, GLWEToRef, LWEInfos, prepared::GGSWPreparedToRef,
},
};
use poulpy_hal::{
api::{
ScratchAvailable, ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
VecZnxDftBytesOf,
ScratchAvailable, ScratchTakeBasic, VecZnxBigAddSmall, VecZnxBigAddSmallInplace, VecZnxBigBytesOf, VecZnxBigNormalize,
VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallA, VecZnxDftBytesOf,
},
layouts::{Backend, DataMut, Module, Scratch, VecZnxBig, ZnxZero},
layouts::{Backend, DataMut, Module, Scratch, VecZnxBig, ZnxInfos, ZnxZero},
};
use crate::tfhe::bdd_arithmetic::GetGGSWBit;
@@ -260,6 +262,204 @@ pub enum Node {
None,
}
impl<BE: Backend> Cswap<BE> for Module<BE> where
Self: Sized
+ GLWEExternalProductInternal<BE>
+ GLWESub
+ VecZnxBigAddSmallInplace<BE>
+ GLWENormalize<BE>
+ VecZnxDftBytesOf
+ VecZnxBigNormalize<BE>
+ VecZnxBigNormalizeTmpBytes
+ GLWENormalize<BE>
+ VecZnxBigAddSmall<BE>
+ VecZnxBigSubSmallA<BE>
+ VecZnxBigBytesOf
{
}
pub trait Cswap<BE: Backend>
where
Self: Sized
+ GLWEExternalProductInternal<BE>
+ GLWESub
+ GLWECopy
+ VecZnxBigAddSmallInplace<BE>
+ GLWENormalize<BE>
+ VecZnxDftBytesOf
+ VecZnxBigNormalize<BE>
+ VecZnxBigNormalizeTmpBytes
+ GLWENormalize<BE>
+ VecZnxBigAddSmall<BE>
+ VecZnxBigSubSmallA<BE>
+ VecZnxBigBytesOf,
{
fn cswap_tmp_bytes<R, A, S>(&self, res_a_infos: &R, res_b_infos: &A, s_infos: &S) -> usize
where
R: GLWEInfos,
A: GLWEInfos,
S: GGSWInfos,
{
let res_dft: usize = self.bytes_of_vec_znx_dft((s_infos.rank() + 1).into(), s_infos.size());
let mut tot = res_dft
+ (self.glwe_external_product_internal_tmp_bytes(res_a_infos, res_b_infos, s_infos)
+ GLWE::bytes_of_from_infos(&GLWELayout {
n: s_infos.n(),
base2k: s_infos.base2k(),
k: res_a_infos.k().max(res_b_infos.k()),
rank: s_infos.rank(),
}))
.max(self.vec_znx_big_normalize_tmp_bytes());
if res_a_infos.base2k() != s_infos.base2k() {
tot += GLWE::bytes_of_from_infos(&GLWELayout {
n: res_a_infos.n(),
base2k: s_infos.base2k(),
k: res_a_infos.k(),
rank: res_a_infos.rank(),
});
tot += GLWE::bytes_of_from_infos(&GLWELayout {
n: res_b_infos.n(),
base2k: s_infos.base2k(),
k: res_b_infos.k(),
rank: res_b_infos.rank(),
});
}
tot += self.bytes_of_vec_znx_big(1, s_infos.size());
tot
}
fn cswap<A, B, S>(&self, res_a: &mut A, res_b: &mut B, s: &S, scratch: &mut Scratch<BE>)
where
A: GLWEToMut,
B: GLWEToMut,
S: GGSWPreparedToRef<BE> + GGSWInfos,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res_a: &mut GLWE<&mut [u8]> = &mut res_a.to_mut();
let res_b: &mut GLWE<&mut [u8]> = &mut res_b.to_mut();
let s: &GGSWPrepared<&[u8], BE> = &s.to_ref();
assert_eq!(res_a.base2k(), res_b.base2k());
let res_base2k: usize = res_a.base2k().as_usize();
let s_base2k: usize = s.base2k().as_usize();
if res_base2k == s_base2k {
let res_big: VecZnxBig<&mut [u8], BE>;
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (s.rank() + 1).into(), s.size()); // Todo optimise
{
// Temporary value storing a - b
let tmp_c_infos: GLWELayout = GLWELayout {
n: s.n(),
base2k: s.base2k(),
k: res_a.k().max(res_b.k()),
rank: s.rank(),
};
let (mut tmp_c, scratch_2) = scratch_1.take_glwe(&tmp_c_infos);
self.glwe_sub(&mut tmp_c, res_b, res_a);
res_big = self.glwe_external_product_internal(res_dft, &tmp_c, s, scratch_2);
}
// Single column res_big to store temporary value before normalization
let (mut res_big_tmp, scratch_2) = scratch_1.take_vec_znx_big::<_, BE>(self, 1, res_big.size());
// res_a = (b-a) * bit + a
for j in 0..(res_a.rank() + 1).into() {
self.vec_znx_big_add_small(&mut res_big_tmp, 0, &res_big, j, res_a.data(), j);
self.vec_znx_big_normalize(
res_base2k,
res_a.data_mut(),
j,
s_base2k,
&res_big_tmp,
0,
scratch_2,
);
}
// res_b = a - (a - b) * bit = (b - a) * bit + a
for j in 0..(res_b.rank() + 1).into() {
self.vec_znx_big_sub_small_a(&mut res_big_tmp, 0, res_b.data(), j, &res_big, j);
self.vec_znx_big_normalize(
res_base2k,
res_b.data_mut(),
j,
s_base2k,
&res_big_tmp,
0,
scratch_2,
);
}
} else {
let (mut tmp_a, scratch_1) = scratch.take_glwe(&GLWELayout {
n: res_a.n(),
base2k: s.base2k(),
k: res_a.k(),
rank: res_a.rank(),
});
let (mut tmp_b, scratch_2) = scratch_1.take_glwe(&GLWELayout {
n: res_b.n(),
base2k: s.base2k(),
k: res_b.k(),
rank: res_b.rank(),
});
self.glwe_normalize(&mut tmp_a, res_a, scratch_2);
self.glwe_normalize(&mut tmp_b, res_b, scratch_2);
let res_big: VecZnxBig<&mut [u8], BE>;
let (res_dft, scratch_3) = scratch_2.take_vec_znx_dft(self, (s.rank() + 1).into(), s.size()); // Todo optimise
{
// Temporary value storing a - b
let tmp_c_infos: GLWELayout = GLWELayout {
n: s.n(),
base2k: s.base2k(),
k: res_a.k().max(res_b.k()),
rank: s.rank(),
};
let (mut tmp_c, scratch_4) = scratch_3.take_glwe(&tmp_c_infos);
self.glwe_sub(&mut tmp_c, res_b, res_a);
res_big = self.glwe_external_product_internal(res_dft, &tmp_c, s, scratch_4);
}
// Single column res_big to store temporary value before normalization
let (mut res_big_tmp, scratch_4) = scratch_3.take_vec_znx_big::<_, BE>(self, 1, res_big.size());
// res_a = (b-a) * bit + a
for j in 0..(res_a.rank() + 1).into() {
self.vec_znx_big_add_small(&mut res_big_tmp, 0, &res_big, j, tmp_a.data(), j);
self.vec_znx_big_normalize(
res_base2k,
res_a.data_mut(),
j,
s_base2k,
&res_big_tmp,
0,
scratch_4,
);
}
// res_b = a - (a - b) * bit = (b - a) * bit + a
for j in 0..(res_b.rank() + 1).into() {
self.vec_znx_big_sub_small_a(&mut res_big_tmp, 0, tmp_b.data(), j, &res_big, j);
self.vec_znx_big_normalize(
res_base2k,
res_b.data_mut(),
j,
s_base2k,
&res_big_tmp,
0,
scratch_4,
);
}
}
}
}
pub trait Cmux<BE: Backend>
where
Self: Sized
@@ -284,6 +484,7 @@ where
.max(self.vec_znx_big_normalize_tmp_bytes())
}
// res = (t - f) * s + f
fn cmux<R, T, F, S>(&self, res: &mut R, t: &T, f: &F, s: &S, scratch: &mut Scratch<BE>)
where
R: GLWEToMut,
@@ -316,6 +517,46 @@ where
}
}
// res = (a - res) * s + res
fn cmux_inplace_neg<R, A, S>(&self, res: &mut R, a: &A, s: &S, scratch: &mut Scratch<BE>)
where
R: GLWEToMut,
A: GLWEToRef,
S: GGSWPreparedToRef<BE> + GGSWInfos,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let s: &GGSWPrepared<&[u8], BE> = &s.to_ref();
let a: &GLWE<&[u8]> = &a.to_ref();
assert_eq!(res.base2k(), a.base2k());
let res_base2k: usize = res.base2k().into();
let ggsw_base2k: usize = s.base2k().into();
let (mut tmp, scratch_1) = scratch.take_glwe(&GLWELayout {
n: s.n(),
base2k: res.base2k(),
k: res.k().max(a.k()),
rank: res.rank(),
});
self.glwe_sub(&mut tmp, a, res);
let (res_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, (res.rank() + 1).into(), s.size()); // Todo optimise
let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_external_product_internal(res_dft, &tmp, s, scratch_2);
for j in 0..(res.rank() + 1).into() {
self.vec_znx_big_add_small_inplace(&mut res_big, j, res.data(), j);
self.vec_znx_big_normalize(
res_base2k,
res.data_mut(),
j,
ggsw_base2k,
&res_big,
j,
scratch_2,
);
}
}
// res = (res - a) * s + a
fn cmux_inplace<R, A, S>(&self, res: &mut R, a: &A, s: &S, scratch: &mut Scratch<BE>)
where
R: GLWEToMut,

View File

@@ -1,4 +1,5 @@
mod bdd_2w_to_1w;
mod blind_retrieval;
mod blind_rotation;
mod blind_selection;
mod ciphertexts;
@@ -7,6 +8,7 @@ mod eval;
mod key;
pub use bdd_2w_to_1w::*;
pub use blind_retrieval::*;
pub use blind_rotation::*;
pub use blind_selection::*;
pub use ciphertexts::*;

View File

@@ -6,7 +6,8 @@ use crate::tfhe::{
bdd_arithmetic::tests::test_suite::{
TestContext, test_bdd_add, test_bdd_and, test_bdd_or, test_bdd_prepare, test_bdd_sll, test_bdd_slt, test_bdd_sltu,
test_bdd_sra, test_bdd_srl, test_bdd_sub, test_bdd_xor, test_fhe_uint_get_bit_glwe, test_fhe_uint_sext,
test_fhe_uint_splice_u8, test_fhe_uint_splice_u16, test_glwe_blind_selection, test_glwe_to_glwe_blind_rotation,
test_fhe_uint_splice_u8, test_fhe_uint_splice_u16, test_fhe_uint_swap, test_glwe_blind_retrieval_statefull,
test_glwe_blind_retriever, test_glwe_blind_selection, test_glwe_to_glwe_blind_rotation,
test_scalar_to_ggsw_blind_rotation,
},
blind_rotation::CGGI,
@@ -15,6 +16,21 @@ use crate::tfhe::{
static TEST_CONTEXT_CGGI_FFT64_REF: LazyLock<TestContext<CGGI, FFT64Ref>> =
LazyLock::new(|| TestContext::<CGGI, FFT64Ref>::new());
#[test]
fn test_glwe_blind_retriever_fft64_ref() {
test_glwe_blind_retriever(&TEST_CONTEXT_CGGI_FFT64_REF);
}
#[test]
fn test_glwe_blind_retrieval_statefull_fft64_ref() {
test_glwe_blind_retrieval_statefull(&TEST_CONTEXT_CGGI_FFT64_REF);
}
#[test]
fn test_fhe_uint_swap_fft64_ref() {
test_fhe_uint_swap(&TEST_CONTEXT_CGGI_FFT64_REF);
}
#[test]
fn test_fhe_uint_get_bit_glwe_fft64_ref() {
test_fhe_uint_get_bit_glwe(&TEST_CONTEXT_CGGI_FFT64_REF);

View File

@@ -12,6 +12,7 @@ mod sltu;
mod sra;
mod srl;
mod sub;
mod swap;
mod xor;
pub use add::*;
@@ -33,6 +34,7 @@ pub use sltu::*;
pub use sra::*;
pub use srl::*;
pub use sub::*;
pub use swap::*;
pub use xor::*;
use poulpy_core::{

View File

@@ -0,0 +1,203 @@
use itertools::Itertools;
use poulpy_core::{
GGSWEncryptSk, GLWEDecrypt, GLWEEncryptSk,
layouts::{GGSW, GGSWPrepared, GGSWPreparedFactory, GLWELayout, GLWESecretPrepared},
};
use poulpy_hal::{
api::{ScratchOwnedAlloc, ScratchOwnedBorrow},
layouts::{Backend, Module, ScalarZnx, Scratch, ScratchOwned, ZnxViewMut},
source::Source,
};
use rand::RngCore;
use crate::tfhe::{
bdd_arithmetic::{
Cswap, FheUint, FheUintPrepared, GLWEBlindRetrieval, GLWEBlindRetriever, ScratchTakeBDD,
tests::test_suite::{TEST_GGSW_INFOS, TEST_GLWE_INFOS, TestContext},
},
blind_rotation::BlindRotationAlgo,
};
pub fn test_fhe_uint_swap<BRA: BlindRotationAlgo, BE: Backend>(test_context: &TestContext<BRA, BE>)
where
Module<BE>: GLWEEncryptSk<BE> + GLWEDecrypt<BE> + Cswap<BE> + GGSWEncryptSk<BE> + GGSWPreparedFactory<BE>,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchTakeBDD<u32, BE>,
{
let glwe_infos: GLWELayout = TEST_GLWE_INFOS;
let ggsw_infos: poulpy_core::layouts::GGSWLayout = TEST_GGSW_INFOS;
let module: &Module<BE> = &test_context.module;
let sk: &GLWESecretPrepared<Vec<u8>, BE> = &test_context.sk_glwe;
let mut source_xa: Source = Source::new([2u8; 32]);
let mut source_xe: Source = Source::new([3u8; 32]);
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(1 << 22);
let mut s: GGSW<Vec<u8>> = GGSW::alloc_from_infos(&ggsw_infos);
let mut s_prepared: GGSWPrepared<Vec<u8>, BE> = GGSWPrepared::alloc_from_infos(module, &ggsw_infos);
let a: u32 = source_xa.next_u32();
let b: u32 = source_xa.next_u32();
for bit in [0, 1] {
let mut a_enc: FheUint<Vec<u8>, u32> = FheUint::<Vec<u8>, u32>::alloc_from_infos(&glwe_infos);
let mut b_enc: FheUint<Vec<u8>, u32> = FheUint::<Vec<u8>, u32>::alloc_from_infos(&glwe_infos);
a_enc.encrypt_sk(
module,
a,
sk,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
b_enc.encrypt_sk(
module,
b,
sk,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
let mut pt: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(module.n(), 1);
pt.raw_mut()[0] = bit;
s.encrypt_sk(
module,
&pt,
sk,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
s_prepared.prepare(module, &s, scratch.borrow());
module.cswap(&mut a_enc, &mut b_enc, &s_prepared, scratch.borrow());
let (a_want, b_want) = if bit == 0 { (a, b) } else { (b, a) };
assert_eq!(a_want, a_enc.decrypt(module, sk, scratch.borrow()));
assert_eq!(b_want, b_enc.decrypt(module, sk, scratch.borrow()));
}
}
pub fn test_glwe_blind_retrieval_statefull<BRA: BlindRotationAlgo, BE: Backend>(test_context: &TestContext<BRA, BE>)
where
Module<BE>: GLWEEncryptSk<BE> + GLWEDecrypt<BE> + GLWEBlindRetrieval<BE> + GGSWEncryptSk<BE> + GGSWPreparedFactory<BE>,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchTakeBDD<u32, BE>,
{
let glwe_infos: GLWELayout = TEST_GLWE_INFOS;
let ggsw_infos: poulpy_core::layouts::GGSWLayout = TEST_GGSW_INFOS;
let module: &Module<BE> = &test_context.module;
let sk: &GLWESecretPrepared<Vec<u8>, BE> = &test_context.sk_glwe;
let mut source_xa: Source = Source::new([2u8; 32]);
let mut source_xe: Source = Source::new([3u8; 32]);
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(1 << 22);
let data: Vec<u32> = (0..25).map(|i| i as u32).collect_vec();
let mut data_enc: Vec<FheUint<Vec<u8>, u32>> = (0..data.len())
.map(|i| {
let mut ct: FheUint<Vec<u8>, u32> = FheUint::<Vec<u8>, u32>::alloc_from_infos(&glwe_infos);
ct.encrypt_sk(
module,
data[i],
sk,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
ct
})
.collect_vec();
for idx in 0..data.len() as u32 {
let mut idx_enc = FheUintPrepared::alloc_from_infos(module, &ggsw_infos);
idx_enc.encrypt_sk(
module,
idx,
sk,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
module.glwe_blind_retrieval_statefull(&mut data_enc, &idx_enc, 0, 5, scratch.borrow());
assert_eq!(
data[idx as usize],
data_enc[0].decrypt(module, sk, scratch.borrow())
);
module.glwe_blind_retrieval_statefull_rev(&mut data_enc, &idx_enc, 0, 5, scratch.borrow());
for i in 0..data.len() {
assert_eq!(data[i], data_enc[i].decrypt(module, sk, scratch.borrow()))
}
}
}
pub fn test_glwe_blind_retriever<BRA: BlindRotationAlgo, BE: Backend>(test_context: &TestContext<BRA, BE>)
where
Module<BE>: GLWEEncryptSk<BE> + GLWEDecrypt<BE> + GLWEBlindRetrieval<BE> + GGSWEncryptSk<BE> + GGSWPreparedFactory<BE>,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchTakeBDD<u32, BE>,
{
let glwe_infos: GLWELayout = TEST_GLWE_INFOS;
let ggsw_infos: poulpy_core::layouts::GGSWLayout = TEST_GGSW_INFOS;
let module: &Module<BE> = &test_context.module;
let sk: &GLWESecretPrepared<Vec<u8>, BE> = &test_context.sk_glwe;
let mut source_xa: Source = Source::new([2u8; 32]);
let mut source_xe: Source = Source::new([3u8; 32]);
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(1 << 22);
let data: Vec<u32> = (0..25).map(|i| i as u32).collect_vec();
let data_enc: Vec<FheUint<Vec<u8>, u32>> = (0..data.len())
.map(|i| {
let mut ct: FheUint<Vec<u8>, u32> = FheUint::<Vec<u8>, u32>::alloc_from_infos(&glwe_infos);
ct.encrypt_sk(
module,
data[i],
sk,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
ct
})
.collect_vec();
for idx in 0..data.len() as u32 {
let mut idx_enc: FheUintPrepared<Vec<u8>, u32, BE> = FheUintPrepared::alloc_from_infos(module, &ggsw_infos);
idx_enc.encrypt_sk(
module,
idx,
sk,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
let mut retriever: GLWEBlindRetriever = GLWEBlindRetriever::alloc(&glwe_infos, 25);
let mut res: FheUint<Vec<u8>, u32> = FheUint::alloc_from_infos(&glwe_infos);
retriever.retrieve(module, &mut res, &data_enc, &idx_enc, scratch.borrow());
println!("{}", res.decrypt(module, sk, scratch.borrow()));
assert_eq!(
data[idx as usize],
res.decrypt(module, sk, scratch.borrow())
);
}
}