diff --git a/poulpy-core/src/external_product/glwe.rs b/poulpy-core/src/external_product/glwe.rs index 303499d..2c9fe12 100644 --- a/poulpy-core/src/external_product/glwe.rs +++ b/poulpy-core/src/external_product/glwe.rs @@ -1,16 +1,16 @@ use poulpy_hal::{ api::{ - ModuleN, ScratchTakeBasic, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftApply, VecZnxDftBytesOf, - VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, - VmpApplyDftToDftTmpBytes, + ModuleN, ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftApply, + VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, + VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft}, }; use crate::{ - ScratchTakeCore, + GLWENormalize, ScratchTakeCore, layouts::{ - GGSWInfos, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, + GGSWInfos, GLWE, GLWEInfos, GLWELayout, GLWEToMut, GLWEToRef, LWEInfos, prepared::{GGSWPrepared, GGSWPreparedToRef}, }, }; @@ -67,11 +67,22 @@ pub trait GLWEExternalProduct { A: GLWEToRef, D: GGSWPreparedToRef + GGSWInfos, Scratch: ScratchTakeCore; + fn glwe_external_product_add(&self, res: &mut R, lhs: &A, rhs: &D, scratch: &mut Scratch) + where + R: GLWEToMut, + A: GLWEToRef, + D: GGSWPreparedToRef + GGSWInfos, + Scratch: ScratchTakeCore; } impl GLWEExternalProduct for Module where - Self: GLWEExternalProductInternal + VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxBigNormalizeTmpBytes, + Self: GLWEExternalProductInternal + + VecZnxDftBytesOf + + VecZnxBigNormalize + + VecZnxBigNormalizeTmpBytes + + VecZnxBigAddSmallInplace + + GLWENormalize, { fn glwe_external_product_tmp_bytes(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize where @@ -163,6 +174,80 @@ where ); } } + + fn glwe_external_product_add(&self, res: &mut R, a: &A, key: &D, scratch: &mut Scratch) + where + R: GLWEToMut, + A: GLWEToRef, + D: GGSWPreparedToRef, + Scratch: ScratchTakeCore, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); + let key: &GGSWPrepared<&[u8], BE> = &key.to_ref(); + + assert_eq!(a.base2k(), res.base2k()); + + let res_base2k: usize = res.base2k().into(); + let key_base2k: usize = key.base2k().into(); + + #[cfg(debug_assertions)] + { + use poulpy_hal::api::ScratchAvailable; + + assert_eq!(key.rank(), a.rank()); + assert_eq!(key.rank(), res.rank()); + assert_eq!(key.n(), res.n()); + assert_eq!(a.n(), res.n()); + assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, a, key)); + } + + if res_base2k == key_base2k { + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // Todo optimise + let mut res_big = self.glwe_external_product_internal(res_dft, a, key, scratch_1); + for j in 0..(res.rank() + 1).into() { + self.vec_znx_big_add_small_inplace(&mut res_big, j, res.data(), j); + self.vec_znx_big_normalize( + res_base2k, + &mut res.data, + j, + key_base2k, + &res_big, + j, + scratch_1, + ); + } + } else { + let (mut a_conv, scratch_1) = scratch.take_glwe(&GLWELayout { + n: a.n(), + base2k: key.base2k(), + k: a.k(), + rank: a.rank(), + }); + let (mut res_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { + n: res.n(), + base2k: key.base2k(), + k: res.k(), + rank: res.rank(), + }); + self.glwe_normalize(&mut a_conv, a, scratch_2); + self.glwe_normalize(&mut res_conv, res, scratch_2); + let (res_dft, scratch_2) = scratch_2.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // Todo optimise + let mut res_big = self.glwe_external_product_internal(res_dft, &a_conv, key, scratch_2); + for j in 0..(res.rank() + 1).into() { + self.vec_znx_big_add_small_inplace(&mut res_big, j, res_conv.data(), j); + self.vec_znx_big_normalize( + res_base2k, + &mut res.data, + j, + key_base2k, + &res_big, + j, + scratch_2, + ); + } + } + } } pub trait GLWEExternalProductInternal { diff --git a/poulpy-core/src/glwe_packing.rs b/poulpy-core/src/glwe_packing.rs index e729929..1ddb565 100644 --- a/poulpy-core/src/glwe_packing.rs +++ b/poulpy-core/src/glwe_packing.rs @@ -7,7 +7,7 @@ use poulpy_hal::{ use crate::{ GLWEAdd, GLWEAutomorphism, GLWECopy, GLWENormalize, GLWERotate, GLWEShift, GLWESub, GLWETrace, ScratchTakeCore, - layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWEAutomorphismKeyHelper, GLWEInfos, GLWEToMut, GLWEToRef, GetGaloisElement}, + layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWEAutomorphismKeyHelper, GLWEInfos, GLWEToMut, GetGaloisElement}, }; pub trait GLWEPacking { /// Packs [x_0: GLWE(m_0), x_1: GLWE(m_1), ..., x_i: GLWE(m_i)] @@ -21,7 +21,7 @@ pub trait GLWEPacking { scratch: &mut Scratch, ) where R: GLWEToMut + GLWEInfos, - A: GLWEToMut + GLWEToRef + GLWEInfos, + A: GLWEToMut + GLWEInfos, K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, H: GLWEAutomorphismKeyHelper; } @@ -51,7 +51,7 @@ where scratch: &mut Scratch, ) where R: GLWEToMut + GLWEInfos, - A: GLWEToMut + GLWEToRef + GLWEInfos, + A: GLWEToMut + GLWEInfos, K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, H: GLWEAutomorphismKeyHelper, { @@ -97,8 +97,8 @@ fn pack_internal( scratch: &mut Scratch, ) where M: GLWEAutomorphism + GLWERotate + GLWESub + GLWEShift + GLWEAdd + GLWENormalize, - A: GLWEToMut + GLWEToRef + GLWEInfos, - B: GLWEToMut + GLWEToRef + GLWEInfos, + A: GLWEToMut + GLWEInfos, + B: GLWEToMut + GLWEInfos, K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, Scratch: ScratchTakeCore, { diff --git a/poulpy-core/src/layouts/glwe.rs b/poulpy-core/src/layouts/glwe.rs index 0f0e7f0..ba03b0e 100644 --- a/poulpy-core/src/layouts/glwe.rs +++ b/poulpy-core/src/layouts/glwe.rs @@ -189,7 +189,7 @@ impl WriterTo for GLWE { } } -pub trait GLWEToRef { +pub trait GLWEToRef: Sized { fn to_ref(&self) -> GLWE<&[u8]>; } @@ -203,14 +203,11 @@ impl GLWEToRef for GLWE { } } -pub trait GLWEToMut { +pub trait GLWEToMut: GLWEToRef { fn to_mut(&mut self) -> GLWE<&mut [u8]>; } -impl GLWEToMut for GLWE -where - Self: GLWEToRef, -{ +impl GLWEToMut for GLWE { fn to_mut(&mut self) -> GLWE<&mut [u8]> { GLWE { k: self.k, diff --git a/poulpy-core/src/operations/glwe.rs b/poulpy-core/src/operations/glwe.rs index 3a2c3b4..3d839b8 100644 --- a/poulpy-core/src/operations/glwe.rs +++ b/poulpy-core/src/operations/glwe.rs @@ -402,6 +402,84 @@ where self.vec_znx_normalize_tmp_bytes() } + /// Usage: + /// let mut tmp_b: Option> = None; + /// let (b_conv, scratch_1) = glwe_maybe_convert_in_place(self, b, res.base2k().as_u32(), &mut tmp_b, scratch); + fn glwe_maybe_cross_normalize_to_ref<'a, A>( + &self, + glwe: &'a A, + target_base2k: usize, + tmp_slot: &'a mut Option>, // caller-owned scratch-backed temp + scratch: &'a mut Scratch, + ) -> (GLWE<&'a [u8]>, &'a mut Scratch) + where + A: GLWEToRef + GLWEInfos, + Scratch: ScratchTakeCore, + { + // No conversion: just use the original GLWE + if glwe.base2k().as_usize() == target_base2k { + // Drop any previous temp; it's stale for this base + tmp_slot.take(); + return (glwe.to_ref(), scratch); + } + + // Conversion: allocate a temporary GLWE in scratch + let mut layout = glwe.glwe_layout(); + layout.base2k = target_base2k.into(); + + let (tmp, scratch2) = scratch.take_glwe(&layout); + *tmp_slot = Some(tmp); + + // Get a mutable handle to the temp and normalize into it + let tmp_ref: &mut GLWE<&mut [u8]> = tmp_slot + .as_mut() + .expect("tmp_slot just set to Some, but found None"); + + self.glwe_normalize(tmp_ref, glwe, scratch2); + + // Return a trait-object view of the temp + (tmp_ref.to_ref(), scratch2) + } + + /// Usage: + /// let mut tmp_b: Option> = None; + /// let (b_conv, scratch_1) = glwe_maybe_convert_in_place(self, b, res.base2k().as_u32(), &mut tmp_b, scratch); + fn glwe_maybe_cross_normalize_to_mut<'a, A>( + &self, + glwe: &'a mut A, + target_base2k: usize, + tmp_slot: &'a mut Option>, // caller-owned scratch-backed temp + scratch: &'a mut Scratch, + ) -> (GLWE<&'a mut [u8]>, &'a mut Scratch) + where + A: GLWEToMut + GLWEInfos, + Scratch: ScratchTakeCore, + { + // No conversion: just use the original GLWE + if glwe.base2k().as_usize() == target_base2k { + // Drop any previous temp; it's stale for this base + tmp_slot.take(); + return (glwe.to_mut(), scratch); + } + + // Conversion: allocate a temporary GLWE in scratch + let mut layout = glwe.glwe_layout(); + layout.base2k = target_base2k.into(); + + let (tmp, scratch2) = scratch.take_glwe(&layout); + *tmp_slot = Some(tmp); + + // Get a mutable handle to the temp and normalize into it + let tmp_ref: &mut GLWE<&mut [u8]> = tmp_slot + .as_mut() + .expect("tmp_slot just set to Some, but found None"); + + self.glwe_normalize(tmp_ref, glwe, scratch2); + + // Return a trait-object view of the temp + (tmp_ref.to_mut(), scratch2) + } + fn glwe_normalize(&self, res: &mut R, a: &A, scratch: &mut Scratch) where R: GLWEToMut, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/blind_retrieval.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/blind_retrieval.rs new file mode 100644 index 0000000..87b8852 --- /dev/null +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/blind_retrieval.rs @@ -0,0 +1,221 @@ +use itertools::Itertools; +use poulpy_core::{ + GLWECopy, ScratchTakeCore, + layouts::{GGSWInfos, GGSWPrepared, GLWE, GLWEInfos, GLWEToMut, GLWEToRef}, +}; +use poulpy_hal::layouts::{Backend, Module, Scratch}; + +use crate::tfhe::bdd_arithmetic::{Cmux, Cswap, GetGGSWBit}; + +pub struct GLWEBlindRetriever { + accumulators: Vec, + counter: usize, +} + +impl GLWEBlindRetriever { + pub fn alloc(infos: &A, size: usize) -> Self + where + A: GLWEInfos, + { + let log2_max_address: usize = (u32::BITS - (size as u32 - 1).leading_zeros()) as usize; + Self { + accumulators: (0..log2_max_address) + .map(|_| Accumulator::alloc(infos)) + .collect_vec(), + counter: 0, + } + } + + pub fn retrieve( + &mut self, + module: &M, + res: &mut R, + data: &[A], + selector: &S, + scratch: &mut Scratch, + ) where + M: GLWECopy + Cmux, + R: GLWEToMut, + A: GLWEToRef, + S: GetGGSWBit, + Scratch: ScratchTakeCore, + { + self.reset(); + + for ct in data { + self.add(module, ct, selector, scratch); + } + self.flush(module, res, selector, scratch); + } + + pub fn add(&mut self, module: &M, a: &A, selector: &S, scratch: &mut Scratch) + where + A: GLWEToRef, + S: GetGGSWBit, + M: GLWECopy + Cmux, + Scratch: ScratchTakeCore, + { + assert!( + (self.counter as u32) < 1 << self.accumulators.len(), + "Accumulating limit of {} reached", + 1 << self.accumulators.len() + ); + + add_core(module, a, &mut self.accumulators, 0, selector, scratch); + self.counter += 1; + } + + pub fn flush(&mut self, module: &M, res: &mut R, selector: &S, scratch: &mut Scratch) + where + R: GLWEToMut, + S: GetGGSWBit, + M: GLWECopy + Cmux, + Scratch: ScratchTakeCore, + { + for i in 0..self.accumulators.len() - 1 { + let (acc_prev, acc_next) = self.accumulators.split_at_mut(i + 1); + if acc_prev[i].num != 0 { + add_core( + module, + &acc_prev[i].data, + acc_next, + i + 1, + selector, + scratch, + ); + acc_prev[0].num = 0 + } + } + module.glwe_copy(res, &self.accumulators.last().unwrap().data); + self.reset() + } + + fn reset(&mut self) { + for acc in self.accumulators.iter_mut() { + acc.num = 0; + } + } +} + +struct Accumulator { + data: GLWE>, + num: usize, // Number of accumulated values +} + +impl Accumulator { + pub fn alloc(infos: &A) -> Self + where + A: GLWEInfos, + { + Self { + data: GLWE::alloc_from_infos(infos), + num: 0, + } + } +} + +fn add_core( + module: &M, + a: &A, + accumulators: &mut [Accumulator], + i: usize, + selector: &S, + scratch: &mut Scratch, +) where + A: GLWEToRef, + S: GetGGSWBit, + M: GLWECopy + Cmux, + Scratch: ScratchTakeCore, +{ + // Isolate the first accumulator + let (acc_prev, acc_next) = accumulators.split_at_mut(1); + + match acc_prev[0].num { + 0 => { + module.glwe_copy(&mut acc_prev[0].data, a); + acc_prev[0].num = 1; + } + 1 => { + module.cmux_inplace_neg(&mut acc_prev[0].data, a, &selector.get_bit(i), scratch); + + if !acc_next.is_empty() { + add_core( + module, + &acc_prev[0].data, + acc_next, + i + 1, + selector, + scratch, + ); + } + + acc_prev[0].num = 0 + } + _ => { + panic!("something went wrong") + } + } +} + +impl GLWEBlindRetrieval for Module where Self: GLWECopy + Cmux + Cswap {} + +pub trait GLWEBlindRetrieval +where + Self: GLWECopy + Cmux + Cswap, +{ + fn glwe_blind_retrieval_tmp_bytes(&self, res_infos: &R, k_infos: &K) -> usize + where + R: GLWEInfos, + K: GGSWInfos, + { + self.cswap_tmp_bytes(res_infos, res_infos, k_infos) + } + + fn glwe_blind_retrieval_statefull( + &self, + res: &mut Vec, + bits: &K, + bit_rsh: usize, + bit_mask: usize, + scratch: &mut Scratch, + ) where + R: GLWEToMut + GLWEInfos, + K: GetGGSWBit, + Scratch: ScratchTakeCore, + { + for i in 0..bit_mask { + let t: usize = 1 << (bit_mask - i - 1); + let bit: &GGSWPrepared<&[u8], BE> = &bits.get_bit(bit_rsh + bit_mask - i - 1); // MSB -> LSB traversal + for j in 0..t { + if j + t < res.len() { + let (lo, hi) = res.split_at_mut(j + t); + self.cswap(&mut lo[j], &mut hi[0], bit, scratch); + } + } + } + } + + fn glwe_blind_retrieval_statefull_rev( + &self, + res: &mut Vec, + bits: &K, + bit_rsh: usize, + bit_mask: usize, + scratch: &mut Scratch, + ) where + R: GLWEToMut + GLWEInfos, + K: GetGGSWBit, + Scratch: ScratchTakeCore, + { + for i in (0..bit_mask).rev() { + let t: usize = 1 << (bit_mask - i - 1); + let bit: &GGSWPrepared<&[u8], BE> = &bits.get_bit(bit_rsh + bit_mask - i - 1); // MSB -> LSB traversal + for j in 0..t { + if j < res.len() && j + t < res.len() { + let (lo, hi) = res.split_at_mut(j + t); + self.cswap(&mut lo[j], &mut hi[0], bit, scratch); + } + } + } + } +} diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/blind_selection.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/blind_selection.rs index 3242b2e..3684ff8 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/blind_selection.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/blind_selection.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use poulpy_core::{ GLWECopy, GLWEDecrypt, ScratchTakeCore, - layouts::{GGSWInfos, GGSWPrepared, GLWE, GLWEInfos, GLWEToMut, GLWEToRef}, + layouts::{GGSWInfos, GGSWPrepared, GLWE, GLWEInfos, GLWEToMut}, }; use poulpy_hal::layouts::{Backend, Module, Scratch, ZnxZero}; @@ -33,7 +33,7 @@ where scratch: &mut Scratch, ) where R: GLWEToMut, - A: GLWEToMut + GLWEToRef, + A: GLWEToMut, K: GetGGSWBit, Scratch: ScratchTakeCore, { diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint.rs index 1d60027..adb4def 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint.rs @@ -49,6 +49,18 @@ impl<'a, T: UnsignedInteger> FheUint<&'a mut [u8], T> { } } +impl<'a, T: UnsignedInteger> FheUint<&'a [u8], T> { + pub fn from_glwe_to_ref(glwe: &'a G) -> Self + where + G: GLWEToRef, + { + FheUint { + bits: glwe.to_ref(), + _phantom: PhantomData, + } + } +} + impl LWEInfos for FheUint { fn base2k(&self) -> poulpy_core::layouts::Base2K { self.bits.base2k() @@ -180,7 +192,7 @@ impl FheUint { /// Packs Vec into [FheUint]. pub fn pack(&mut self, module: &M, mut bits: Vec, keys: &H, scratch: &mut Scratch) where - G: GLWEToMut + GLWEToRef + GLWEInfos, + G: GLWEToMut + GLWEInfos, M: ModuleLogN + GLWEPacking + GLWECopy, K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, H: GLWEAutomorphismKeyHelper, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs index 90dd62e..75ba7cf 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs @@ -4,14 +4,16 @@ use std::thread; use itertools::Itertools; use poulpy_core::{ GLWECopy, GLWEExternalProductInternal, GLWENormalize, GLWESub, ScratchTakeCore, - layouts::{GGSWInfos, GGSWPrepared, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, prepared::GGSWPreparedToRef}, + layouts::{ + GGSWInfos, GGSWPrepared, GLWE, GLWEInfos, GLWELayout, GLWEToMut, GLWEToRef, LWEInfos, prepared::GGSWPreparedToRef, + }, }; use poulpy_hal::{ api::{ - ScratchAvailable, ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, - VecZnxDftBytesOf, + ScratchAvailable, ScratchTakeBasic, VecZnxBigAddSmall, VecZnxBigAddSmallInplace, VecZnxBigBytesOf, VecZnxBigNormalize, + VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallA, VecZnxDftBytesOf, }, - layouts::{Backend, DataMut, Module, Scratch, VecZnxBig, ZnxZero}, + layouts::{Backend, DataMut, Module, Scratch, VecZnxBig, ZnxInfos, ZnxZero}, }; use crate::tfhe::bdd_arithmetic::GetGGSWBit; @@ -260,6 +262,204 @@ pub enum Node { None, } +impl Cswap for Module where + Self: Sized + + GLWEExternalProductInternal + + GLWESub + + VecZnxBigAddSmallInplace + + GLWENormalize + + VecZnxDftBytesOf + + VecZnxBigNormalize + + VecZnxBigNormalizeTmpBytes + + GLWENormalize + + VecZnxBigAddSmall + + VecZnxBigSubSmallA + + VecZnxBigBytesOf +{ +} + +pub trait Cswap +where + Self: Sized + + GLWEExternalProductInternal + + GLWESub + + GLWECopy + + VecZnxBigAddSmallInplace + + GLWENormalize + + VecZnxDftBytesOf + + VecZnxBigNormalize + + VecZnxBigNormalizeTmpBytes + + GLWENormalize + + VecZnxBigAddSmall + + VecZnxBigSubSmallA + + VecZnxBigBytesOf, +{ + fn cswap_tmp_bytes(&self, res_a_infos: &R, res_b_infos: &A, s_infos: &S) -> usize + where + R: GLWEInfos, + A: GLWEInfos, + S: GGSWInfos, + { + let res_dft: usize = self.bytes_of_vec_znx_dft((s_infos.rank() + 1).into(), s_infos.size()); + let mut tot = res_dft + + (self.glwe_external_product_internal_tmp_bytes(res_a_infos, res_b_infos, s_infos) + + GLWE::bytes_of_from_infos(&GLWELayout { + n: s_infos.n(), + base2k: s_infos.base2k(), + k: res_a_infos.k().max(res_b_infos.k()), + rank: s_infos.rank(), + })) + .max(self.vec_znx_big_normalize_tmp_bytes()); + + if res_a_infos.base2k() != s_infos.base2k() { + tot += GLWE::bytes_of_from_infos(&GLWELayout { + n: res_a_infos.n(), + base2k: s_infos.base2k(), + k: res_a_infos.k(), + rank: res_a_infos.rank(), + }); + + tot += GLWE::bytes_of_from_infos(&GLWELayout { + n: res_b_infos.n(), + base2k: s_infos.base2k(), + k: res_b_infos.k(), + rank: res_b_infos.rank(), + }); + } + + tot += self.bytes_of_vec_znx_big(1, s_infos.size()); + + tot + } + + fn cswap(&self, res_a: &mut A, res_b: &mut B, s: &S, scratch: &mut Scratch) + where + A: GLWEToMut, + B: GLWEToMut, + S: GGSWPreparedToRef + GGSWInfos, + Scratch: ScratchTakeCore, + { + let res_a: &mut GLWE<&mut [u8]> = &mut res_a.to_mut(); + let res_b: &mut GLWE<&mut [u8]> = &mut res_b.to_mut(); + let s: &GGSWPrepared<&[u8], BE> = &s.to_ref(); + assert_eq!(res_a.base2k(), res_b.base2k()); + + let res_base2k: usize = res_a.base2k().as_usize(); + let s_base2k: usize = s.base2k().as_usize(); + + if res_base2k == s_base2k { + let res_big: VecZnxBig<&mut [u8], BE>; + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (s.rank() + 1).into(), s.size()); // Todo optimise + { + // Temporary value storing a - b + let tmp_c_infos: GLWELayout = GLWELayout { + n: s.n(), + base2k: s.base2k(), + k: res_a.k().max(res_b.k()), + rank: s.rank(), + }; + let (mut tmp_c, scratch_2) = scratch_1.take_glwe(&tmp_c_infos); + self.glwe_sub(&mut tmp_c, res_b, res_a); + res_big = self.glwe_external_product_internal(res_dft, &tmp_c, s, scratch_2); + } + + // Single column res_big to store temporary value before normalization + let (mut res_big_tmp, scratch_2) = scratch_1.take_vec_znx_big::<_, BE>(self, 1, res_big.size()); + + // res_a = (b-a) * bit + a + for j in 0..(res_a.rank() + 1).into() { + self.vec_znx_big_add_small(&mut res_big_tmp, 0, &res_big, j, res_a.data(), j); + self.vec_znx_big_normalize( + res_base2k, + res_a.data_mut(), + j, + s_base2k, + &res_big_tmp, + 0, + scratch_2, + ); + } + + // res_b = a - (a - b) * bit = (b - a) * bit + a + for j in 0..(res_b.rank() + 1).into() { + self.vec_znx_big_sub_small_a(&mut res_big_tmp, 0, res_b.data(), j, &res_big, j); + self.vec_znx_big_normalize( + res_base2k, + res_b.data_mut(), + j, + s_base2k, + &res_big_tmp, + 0, + scratch_2, + ); + } + } else { + let (mut tmp_a, scratch_1) = scratch.take_glwe(&GLWELayout { + n: res_a.n(), + base2k: s.base2k(), + k: res_a.k(), + rank: res_a.rank(), + }); + + let (mut tmp_b, scratch_2) = scratch_1.take_glwe(&GLWELayout { + n: res_b.n(), + base2k: s.base2k(), + k: res_b.k(), + rank: res_b.rank(), + }); + + self.glwe_normalize(&mut tmp_a, res_a, scratch_2); + self.glwe_normalize(&mut tmp_b, res_b, scratch_2); + + let res_big: VecZnxBig<&mut [u8], BE>; + let (res_dft, scratch_3) = scratch_2.take_vec_znx_dft(self, (s.rank() + 1).into(), s.size()); // Todo optimise + { + // Temporary value storing a - b + let tmp_c_infos: GLWELayout = GLWELayout { + n: s.n(), + base2k: s.base2k(), + k: res_a.k().max(res_b.k()), + rank: s.rank(), + }; + let (mut tmp_c, scratch_4) = scratch_3.take_glwe(&tmp_c_infos); + self.glwe_sub(&mut tmp_c, res_b, res_a); + res_big = self.glwe_external_product_internal(res_dft, &tmp_c, s, scratch_4); + } + + // Single column res_big to store temporary value before normalization + let (mut res_big_tmp, scratch_4) = scratch_3.take_vec_znx_big::<_, BE>(self, 1, res_big.size()); + + // res_a = (b-a) * bit + a + for j in 0..(res_a.rank() + 1).into() { + self.vec_znx_big_add_small(&mut res_big_tmp, 0, &res_big, j, tmp_a.data(), j); + self.vec_znx_big_normalize( + res_base2k, + res_a.data_mut(), + j, + s_base2k, + &res_big_tmp, + 0, + scratch_4, + ); + } + + // res_b = a - (a - b) * bit = (b - a) * bit + a + for j in 0..(res_b.rank() + 1).into() { + self.vec_znx_big_sub_small_a(&mut res_big_tmp, 0, tmp_b.data(), j, &res_big, j); + self.vec_znx_big_normalize( + res_base2k, + res_b.data_mut(), + j, + s_base2k, + &res_big_tmp, + 0, + scratch_4, + ); + } + } + } +} + pub trait Cmux where Self: Sized @@ -284,6 +484,7 @@ where .max(self.vec_znx_big_normalize_tmp_bytes()) } + // res = (t - f) * s + f fn cmux(&self, res: &mut R, t: &T, f: &F, s: &S, scratch: &mut Scratch) where R: GLWEToMut, @@ -316,6 +517,46 @@ where } } + // res = (a - res) * s + res + fn cmux_inplace_neg(&self, res: &mut R, a: &A, s: &S, scratch: &mut Scratch) + where + R: GLWEToMut, + A: GLWEToRef, + S: GGSWPreparedToRef + GGSWInfos, + Scratch: ScratchTakeCore, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let s: &GGSWPrepared<&[u8], BE> = &s.to_ref(); + let a: &GLWE<&[u8]> = &a.to_ref(); + + assert_eq!(res.base2k(), a.base2k()); + + let res_base2k: usize = res.base2k().into(); + let ggsw_base2k: usize = s.base2k().into(); + let (mut tmp, scratch_1) = scratch.take_glwe(&GLWELayout { + n: s.n(), + base2k: res.base2k(), + k: res.k().max(a.k()), + rank: res.rank(), + }); + self.glwe_sub(&mut tmp, a, res); + let (res_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, (res.rank() + 1).into(), s.size()); // Todo optimise + let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_external_product_internal(res_dft, &tmp, s, scratch_2); + for j in 0..(res.rank() + 1).into() { + self.vec_znx_big_add_small_inplace(&mut res_big, j, res.data(), j); + self.vec_znx_big_normalize( + res_base2k, + res.data_mut(), + j, + ggsw_base2k, + &res_big, + j, + scratch_2, + ); + } + } + + // res = (res - a) * s + a fn cmux_inplace(&self, res: &mut R, a: &A, s: &S, scratch: &mut Scratch) where R: GLWEToMut, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/mod.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/mod.rs index 7b61bf2..7ef1921 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/mod.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/mod.rs @@ -1,4 +1,5 @@ mod bdd_2w_to_1w; +mod blind_retrieval; mod blind_rotation; mod blind_selection; mod ciphertexts; @@ -7,6 +8,7 @@ mod eval; mod key; pub use bdd_2w_to_1w::*; +pub use blind_retrieval::*; pub use blind_rotation::*; pub use blind_selection::*; pub use ciphertexts::*; diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/fft64_ref.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/fft64_ref.rs index cd316d0..736f641 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/fft64_ref.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/fft64_ref.rs @@ -6,7 +6,8 @@ use crate::tfhe::{ bdd_arithmetic::tests::test_suite::{ TestContext, test_bdd_add, test_bdd_and, test_bdd_or, test_bdd_prepare, test_bdd_sll, test_bdd_slt, test_bdd_sltu, test_bdd_sra, test_bdd_srl, test_bdd_sub, test_bdd_xor, test_fhe_uint_get_bit_glwe, test_fhe_uint_sext, - test_fhe_uint_splice_u8, test_fhe_uint_splice_u16, test_glwe_blind_selection, test_glwe_to_glwe_blind_rotation, + test_fhe_uint_splice_u8, test_fhe_uint_splice_u16, test_fhe_uint_swap, test_glwe_blind_retrieval_statefull, + test_glwe_blind_retriever, test_glwe_blind_selection, test_glwe_to_glwe_blind_rotation, test_scalar_to_ggsw_blind_rotation, }, blind_rotation::CGGI, @@ -15,6 +16,21 @@ use crate::tfhe::{ static TEST_CONTEXT_CGGI_FFT64_REF: LazyLock> = LazyLock::new(|| TestContext::::new()); +#[test] +fn test_glwe_blind_retriever_fft64_ref() { + test_glwe_blind_retriever(&TEST_CONTEXT_CGGI_FFT64_REF); +} + +#[test] +fn test_glwe_blind_retrieval_statefull_fft64_ref() { + test_glwe_blind_retrieval_statefull(&TEST_CONTEXT_CGGI_FFT64_REF); +} + +#[test] +fn test_fhe_uint_swap_fft64_ref() { + test_fhe_uint_swap(&TEST_CONTEXT_CGGI_FFT64_REF); +} + #[test] fn test_fhe_uint_get_bit_glwe_fft64_ref() { test_fhe_uint_get_bit_glwe(&TEST_CONTEXT_CGGI_FFT64_REF); diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs index 08de35a..74fb6ce 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs @@ -12,6 +12,7 @@ mod sltu; mod sra; mod srl; mod sub; +mod swap; mod xor; pub use add::*; @@ -33,6 +34,7 @@ pub use sltu::*; pub use sra::*; pub use srl::*; pub use sub::*; +pub use swap::*; pub use xor::*; use poulpy_core::{ diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/swap.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/swap.rs new file mode 100644 index 0000000..a9e5d0c --- /dev/null +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/swap.rs @@ -0,0 +1,203 @@ +use itertools::Itertools; +use poulpy_core::{ + GGSWEncryptSk, GLWEDecrypt, GLWEEncryptSk, + layouts::{GGSW, GGSWPrepared, GGSWPreparedFactory, GLWELayout, GLWESecretPrepared}, +}; +use poulpy_hal::{ + api::{ScratchOwnedAlloc, ScratchOwnedBorrow}, + layouts::{Backend, Module, ScalarZnx, Scratch, ScratchOwned, ZnxViewMut}, + source::Source, +}; +use rand::RngCore; + +use crate::tfhe::{ + bdd_arithmetic::{ + Cswap, FheUint, FheUintPrepared, GLWEBlindRetrieval, GLWEBlindRetriever, ScratchTakeBDD, + tests::test_suite::{TEST_GGSW_INFOS, TEST_GLWE_INFOS, TestContext}, + }, + blind_rotation::BlindRotationAlgo, +}; + +pub fn test_fhe_uint_swap(test_context: &TestContext) +where + Module: GLWEEncryptSk + GLWEDecrypt + Cswap + GGSWEncryptSk + GGSWPreparedFactory, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchTakeBDD, +{ + let glwe_infos: GLWELayout = TEST_GLWE_INFOS; + let ggsw_infos: poulpy_core::layouts::GGSWLayout = TEST_GGSW_INFOS; + + let module: &Module = &test_context.module; + let sk: &GLWESecretPrepared, BE> = &test_context.sk_glwe; + + let mut source_xa: Source = Source::new([2u8; 32]); + let mut source_xe: Source = Source::new([3u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(1 << 22); + + let mut s: GGSW> = GGSW::alloc_from_infos(&ggsw_infos); + let mut s_prepared: GGSWPrepared, BE> = GGSWPrepared::alloc_from_infos(module, &ggsw_infos); + + let a: u32 = source_xa.next_u32(); + let b: u32 = source_xa.next_u32(); + + for bit in [0, 1] { + let mut a_enc: FheUint, u32> = FheUint::, u32>::alloc_from_infos(&glwe_infos); + let mut b_enc: FheUint, u32> = FheUint::, u32>::alloc_from_infos(&glwe_infos); + + a_enc.encrypt_sk( + module, + a, + sk, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + b_enc.encrypt_sk( + module, + b, + sk, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + let mut pt: ScalarZnx> = ScalarZnx::alloc(module.n(), 1); + pt.raw_mut()[0] = bit; + s.encrypt_sk( + module, + &pt, + sk, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + s_prepared.prepare(module, &s, scratch.borrow()); + + module.cswap(&mut a_enc, &mut b_enc, &s_prepared, scratch.borrow()); + + let (a_want, b_want) = if bit == 0 { (a, b) } else { (b, a) }; + + assert_eq!(a_want, a_enc.decrypt(module, sk, scratch.borrow())); + assert_eq!(b_want, b_enc.decrypt(module, sk, scratch.borrow())); + } +} + +pub fn test_glwe_blind_retrieval_statefull(test_context: &TestContext) +where + Module: GLWEEncryptSk + GLWEDecrypt + GLWEBlindRetrieval + GGSWEncryptSk + GGSWPreparedFactory, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchTakeBDD, +{ + let glwe_infos: GLWELayout = TEST_GLWE_INFOS; + let ggsw_infos: poulpy_core::layouts::GGSWLayout = TEST_GGSW_INFOS; + + let module: &Module = &test_context.module; + let sk: &GLWESecretPrepared, BE> = &test_context.sk_glwe; + + let mut source_xa: Source = Source::new([2u8; 32]); + let mut source_xe: Source = Source::new([3u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(1 << 22); + + let data: Vec = (0..25).map(|i| i as u32).collect_vec(); + + let mut data_enc: Vec, u32>> = (0..data.len()) + .map(|i| { + let mut ct: FheUint, u32> = FheUint::, u32>::alloc_from_infos(&glwe_infos); + ct.encrypt_sk( + module, + data[i], + sk, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + ct + }) + .collect_vec(); + + for idx in 0..data.len() as u32 { + let mut idx_enc = FheUintPrepared::alloc_from_infos(module, &ggsw_infos); + idx_enc.encrypt_sk( + module, + idx, + sk, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + module.glwe_blind_retrieval_statefull(&mut data_enc, &idx_enc, 0, 5, scratch.borrow()); + + assert_eq!( + data[idx as usize], + data_enc[0].decrypt(module, sk, scratch.borrow()) + ); + + module.glwe_blind_retrieval_statefull_rev(&mut data_enc, &idx_enc, 0, 5, scratch.borrow()); + + for i in 0..data.len() { + assert_eq!(data[i], data_enc[i].decrypt(module, sk, scratch.borrow())) + } + } +} + +pub fn test_glwe_blind_retriever(test_context: &TestContext) +where + Module: GLWEEncryptSk + GLWEDecrypt + GLWEBlindRetrieval + GGSWEncryptSk + GGSWPreparedFactory, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchTakeBDD, +{ + let glwe_infos: GLWELayout = TEST_GLWE_INFOS; + let ggsw_infos: poulpy_core::layouts::GGSWLayout = TEST_GGSW_INFOS; + + let module: &Module = &test_context.module; + let sk: &GLWESecretPrepared, BE> = &test_context.sk_glwe; + + let mut source_xa: Source = Source::new([2u8; 32]); + let mut source_xe: Source = Source::new([3u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(1 << 22); + + let data: Vec = (0..25).map(|i| i as u32).collect_vec(); + + let data_enc: Vec, u32>> = (0..data.len()) + .map(|i| { + let mut ct: FheUint, u32> = FheUint::, u32>::alloc_from_infos(&glwe_infos); + ct.encrypt_sk( + module, + data[i], + sk, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + ct + }) + .collect_vec(); + + for idx in 0..data.len() as u32 { + let mut idx_enc: FheUintPrepared, u32, BE> = FheUintPrepared::alloc_from_infos(module, &ggsw_infos); + idx_enc.encrypt_sk( + module, + idx, + sk, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + let mut retriever: GLWEBlindRetriever = GLWEBlindRetriever::alloc(&glwe_infos, 25); + let mut res: FheUint, u32> = FheUint::alloc_from_infos(&glwe_infos); + retriever.retrieve(module, &mut res, &data_enc, &idx_enc, scratch.borrow()); + + println!("{}", res.decrypt(module, sk, scratch.borrow())); + + assert_eq!( + data[idx as usize], + res.decrypt(module, sk, scratch.borrow()) + ); + } +}