mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
Add support for blind retrieval
This commit is contained in:
@@ -1,16 +1,16 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
ModuleN, ScratchTakeBasic, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftApply, VecZnxDftBytesOf,
|
||||
VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd,
|
||||
VmpApplyDftToDftTmpBytes,
|
||||
ModuleN, ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftApply,
|
||||
VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft,
|
||||
VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
ScratchTakeCore,
|
||||
GLWENormalize, ScratchTakeCore,
|
||||
layouts::{
|
||||
GGSWInfos, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos,
|
||||
GGSWInfos, GLWE, GLWEInfos, GLWELayout, GLWEToMut, GLWEToRef, LWEInfos,
|
||||
prepared::{GGSWPrepared, GGSWPreparedToRef},
|
||||
},
|
||||
};
|
||||
@@ -67,11 +67,22 @@ pub trait GLWEExternalProduct<BE: Backend> {
|
||||
A: GLWEToRef,
|
||||
D: GGSWPreparedToRef<BE> + GGSWInfos,
|
||||
Scratch<BE>: ScratchTakeCore<BE>;
|
||||
fn glwe_external_product_add<R, A, D>(&self, res: &mut R, lhs: &A, rhs: &D, scratch: &mut Scratch<BE>)
|
||||
where
|
||||
R: GLWEToMut,
|
||||
A: GLWEToRef,
|
||||
D: GGSWPreparedToRef<BE> + GGSWInfos,
|
||||
Scratch<BE>: ScratchTakeCore<BE>;
|
||||
}
|
||||
|
||||
impl<BE: Backend> GLWEExternalProduct<BE> for Module<BE>
|
||||
where
|
||||
Self: GLWEExternalProductInternal<BE> + VecZnxDftBytesOf + VecZnxBigNormalize<BE> + VecZnxBigNormalizeTmpBytes,
|
||||
Self: GLWEExternalProductInternal<BE>
|
||||
+ VecZnxDftBytesOf
|
||||
+ VecZnxBigNormalize<BE>
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VecZnxBigAddSmallInplace<BE>
|
||||
+ GLWENormalize<BE>,
|
||||
{
|
||||
fn glwe_external_product_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
|
||||
where
|
||||
@@ -163,6 +174,80 @@ where
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn glwe_external_product_add<R, A, D>(&self, res: &mut R, a: &A, key: &D, scratch: &mut Scratch<BE>)
|
||||
where
|
||||
R: GLWEToMut,
|
||||
A: GLWEToRef,
|
||||
D: GGSWPreparedToRef<BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
|
||||
let a: &GLWE<&[u8]> = &a.to_ref();
|
||||
let key: &GGSWPrepared<&[u8], BE> = &key.to_ref();
|
||||
|
||||
assert_eq!(a.base2k(), res.base2k());
|
||||
|
||||
let res_base2k: usize = res.base2k().into();
|
||||
let key_base2k: usize = key.base2k().into();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
use poulpy_hal::api::ScratchAvailable;
|
||||
|
||||
assert_eq!(key.rank(), a.rank());
|
||||
assert_eq!(key.rank(), res.rank());
|
||||
assert_eq!(key.n(), res.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, a, key));
|
||||
}
|
||||
|
||||
if res_base2k == key_base2k {
|
||||
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // Todo optimise
|
||||
let mut res_big = self.glwe_external_product_internal(res_dft, a, key, scratch_1);
|
||||
for j in 0..(res.rank() + 1).into() {
|
||||
self.vec_znx_big_add_small_inplace(&mut res_big, j, res.data(), j);
|
||||
self.vec_znx_big_normalize(
|
||||
res_base2k,
|
||||
&mut res.data,
|
||||
j,
|
||||
key_base2k,
|
||||
&res_big,
|
||||
j,
|
||||
scratch_1,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
let (mut a_conv, scratch_1) = scratch.take_glwe(&GLWELayout {
|
||||
n: a.n(),
|
||||
base2k: key.base2k(),
|
||||
k: a.k(),
|
||||
rank: a.rank(),
|
||||
});
|
||||
let (mut res_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout {
|
||||
n: res.n(),
|
||||
base2k: key.base2k(),
|
||||
k: res.k(),
|
||||
rank: res.rank(),
|
||||
});
|
||||
self.glwe_normalize(&mut a_conv, a, scratch_2);
|
||||
self.glwe_normalize(&mut res_conv, res, scratch_2);
|
||||
let (res_dft, scratch_2) = scratch_2.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // Todo optimise
|
||||
let mut res_big = self.glwe_external_product_internal(res_dft, &a_conv, key, scratch_2);
|
||||
for j in 0..(res.rank() + 1).into() {
|
||||
self.vec_znx_big_add_small_inplace(&mut res_big, j, res_conv.data(), j);
|
||||
self.vec_znx_big_normalize(
|
||||
res_base2k,
|
||||
&mut res.data,
|
||||
j,
|
||||
key_base2k,
|
||||
&res_big,
|
||||
j,
|
||||
scratch_2,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait GLWEExternalProductInternal<BE: Backend> {
|
||||
|
||||
@@ -7,7 +7,7 @@ use poulpy_hal::{
|
||||
|
||||
use crate::{
|
||||
GLWEAdd, GLWEAutomorphism, GLWECopy, GLWENormalize, GLWERotate, GLWEShift, GLWESub, GLWETrace, ScratchTakeCore,
|
||||
layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWEAutomorphismKeyHelper, GLWEInfos, GLWEToMut, GLWEToRef, GetGaloisElement},
|
||||
layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWEAutomorphismKeyHelper, GLWEInfos, GLWEToMut, GetGaloisElement},
|
||||
};
|
||||
pub trait GLWEPacking<BE: Backend> {
|
||||
/// Packs [x_0: GLWE(m_0), x_1: GLWE(m_1), ..., x_i: GLWE(m_i)]
|
||||
@@ -21,7 +21,7 @@ pub trait GLWEPacking<BE: Backend> {
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
R: GLWEToMut + GLWEInfos,
|
||||
A: GLWEToMut + GLWEToRef + GLWEInfos,
|
||||
A: GLWEToMut + GLWEInfos,
|
||||
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
|
||||
H: GLWEAutomorphismKeyHelper<K, BE>;
|
||||
}
|
||||
@@ -51,7 +51,7 @@ where
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
R: GLWEToMut + GLWEInfos,
|
||||
A: GLWEToMut + GLWEToRef + GLWEInfos,
|
||||
A: GLWEToMut + GLWEInfos,
|
||||
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
|
||||
H: GLWEAutomorphismKeyHelper<K, BE>,
|
||||
{
|
||||
@@ -97,8 +97,8 @@ fn pack_internal<M, A, B, K, BE: Backend>(
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
M: GLWEAutomorphism<BE> + GLWERotate<BE> + GLWESub + GLWEShift<BE> + GLWEAdd + GLWENormalize<BE>,
|
||||
A: GLWEToMut + GLWEToRef + GLWEInfos,
|
||||
B: GLWEToMut + GLWEToRef + GLWEInfos,
|
||||
A: GLWEToMut + GLWEInfos,
|
||||
B: GLWEToMut + GLWEInfos,
|
||||
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
|
||||
@@ -189,7 +189,7 @@ impl<D: DataRef> WriterTo for GLWE<D> {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait GLWEToRef {
|
||||
pub trait GLWEToRef: Sized {
|
||||
fn to_ref(&self) -> GLWE<&[u8]>;
|
||||
}
|
||||
|
||||
@@ -203,14 +203,11 @@ impl<D: DataRef> GLWEToRef for GLWE<D> {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait GLWEToMut {
|
||||
pub trait GLWEToMut: GLWEToRef {
|
||||
fn to_mut(&mut self) -> GLWE<&mut [u8]>;
|
||||
}
|
||||
|
||||
impl<D: DataMut> GLWEToMut for GLWE<D>
|
||||
where
|
||||
Self: GLWEToRef,
|
||||
{
|
||||
impl<D: DataMut> GLWEToMut for GLWE<D> {
|
||||
fn to_mut(&mut self) -> GLWE<&mut [u8]> {
|
||||
GLWE {
|
||||
k: self.k,
|
||||
|
||||
@@ -402,6 +402,84 @@ where
|
||||
self.vec_znx_normalize_tmp_bytes()
|
||||
}
|
||||
|
||||
/// Usage:
|
||||
/// let mut tmp_b: Option<GLWE<&mut [u8]>> = None;
|
||||
/// let (b_conv, scratch_1) = glwe_maybe_convert_in_place(self, b, res.base2k().as_u32(), &mut tmp_b, scratch);
|
||||
fn glwe_maybe_cross_normalize_to_ref<'a, A>(
|
||||
&self,
|
||||
glwe: &'a A,
|
||||
target_base2k: usize,
|
||||
tmp_slot: &'a mut Option<GLWE<&'a mut [u8]>>, // caller-owned scratch-backed temp
|
||||
scratch: &'a mut Scratch<BE>,
|
||||
) -> (GLWE<&'a [u8]>, &'a mut Scratch<BE>)
|
||||
where
|
||||
A: GLWEToRef + GLWEInfos,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
// No conversion: just use the original GLWE
|
||||
if glwe.base2k().as_usize() == target_base2k {
|
||||
// Drop any previous temp; it's stale for this base
|
||||
tmp_slot.take();
|
||||
return (glwe.to_ref(), scratch);
|
||||
}
|
||||
|
||||
// Conversion: allocate a temporary GLWE in scratch
|
||||
let mut layout = glwe.glwe_layout();
|
||||
layout.base2k = target_base2k.into();
|
||||
|
||||
let (tmp, scratch2) = scratch.take_glwe(&layout);
|
||||
*tmp_slot = Some(tmp);
|
||||
|
||||
// Get a mutable handle to the temp and normalize into it
|
||||
let tmp_ref: &mut GLWE<&mut [u8]> = tmp_slot
|
||||
.as_mut()
|
||||
.expect("tmp_slot just set to Some, but found None");
|
||||
|
||||
self.glwe_normalize(tmp_ref, glwe, scratch2);
|
||||
|
||||
// Return a trait-object view of the temp
|
||||
(tmp_ref.to_ref(), scratch2)
|
||||
}
|
||||
|
||||
/// Usage:
|
||||
/// let mut tmp_b: Option<GLWE<&mut [u8]>> = None;
|
||||
/// let (b_conv, scratch_1) = glwe_maybe_convert_in_place(self, b, res.base2k().as_u32(), &mut tmp_b, scratch);
|
||||
fn glwe_maybe_cross_normalize_to_mut<'a, A>(
|
||||
&self,
|
||||
glwe: &'a mut A,
|
||||
target_base2k: usize,
|
||||
tmp_slot: &'a mut Option<GLWE<&'a mut [u8]>>, // caller-owned scratch-backed temp
|
||||
scratch: &'a mut Scratch<BE>,
|
||||
) -> (GLWE<&'a mut [u8]>, &'a mut Scratch<BE>)
|
||||
where
|
||||
A: GLWEToMut + GLWEInfos,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
// No conversion: just use the original GLWE
|
||||
if glwe.base2k().as_usize() == target_base2k {
|
||||
// Drop any previous temp; it's stale for this base
|
||||
tmp_slot.take();
|
||||
return (glwe.to_mut(), scratch);
|
||||
}
|
||||
|
||||
// Conversion: allocate a temporary GLWE in scratch
|
||||
let mut layout = glwe.glwe_layout();
|
||||
layout.base2k = target_base2k.into();
|
||||
|
||||
let (tmp, scratch2) = scratch.take_glwe(&layout);
|
||||
*tmp_slot = Some(tmp);
|
||||
|
||||
// Get a mutable handle to the temp and normalize into it
|
||||
let tmp_ref: &mut GLWE<&mut [u8]> = tmp_slot
|
||||
.as_mut()
|
||||
.expect("tmp_slot just set to Some, but found None");
|
||||
|
||||
self.glwe_normalize(tmp_ref, glwe, scratch2);
|
||||
|
||||
// Return a trait-object view of the temp
|
||||
(tmp_ref.to_mut(), scratch2)
|
||||
}
|
||||
|
||||
fn glwe_normalize<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<BE>)
|
||||
where
|
||||
R: GLWEToMut,
|
||||
|
||||
221
poulpy-schemes/src/tfhe/bdd_arithmetic/blind_retrieval.rs
Normal file
221
poulpy-schemes/src/tfhe/bdd_arithmetic/blind_retrieval.rs
Normal file
@@ -0,0 +1,221 @@
|
||||
use itertools::Itertools;
|
||||
use poulpy_core::{
|
||||
GLWECopy, ScratchTakeCore,
|
||||
layouts::{GGSWInfos, GGSWPrepared, GLWE, GLWEInfos, GLWEToMut, GLWEToRef},
|
||||
};
|
||||
use poulpy_hal::layouts::{Backend, Module, Scratch};
|
||||
|
||||
use crate::tfhe::bdd_arithmetic::{Cmux, Cswap, GetGGSWBit};
|
||||
|
||||
pub struct GLWEBlindRetriever {
|
||||
accumulators: Vec<Accumulator>,
|
||||
counter: usize,
|
||||
}
|
||||
|
||||
impl GLWEBlindRetriever {
|
||||
pub fn alloc<A>(infos: &A, size: usize) -> Self
|
||||
where
|
||||
A: GLWEInfos,
|
||||
{
|
||||
let log2_max_address: usize = (u32::BITS - (size as u32 - 1).leading_zeros()) as usize;
|
||||
Self {
|
||||
accumulators: (0..log2_max_address)
|
||||
.map(|_| Accumulator::alloc(infos))
|
||||
.collect_vec(),
|
||||
counter: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn retrieve<M, R, A, S, BE: Backend>(
|
||||
&mut self,
|
||||
module: &M,
|
||||
res: &mut R,
|
||||
data: &[A],
|
||||
selector: &S,
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
M: GLWECopy + Cmux<BE>,
|
||||
R: GLWEToMut,
|
||||
A: GLWEToRef,
|
||||
S: GetGGSWBit<BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
self.reset();
|
||||
|
||||
for ct in data {
|
||||
self.add(module, ct, selector, scratch);
|
||||
}
|
||||
self.flush(module, res, selector, scratch);
|
||||
}
|
||||
|
||||
pub fn add<A, S, M, BE: Backend>(&mut self, module: &M, a: &A, selector: &S, scratch: &mut Scratch<BE>)
|
||||
where
|
||||
A: GLWEToRef,
|
||||
S: GetGGSWBit<BE>,
|
||||
M: GLWECopy + Cmux<BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
assert!(
|
||||
(self.counter as u32) < 1 << self.accumulators.len(),
|
||||
"Accumulating limit of {} reached",
|
||||
1 << self.accumulators.len()
|
||||
);
|
||||
|
||||
add_core(module, a, &mut self.accumulators, 0, selector, scratch);
|
||||
self.counter += 1;
|
||||
}
|
||||
|
||||
pub fn flush<R, M, S, BE: Backend>(&mut self, module: &M, res: &mut R, selector: &S, scratch: &mut Scratch<BE>)
|
||||
where
|
||||
R: GLWEToMut,
|
||||
S: GetGGSWBit<BE>,
|
||||
M: GLWECopy + Cmux<BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
for i in 0..self.accumulators.len() - 1 {
|
||||
let (acc_prev, acc_next) = self.accumulators.split_at_mut(i + 1);
|
||||
if acc_prev[i].num != 0 {
|
||||
add_core(
|
||||
module,
|
||||
&acc_prev[i].data,
|
||||
acc_next,
|
||||
i + 1,
|
||||
selector,
|
||||
scratch,
|
||||
);
|
||||
acc_prev[0].num = 0
|
||||
}
|
||||
}
|
||||
module.glwe_copy(res, &self.accumulators.last().unwrap().data);
|
||||
self.reset()
|
||||
}
|
||||
|
||||
fn reset(&mut self) {
|
||||
for acc in self.accumulators.iter_mut() {
|
||||
acc.num = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Accumulator {
|
||||
data: GLWE<Vec<u8>>,
|
||||
num: usize, // Number of accumulated values
|
||||
}
|
||||
|
||||
impl Accumulator {
|
||||
pub fn alloc<A>(infos: &A) -> Self
|
||||
where
|
||||
A: GLWEInfos,
|
||||
{
|
||||
Self {
|
||||
data: GLWE::alloc_from_infos(infos),
|
||||
num: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn add_core<A, S, M, BE: Backend>(
|
||||
module: &M,
|
||||
a: &A,
|
||||
accumulators: &mut [Accumulator],
|
||||
i: usize,
|
||||
selector: &S,
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
A: GLWEToRef,
|
||||
S: GetGGSWBit<BE>,
|
||||
M: GLWECopy + Cmux<BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
// Isolate the first accumulator
|
||||
let (acc_prev, acc_next) = accumulators.split_at_mut(1);
|
||||
|
||||
match acc_prev[0].num {
|
||||
0 => {
|
||||
module.glwe_copy(&mut acc_prev[0].data, a);
|
||||
acc_prev[0].num = 1;
|
||||
}
|
||||
1 => {
|
||||
module.cmux_inplace_neg(&mut acc_prev[0].data, a, &selector.get_bit(i), scratch);
|
||||
|
||||
if !acc_next.is_empty() {
|
||||
add_core(
|
||||
module,
|
||||
&acc_prev[0].data,
|
||||
acc_next,
|
||||
i + 1,
|
||||
selector,
|
||||
scratch,
|
||||
);
|
||||
}
|
||||
|
||||
acc_prev[0].num = 0
|
||||
}
|
||||
_ => {
|
||||
panic!("something went wrong")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<BE: Backend> GLWEBlindRetrieval<BE> for Module<BE> where Self: GLWECopy + Cmux<BE> + Cswap<BE> {}
|
||||
|
||||
pub trait GLWEBlindRetrieval<BE: Backend>
|
||||
where
|
||||
Self: GLWECopy + Cmux<BE> + Cswap<BE>,
|
||||
{
|
||||
fn glwe_blind_retrieval_tmp_bytes<R, K>(&self, res_infos: &R, k_infos: &K) -> usize
|
||||
where
|
||||
R: GLWEInfos,
|
||||
K: GGSWInfos,
|
||||
{
|
||||
self.cswap_tmp_bytes(res_infos, res_infos, k_infos)
|
||||
}
|
||||
|
||||
fn glwe_blind_retrieval_statefull<R, K>(
|
||||
&self,
|
||||
res: &mut Vec<R>,
|
||||
bits: &K,
|
||||
bit_rsh: usize,
|
||||
bit_mask: usize,
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
R: GLWEToMut + GLWEInfos,
|
||||
K: GetGGSWBit<BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
for i in 0..bit_mask {
|
||||
let t: usize = 1 << (bit_mask - i - 1);
|
||||
let bit: &GGSWPrepared<&[u8], BE> = &bits.get_bit(bit_rsh + bit_mask - i - 1); // MSB -> LSB traversal
|
||||
for j in 0..t {
|
||||
if j + t < res.len() {
|
||||
let (lo, hi) = res.split_at_mut(j + t);
|
||||
self.cswap(&mut lo[j], &mut hi[0], bit, scratch);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn glwe_blind_retrieval_statefull_rev<R, K>(
|
||||
&self,
|
||||
res: &mut Vec<R>,
|
||||
bits: &K,
|
||||
bit_rsh: usize,
|
||||
bit_mask: usize,
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
R: GLWEToMut + GLWEInfos,
|
||||
K: GetGGSWBit<BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
for i in (0..bit_mask).rev() {
|
||||
let t: usize = 1 << (bit_mask - i - 1);
|
||||
let bit: &GGSWPrepared<&[u8], BE> = &bits.get_bit(bit_rsh + bit_mask - i - 1); // MSB -> LSB traversal
|
||||
for j in 0..t {
|
||||
if j < res.len() && j + t < res.len() {
|
||||
let (lo, hi) = res.split_at_mut(j + t);
|
||||
self.cswap(&mut lo[j], &mut hi[0], bit, scratch);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,7 @@ use std::collections::HashMap;
|
||||
|
||||
use poulpy_core::{
|
||||
GLWECopy, GLWEDecrypt, ScratchTakeCore,
|
||||
layouts::{GGSWInfos, GGSWPrepared, GLWE, GLWEInfos, GLWEToMut, GLWEToRef},
|
||||
layouts::{GGSWInfos, GGSWPrepared, GLWE, GLWEInfos, GLWEToMut},
|
||||
};
|
||||
use poulpy_hal::layouts::{Backend, Module, Scratch, ZnxZero};
|
||||
|
||||
@@ -33,7 +33,7 @@ where
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
R: GLWEToMut,
|
||||
A: GLWEToMut + GLWEToRef,
|
||||
A: GLWEToMut,
|
||||
K: GetGGSWBit<BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
|
||||
@@ -49,6 +49,18 @@ impl<'a, T: UnsignedInteger> FheUint<&'a mut [u8], T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: UnsignedInteger> FheUint<&'a [u8], T> {
|
||||
pub fn from_glwe_to_ref<G>(glwe: &'a G) -> Self
|
||||
where
|
||||
G: GLWEToRef,
|
||||
{
|
||||
FheUint {
|
||||
bits: glwe.to_ref(),
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef, T: UnsignedInteger> LWEInfos for FheUint<D, T> {
|
||||
fn base2k(&self) -> poulpy_core::layouts::Base2K {
|
||||
self.bits.base2k()
|
||||
@@ -180,7 +192,7 @@ impl<D: DataMut, T: UnsignedInteger> FheUint<D, T> {
|
||||
/// Packs Vec<GLWE(bit[i])> into [FheUint].
|
||||
pub fn pack<G, M, K, H, BE: Backend>(&mut self, module: &M, mut bits: Vec<G>, keys: &H, scratch: &mut Scratch<BE>)
|
||||
where
|
||||
G: GLWEToMut + GLWEToRef + GLWEInfos,
|
||||
G: GLWEToMut + GLWEInfos,
|
||||
M: ModuleLogN + GLWEPacking<BE> + GLWECopy,
|
||||
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
|
||||
H: GLWEAutomorphismKeyHelper<K, BE>,
|
||||
|
||||
@@ -4,14 +4,16 @@ use std::thread;
|
||||
use itertools::Itertools;
|
||||
use poulpy_core::{
|
||||
GLWECopy, GLWEExternalProductInternal, GLWENormalize, GLWESub, ScratchTakeCore,
|
||||
layouts::{GGSWInfos, GGSWPrepared, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, prepared::GGSWPreparedToRef},
|
||||
layouts::{
|
||||
GGSWInfos, GGSWPrepared, GLWE, GLWEInfos, GLWELayout, GLWEToMut, GLWEToRef, LWEInfos, prepared::GGSWPreparedToRef,
|
||||
},
|
||||
};
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
ScratchAvailable, ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
|
||||
VecZnxDftBytesOf,
|
||||
ScratchAvailable, ScratchTakeBasic, VecZnxBigAddSmall, VecZnxBigAddSmallInplace, VecZnxBigBytesOf, VecZnxBigNormalize,
|
||||
VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallA, VecZnxDftBytesOf,
|
||||
},
|
||||
layouts::{Backend, DataMut, Module, Scratch, VecZnxBig, ZnxZero},
|
||||
layouts::{Backend, DataMut, Module, Scratch, VecZnxBig, ZnxInfos, ZnxZero},
|
||||
};
|
||||
|
||||
use crate::tfhe::bdd_arithmetic::GetGGSWBit;
|
||||
@@ -260,6 +262,204 @@ pub enum Node {
|
||||
None,
|
||||
}
|
||||
|
||||
impl<BE: Backend> Cswap<BE> for Module<BE> where
|
||||
Self: Sized
|
||||
+ GLWEExternalProductInternal<BE>
|
||||
+ GLWESub
|
||||
+ VecZnxBigAddSmallInplace<BE>
|
||||
+ GLWENormalize<BE>
|
||||
+ VecZnxDftBytesOf
|
||||
+ VecZnxBigNormalize<BE>
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ GLWENormalize<BE>
|
||||
+ VecZnxBigAddSmall<BE>
|
||||
+ VecZnxBigSubSmallA<BE>
|
||||
+ VecZnxBigBytesOf
|
||||
{
|
||||
}
|
||||
|
||||
pub trait Cswap<BE: Backend>
|
||||
where
|
||||
Self: Sized
|
||||
+ GLWEExternalProductInternal<BE>
|
||||
+ GLWESub
|
||||
+ GLWECopy
|
||||
+ VecZnxBigAddSmallInplace<BE>
|
||||
+ GLWENormalize<BE>
|
||||
+ VecZnxDftBytesOf
|
||||
+ VecZnxBigNormalize<BE>
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ GLWENormalize<BE>
|
||||
+ VecZnxBigAddSmall<BE>
|
||||
+ VecZnxBigSubSmallA<BE>
|
||||
+ VecZnxBigBytesOf,
|
||||
{
|
||||
fn cswap_tmp_bytes<R, A, S>(&self, res_a_infos: &R, res_b_infos: &A, s_infos: &S) -> usize
|
||||
where
|
||||
R: GLWEInfos,
|
||||
A: GLWEInfos,
|
||||
S: GGSWInfos,
|
||||
{
|
||||
let res_dft: usize = self.bytes_of_vec_znx_dft((s_infos.rank() + 1).into(), s_infos.size());
|
||||
let mut tot = res_dft
|
||||
+ (self.glwe_external_product_internal_tmp_bytes(res_a_infos, res_b_infos, s_infos)
|
||||
+ GLWE::bytes_of_from_infos(&GLWELayout {
|
||||
n: s_infos.n(),
|
||||
base2k: s_infos.base2k(),
|
||||
k: res_a_infos.k().max(res_b_infos.k()),
|
||||
rank: s_infos.rank(),
|
||||
}))
|
||||
.max(self.vec_znx_big_normalize_tmp_bytes());
|
||||
|
||||
if res_a_infos.base2k() != s_infos.base2k() {
|
||||
tot += GLWE::bytes_of_from_infos(&GLWELayout {
|
||||
n: res_a_infos.n(),
|
||||
base2k: s_infos.base2k(),
|
||||
k: res_a_infos.k(),
|
||||
rank: res_a_infos.rank(),
|
||||
});
|
||||
|
||||
tot += GLWE::bytes_of_from_infos(&GLWELayout {
|
||||
n: res_b_infos.n(),
|
||||
base2k: s_infos.base2k(),
|
||||
k: res_b_infos.k(),
|
||||
rank: res_b_infos.rank(),
|
||||
});
|
||||
}
|
||||
|
||||
tot += self.bytes_of_vec_znx_big(1, s_infos.size());
|
||||
|
||||
tot
|
||||
}
|
||||
|
||||
fn cswap<A, B, S>(&self, res_a: &mut A, res_b: &mut B, s: &S, scratch: &mut Scratch<BE>)
|
||||
where
|
||||
A: GLWEToMut,
|
||||
B: GLWEToMut,
|
||||
S: GGSWPreparedToRef<BE> + GGSWInfos,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
let res_a: &mut GLWE<&mut [u8]> = &mut res_a.to_mut();
|
||||
let res_b: &mut GLWE<&mut [u8]> = &mut res_b.to_mut();
|
||||
let s: &GGSWPrepared<&[u8], BE> = &s.to_ref();
|
||||
assert_eq!(res_a.base2k(), res_b.base2k());
|
||||
|
||||
let res_base2k: usize = res_a.base2k().as_usize();
|
||||
let s_base2k: usize = s.base2k().as_usize();
|
||||
|
||||
if res_base2k == s_base2k {
|
||||
let res_big: VecZnxBig<&mut [u8], BE>;
|
||||
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (s.rank() + 1).into(), s.size()); // Todo optimise
|
||||
{
|
||||
// Temporary value storing a - b
|
||||
let tmp_c_infos: GLWELayout = GLWELayout {
|
||||
n: s.n(),
|
||||
base2k: s.base2k(),
|
||||
k: res_a.k().max(res_b.k()),
|
||||
rank: s.rank(),
|
||||
};
|
||||
let (mut tmp_c, scratch_2) = scratch_1.take_glwe(&tmp_c_infos);
|
||||
self.glwe_sub(&mut tmp_c, res_b, res_a);
|
||||
res_big = self.glwe_external_product_internal(res_dft, &tmp_c, s, scratch_2);
|
||||
}
|
||||
|
||||
// Single column res_big to store temporary value before normalization
|
||||
let (mut res_big_tmp, scratch_2) = scratch_1.take_vec_znx_big::<_, BE>(self, 1, res_big.size());
|
||||
|
||||
// res_a = (b-a) * bit + a
|
||||
for j in 0..(res_a.rank() + 1).into() {
|
||||
self.vec_znx_big_add_small(&mut res_big_tmp, 0, &res_big, j, res_a.data(), j);
|
||||
self.vec_znx_big_normalize(
|
||||
res_base2k,
|
||||
res_a.data_mut(),
|
||||
j,
|
||||
s_base2k,
|
||||
&res_big_tmp,
|
||||
0,
|
||||
scratch_2,
|
||||
);
|
||||
}
|
||||
|
||||
// res_b = a - (a - b) * bit = (b - a) * bit + a
|
||||
for j in 0..(res_b.rank() + 1).into() {
|
||||
self.vec_znx_big_sub_small_a(&mut res_big_tmp, 0, res_b.data(), j, &res_big, j);
|
||||
self.vec_znx_big_normalize(
|
||||
res_base2k,
|
||||
res_b.data_mut(),
|
||||
j,
|
||||
s_base2k,
|
||||
&res_big_tmp,
|
||||
0,
|
||||
scratch_2,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
let (mut tmp_a, scratch_1) = scratch.take_glwe(&GLWELayout {
|
||||
n: res_a.n(),
|
||||
base2k: s.base2k(),
|
||||
k: res_a.k(),
|
||||
rank: res_a.rank(),
|
||||
});
|
||||
|
||||
let (mut tmp_b, scratch_2) = scratch_1.take_glwe(&GLWELayout {
|
||||
n: res_b.n(),
|
||||
base2k: s.base2k(),
|
||||
k: res_b.k(),
|
||||
rank: res_b.rank(),
|
||||
});
|
||||
|
||||
self.glwe_normalize(&mut tmp_a, res_a, scratch_2);
|
||||
self.glwe_normalize(&mut tmp_b, res_b, scratch_2);
|
||||
|
||||
let res_big: VecZnxBig<&mut [u8], BE>;
|
||||
let (res_dft, scratch_3) = scratch_2.take_vec_znx_dft(self, (s.rank() + 1).into(), s.size()); // Todo optimise
|
||||
{
|
||||
// Temporary value storing a - b
|
||||
let tmp_c_infos: GLWELayout = GLWELayout {
|
||||
n: s.n(),
|
||||
base2k: s.base2k(),
|
||||
k: res_a.k().max(res_b.k()),
|
||||
rank: s.rank(),
|
||||
};
|
||||
let (mut tmp_c, scratch_4) = scratch_3.take_glwe(&tmp_c_infos);
|
||||
self.glwe_sub(&mut tmp_c, res_b, res_a);
|
||||
res_big = self.glwe_external_product_internal(res_dft, &tmp_c, s, scratch_4);
|
||||
}
|
||||
|
||||
// Single column res_big to store temporary value before normalization
|
||||
let (mut res_big_tmp, scratch_4) = scratch_3.take_vec_znx_big::<_, BE>(self, 1, res_big.size());
|
||||
|
||||
// res_a = (b-a) * bit + a
|
||||
for j in 0..(res_a.rank() + 1).into() {
|
||||
self.vec_znx_big_add_small(&mut res_big_tmp, 0, &res_big, j, tmp_a.data(), j);
|
||||
self.vec_znx_big_normalize(
|
||||
res_base2k,
|
||||
res_a.data_mut(),
|
||||
j,
|
||||
s_base2k,
|
||||
&res_big_tmp,
|
||||
0,
|
||||
scratch_4,
|
||||
);
|
||||
}
|
||||
|
||||
// res_b = a - (a - b) * bit = (b - a) * bit + a
|
||||
for j in 0..(res_b.rank() + 1).into() {
|
||||
self.vec_znx_big_sub_small_a(&mut res_big_tmp, 0, tmp_b.data(), j, &res_big, j);
|
||||
self.vec_znx_big_normalize(
|
||||
res_base2k,
|
||||
res_b.data_mut(),
|
||||
j,
|
||||
s_base2k,
|
||||
&res_big_tmp,
|
||||
0,
|
||||
scratch_4,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Cmux<BE: Backend>
|
||||
where
|
||||
Self: Sized
|
||||
@@ -284,6 +484,7 @@ where
|
||||
.max(self.vec_znx_big_normalize_tmp_bytes())
|
||||
}
|
||||
|
||||
// res = (t - f) * s + f
|
||||
fn cmux<R, T, F, S>(&self, res: &mut R, t: &T, f: &F, s: &S, scratch: &mut Scratch<BE>)
|
||||
where
|
||||
R: GLWEToMut,
|
||||
@@ -316,6 +517,46 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
// res = (a - res) * s + res
|
||||
fn cmux_inplace_neg<R, A, S>(&self, res: &mut R, a: &A, s: &S, scratch: &mut Scratch<BE>)
|
||||
where
|
||||
R: GLWEToMut,
|
||||
A: GLWEToRef,
|
||||
S: GGSWPreparedToRef<BE> + GGSWInfos,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
|
||||
let s: &GGSWPrepared<&[u8], BE> = &s.to_ref();
|
||||
let a: &GLWE<&[u8]> = &a.to_ref();
|
||||
|
||||
assert_eq!(res.base2k(), a.base2k());
|
||||
|
||||
let res_base2k: usize = res.base2k().into();
|
||||
let ggsw_base2k: usize = s.base2k().into();
|
||||
let (mut tmp, scratch_1) = scratch.take_glwe(&GLWELayout {
|
||||
n: s.n(),
|
||||
base2k: res.base2k(),
|
||||
k: res.k().max(a.k()),
|
||||
rank: res.rank(),
|
||||
});
|
||||
self.glwe_sub(&mut tmp, a, res);
|
||||
let (res_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, (res.rank() + 1).into(), s.size()); // Todo optimise
|
||||
let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_external_product_internal(res_dft, &tmp, s, scratch_2);
|
||||
for j in 0..(res.rank() + 1).into() {
|
||||
self.vec_znx_big_add_small_inplace(&mut res_big, j, res.data(), j);
|
||||
self.vec_znx_big_normalize(
|
||||
res_base2k,
|
||||
res.data_mut(),
|
||||
j,
|
||||
ggsw_base2k,
|
||||
&res_big,
|
||||
j,
|
||||
scratch_2,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// res = (res - a) * s + a
|
||||
fn cmux_inplace<R, A, S>(&self, res: &mut R, a: &A, s: &S, scratch: &mut Scratch<BE>)
|
||||
where
|
||||
R: GLWEToMut,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
mod bdd_2w_to_1w;
|
||||
mod blind_retrieval;
|
||||
mod blind_rotation;
|
||||
mod blind_selection;
|
||||
mod ciphertexts;
|
||||
@@ -7,6 +8,7 @@ mod eval;
|
||||
mod key;
|
||||
|
||||
pub use bdd_2w_to_1w::*;
|
||||
pub use blind_retrieval::*;
|
||||
pub use blind_rotation::*;
|
||||
pub use blind_selection::*;
|
||||
pub use ciphertexts::*;
|
||||
|
||||
@@ -6,7 +6,8 @@ use crate::tfhe::{
|
||||
bdd_arithmetic::tests::test_suite::{
|
||||
TestContext, test_bdd_add, test_bdd_and, test_bdd_or, test_bdd_prepare, test_bdd_sll, test_bdd_slt, test_bdd_sltu,
|
||||
test_bdd_sra, test_bdd_srl, test_bdd_sub, test_bdd_xor, test_fhe_uint_get_bit_glwe, test_fhe_uint_sext,
|
||||
test_fhe_uint_splice_u8, test_fhe_uint_splice_u16, test_glwe_blind_selection, test_glwe_to_glwe_blind_rotation,
|
||||
test_fhe_uint_splice_u8, test_fhe_uint_splice_u16, test_fhe_uint_swap, test_glwe_blind_retrieval_statefull,
|
||||
test_glwe_blind_retriever, test_glwe_blind_selection, test_glwe_to_glwe_blind_rotation,
|
||||
test_scalar_to_ggsw_blind_rotation,
|
||||
},
|
||||
blind_rotation::CGGI,
|
||||
@@ -15,6 +16,21 @@ use crate::tfhe::{
|
||||
static TEST_CONTEXT_CGGI_FFT64_REF: LazyLock<TestContext<CGGI, FFT64Ref>> =
|
||||
LazyLock::new(|| TestContext::<CGGI, FFT64Ref>::new());
|
||||
|
||||
#[test]
|
||||
fn test_glwe_blind_retriever_fft64_ref() {
|
||||
test_glwe_blind_retriever(&TEST_CONTEXT_CGGI_FFT64_REF);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glwe_blind_retrieval_statefull_fft64_ref() {
|
||||
test_glwe_blind_retrieval_statefull(&TEST_CONTEXT_CGGI_FFT64_REF);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fhe_uint_swap_fft64_ref() {
|
||||
test_fhe_uint_swap(&TEST_CONTEXT_CGGI_FFT64_REF);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fhe_uint_get_bit_glwe_fft64_ref() {
|
||||
test_fhe_uint_get_bit_glwe(&TEST_CONTEXT_CGGI_FFT64_REF);
|
||||
|
||||
@@ -12,6 +12,7 @@ mod sltu;
|
||||
mod sra;
|
||||
mod srl;
|
||||
mod sub;
|
||||
mod swap;
|
||||
mod xor;
|
||||
|
||||
pub use add::*;
|
||||
@@ -33,6 +34,7 @@ pub use sltu::*;
|
||||
pub use sra::*;
|
||||
pub use srl::*;
|
||||
pub use sub::*;
|
||||
pub use swap::*;
|
||||
pub use xor::*;
|
||||
|
||||
use poulpy_core::{
|
||||
|
||||
203
poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/swap.rs
Normal file
203
poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/swap.rs
Normal file
@@ -0,0 +1,203 @@
|
||||
use itertools::Itertools;
|
||||
use poulpy_core::{
|
||||
GGSWEncryptSk, GLWEDecrypt, GLWEEncryptSk,
|
||||
layouts::{GGSW, GGSWPrepared, GGSWPreparedFactory, GLWELayout, GLWESecretPrepared},
|
||||
};
|
||||
use poulpy_hal::{
|
||||
api::{ScratchOwnedAlloc, ScratchOwnedBorrow},
|
||||
layouts::{Backend, Module, ScalarZnx, Scratch, ScratchOwned, ZnxViewMut},
|
||||
source::Source,
|
||||
};
|
||||
use rand::RngCore;
|
||||
|
||||
use crate::tfhe::{
|
||||
bdd_arithmetic::{
|
||||
Cswap, FheUint, FheUintPrepared, GLWEBlindRetrieval, GLWEBlindRetriever, ScratchTakeBDD,
|
||||
tests::test_suite::{TEST_GGSW_INFOS, TEST_GLWE_INFOS, TestContext},
|
||||
},
|
||||
blind_rotation::BlindRotationAlgo,
|
||||
};
|
||||
|
||||
pub fn test_fhe_uint_swap<BRA: BlindRotationAlgo, BE: Backend>(test_context: &TestContext<BRA, BE>)
|
||||
where
|
||||
Module<BE>: GLWEEncryptSk<BE> + GLWEDecrypt<BE> + Cswap<BE> + GGSWEncryptSk<BE> + GGSWPreparedFactory<BE>,
|
||||
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
|
||||
Scratch<BE>: ScratchTakeBDD<u32, BE>,
|
||||
{
|
||||
let glwe_infos: GLWELayout = TEST_GLWE_INFOS;
|
||||
let ggsw_infos: poulpy_core::layouts::GGSWLayout = TEST_GGSW_INFOS;
|
||||
|
||||
let module: &Module<BE> = &test_context.module;
|
||||
let sk: &GLWESecretPrepared<Vec<u8>, BE> = &test_context.sk_glwe;
|
||||
|
||||
let mut source_xa: Source = Source::new([2u8; 32]);
|
||||
let mut source_xe: Source = Source::new([3u8; 32]);
|
||||
|
||||
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(1 << 22);
|
||||
|
||||
let mut s: GGSW<Vec<u8>> = GGSW::alloc_from_infos(&ggsw_infos);
|
||||
let mut s_prepared: GGSWPrepared<Vec<u8>, BE> = GGSWPrepared::alloc_from_infos(module, &ggsw_infos);
|
||||
|
||||
let a: u32 = source_xa.next_u32();
|
||||
let b: u32 = source_xa.next_u32();
|
||||
|
||||
for bit in [0, 1] {
|
||||
let mut a_enc: FheUint<Vec<u8>, u32> = FheUint::<Vec<u8>, u32>::alloc_from_infos(&glwe_infos);
|
||||
let mut b_enc: FheUint<Vec<u8>, u32> = FheUint::<Vec<u8>, u32>::alloc_from_infos(&glwe_infos);
|
||||
|
||||
a_enc.encrypt_sk(
|
||||
module,
|
||||
a,
|
||||
sk,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
b_enc.encrypt_sk(
|
||||
module,
|
||||
b,
|
||||
sk,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
let mut pt: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(module.n(), 1);
|
||||
pt.raw_mut()[0] = bit;
|
||||
s.encrypt_sk(
|
||||
module,
|
||||
&pt,
|
||||
sk,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
scratch.borrow(),
|
||||
);
|
||||
s_prepared.prepare(module, &s, scratch.borrow());
|
||||
|
||||
module.cswap(&mut a_enc, &mut b_enc, &s_prepared, scratch.borrow());
|
||||
|
||||
let (a_want, b_want) = if bit == 0 { (a, b) } else { (b, a) };
|
||||
|
||||
assert_eq!(a_want, a_enc.decrypt(module, sk, scratch.borrow()));
|
||||
assert_eq!(b_want, b_enc.decrypt(module, sk, scratch.borrow()));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn test_glwe_blind_retrieval_statefull<BRA: BlindRotationAlgo, BE: Backend>(test_context: &TestContext<BRA, BE>)
|
||||
where
|
||||
Module<BE>: GLWEEncryptSk<BE> + GLWEDecrypt<BE> + GLWEBlindRetrieval<BE> + GGSWEncryptSk<BE> + GGSWPreparedFactory<BE>,
|
||||
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
|
||||
Scratch<BE>: ScratchTakeBDD<u32, BE>,
|
||||
{
|
||||
let glwe_infos: GLWELayout = TEST_GLWE_INFOS;
|
||||
let ggsw_infos: poulpy_core::layouts::GGSWLayout = TEST_GGSW_INFOS;
|
||||
|
||||
let module: &Module<BE> = &test_context.module;
|
||||
let sk: &GLWESecretPrepared<Vec<u8>, BE> = &test_context.sk_glwe;
|
||||
|
||||
let mut source_xa: Source = Source::new([2u8; 32]);
|
||||
let mut source_xe: Source = Source::new([3u8; 32]);
|
||||
|
||||
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(1 << 22);
|
||||
|
||||
let data: Vec<u32> = (0..25).map(|i| i as u32).collect_vec();
|
||||
|
||||
let mut data_enc: Vec<FheUint<Vec<u8>, u32>> = (0..data.len())
|
||||
.map(|i| {
|
||||
let mut ct: FheUint<Vec<u8>, u32> = FheUint::<Vec<u8>, u32>::alloc_from_infos(&glwe_infos);
|
||||
ct.encrypt_sk(
|
||||
module,
|
||||
data[i],
|
||||
sk,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
scratch.borrow(),
|
||||
);
|
||||
ct
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
for idx in 0..data.len() as u32 {
|
||||
let mut idx_enc = FheUintPrepared::alloc_from_infos(module, &ggsw_infos);
|
||||
idx_enc.encrypt_sk(
|
||||
module,
|
||||
idx,
|
||||
sk,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
module.glwe_blind_retrieval_statefull(&mut data_enc, &idx_enc, 0, 5, scratch.borrow());
|
||||
|
||||
assert_eq!(
|
||||
data[idx as usize],
|
||||
data_enc[0].decrypt(module, sk, scratch.borrow())
|
||||
);
|
||||
|
||||
module.glwe_blind_retrieval_statefull_rev(&mut data_enc, &idx_enc, 0, 5, scratch.borrow());
|
||||
|
||||
for i in 0..data.len() {
|
||||
assert_eq!(data[i], data_enc[i].decrypt(module, sk, scratch.borrow()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn test_glwe_blind_retriever<BRA: BlindRotationAlgo, BE: Backend>(test_context: &TestContext<BRA, BE>)
|
||||
where
|
||||
Module<BE>: GLWEEncryptSk<BE> + GLWEDecrypt<BE> + GLWEBlindRetrieval<BE> + GGSWEncryptSk<BE> + GGSWPreparedFactory<BE>,
|
||||
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
|
||||
Scratch<BE>: ScratchTakeBDD<u32, BE>,
|
||||
{
|
||||
let glwe_infos: GLWELayout = TEST_GLWE_INFOS;
|
||||
let ggsw_infos: poulpy_core::layouts::GGSWLayout = TEST_GGSW_INFOS;
|
||||
|
||||
let module: &Module<BE> = &test_context.module;
|
||||
let sk: &GLWESecretPrepared<Vec<u8>, BE> = &test_context.sk_glwe;
|
||||
|
||||
let mut source_xa: Source = Source::new([2u8; 32]);
|
||||
let mut source_xe: Source = Source::new([3u8; 32]);
|
||||
|
||||
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(1 << 22);
|
||||
|
||||
let data: Vec<u32> = (0..25).map(|i| i as u32).collect_vec();
|
||||
|
||||
let data_enc: Vec<FheUint<Vec<u8>, u32>> = (0..data.len())
|
||||
.map(|i| {
|
||||
let mut ct: FheUint<Vec<u8>, u32> = FheUint::<Vec<u8>, u32>::alloc_from_infos(&glwe_infos);
|
||||
ct.encrypt_sk(
|
||||
module,
|
||||
data[i],
|
||||
sk,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
scratch.borrow(),
|
||||
);
|
||||
ct
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
for idx in 0..data.len() as u32 {
|
||||
let mut idx_enc: FheUintPrepared<Vec<u8>, u32, BE> = FheUintPrepared::alloc_from_infos(module, &ggsw_infos);
|
||||
idx_enc.encrypt_sk(
|
||||
module,
|
||||
idx,
|
||||
sk,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
let mut retriever: GLWEBlindRetriever = GLWEBlindRetriever::alloc(&glwe_infos, 25);
|
||||
let mut res: FheUint<Vec<u8>, u32> = FheUint::alloc_from_infos(&glwe_infos);
|
||||
retriever.retrieve(module, &mut res, &data_enc, &idx_enc, scratch.borrow());
|
||||
|
||||
println!("{}", res.decrypt(module, sk, scratch.borrow()));
|
||||
|
||||
assert_eq!(
|
||||
data[idx as usize],
|
||||
res.decrypt(module, sk, scratch.borrow())
|
||||
);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user