diff --git a/poulpy-core/src/layouts/lwe_secret.rs b/poulpy-core/src/layouts/lwe_secret.rs index 78ad9c7..00a7849 100644 --- a/poulpy-core/src/layouts/lwe_secret.rs +++ b/poulpy-core/src/layouts/lwe_secret.rs @@ -4,6 +4,7 @@ use poulpy_hal::{ }; use crate::{ + GetDistribution, dist::Distribution, layouts::{Base2K, Degree, LWEInfos, TorusPrecision}, }; @@ -22,6 +23,12 @@ impl LWESecret> { } } +impl GetDistribution for LWESecret { + fn dist(&self) -> &Distribution { + &self.dist + } +} + impl LWESecret { pub fn raw(&self) -> &[i64] { self.data.at(0, 0) diff --git a/poulpy-core/src/lib.rs b/poulpy-core/src/lib.rs index 78fb717..a5c5152 100644 --- a/poulpy-core/src/lib.rs +++ b/poulpy-core/src/lib.rs @@ -23,9 +23,8 @@ pub use external_product::*; pub use glwe_packing::*; pub use keyswitching::*; pub use noise::*; +pub use scratch::*; pub use encryption::SIGMA; -pub use scratch::*; - pub mod tests; diff --git a/poulpy-core/src/operations/glwe.rs b/poulpy-core/src/operations/glwe.rs index 417e5f2..c6f1818 100644 --- a/poulpy-core/src/operations/glwe.rs +++ b/poulpy-core/src/operations/glwe.rs @@ -217,6 +217,8 @@ where } } +impl GLWEMulXpMinusOne for Module where Self: ModuleN + VecZnxMulXpMinusOne + VecZnxMulXpMinusOneInplace {} + pub trait GLWEMulXpMinusOne where Self: ModuleN + VecZnxMulXpMinusOne + VecZnxMulXpMinusOneInplace, diff --git a/poulpy-hal/src/api/scratch.rs b/poulpy-hal/src/api/scratch.rs index 714a58d..9e3c484 100644 --- a/poulpy-hal/src/api/scratch.rs +++ b/poulpy-hal/src/api/scratch.rs @@ -34,13 +34,9 @@ pub trait ScratchTakeBasic where Self: TakeSlice, { - fn take_scalar_znx(&mut self, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) - { + fn take_scalar_znx(&mut self, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) { let (take_slice, rem_slice) = self.take_slice(ScalarZnx::bytes_of(n, cols)); - ( - ScalarZnx::from_data(take_slice, n, cols), - rem_slice, - ) + (ScalarZnx::from_data(take_slice, n, cols), rem_slice) } fn take_svp_ppol(&mut self, module: &M, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Self) @@ -51,12 +47,9 @@ where (SvpPPol::from_data(take_slice, module.n(), cols), rem_slice) } - fn take_vec_znx(&mut self, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self){ + fn take_vec_znx(&mut self, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) { let (take_slice, rem_slice) = self.take_slice(VecZnx::bytes_of(n, cols, size)); - ( - VecZnx::from_data(take_slice, n, cols, size), - rem_slice, - ) + (VecZnx::from_data(take_slice, n, cols, size), rem_slice) } fn take_vec_znx_big(&mut self, module: &M, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, &mut Self) @@ -102,7 +95,7 @@ where (slice, scratch) } - fn take_vec_znx_slice(&mut self, n: usize, len: usize, cols: usize, size: usize) -> (Vec>, &mut Self){ + fn take_vec_znx_slice(&mut self, len: usize, n: usize, cols: usize, size: usize) -> (Vec>, &mut Self) { let mut scratch: &mut Self = self; let mut slice: Vec> = Vec::with_capacity(len); for _ in 0..len { @@ -133,13 +126,12 @@ where fn take_mat_znx( &mut self, - n: usize, + n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize, - ) -> (MatZnx<&mut [u8]>, &mut Self) - { + ) -> (MatZnx<&mut [u8]>, &mut Self) { let (take_slice, rem_slice) = self.take_slice(MatZnx::bytes_of(n, rows, cols_in, cols_out, size)); ( MatZnx::from_data(take_slice, n, rows, cols_in, cols_out, size), diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block.rs index 70bd910..3e07242 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block.rs @@ -2,17 +2,19 @@ use std::marker::PhantomData; use poulpy_core::layouts::{Base2K, GLWE, GLWEInfos, GLWEPlaintextLayout, LWEInfos, Rank, TorusPrecision}; -use poulpy_core::{TakeGLWEPlaintext, layouts::prepared::GLWESecretPrepared}; +#[cfg(test)] +use poulpy_core::ScratchTakeCore; +use poulpy_core::{layouts::prepared::GLWESecretPrepared}; use poulpy_hal::api::VecZnxBigBytesOf; #[cfg(test)] use poulpy_hal::api::{ - ScratchAvailable, TakeVecZnx, VecZnxAddInplace, VecZnxAddNormal, VecZnxFillUniform, VecZnxNormalize, VecZnxSub, + VecZnxAddInplace, VecZnxAddNormal, VecZnxFillUniform, VecZnxNormalize, VecZnxSub, }; #[cfg(test)] use poulpy_hal::source::Source; use poulpy_hal::{ api::{ - TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxDftApply, + VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalizeTmpBytes, }, layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}, @@ -96,7 +98,7 @@ impl FheUintBlocks { + VecZnxAddNormal + VecZnxNormalize + VecZnxSub, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeGLWEPlaintext, + Scratch: ScratchTakeCore, { use poulpy_core::layouts::GLWEPlaintextLayout; @@ -136,7 +138,7 @@ impl FheUintBlocks { + VecZnxBigAddInplace + VecZnxBigAddSmallInplace + VecZnxBigNormalize, - Scratch: TakeVecZnxDft + TakeVecZnxBig + TakeGLWEPlaintext, + Scratch: ScratchTakeCore, { #[cfg(debug_assertions)] { @@ -186,7 +188,7 @@ impl FheUintBlocks { + VecZnxNormalizeTmpBytes + VecZnxSubInplace + VecZnxNormalizeInplace, - Scratch: TakeGLWEPlaintext + TakeVecZnxDft + TakeVecZnxBig, + Scratch: ScratchTakeCore, { #[cfg(debug_assertions)] { diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block_prepared.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block_prepared.rs index aa70910..1ed7229 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block_prepared.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block_prepared.rs @@ -144,7 +144,7 @@ impl FheUintBlocksPrep FheUintWord { + VecZnxAutomorphismInplace + VecZnxBigSubSmallNegateInplace + VecZnxRotate, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeGLWESlice, + Scratch: ScratchTakeCore, { // Repacks the GLWE ciphertexts bits let gap: usize = module.n() / T::WORD_SIZE; @@ -122,7 +120,7 @@ impl FheUintWord { + VecZnxAddNormal + VecZnxNormalize + VecZnxSub, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeGLWEPlaintext, + Scratch: ScratchTakeCore, { #[cfg(debug_assertions)] { @@ -167,7 +165,7 @@ impl FheUintWord { + VecZnxBigAddInplace + VecZnxBigAddSmallInplace + VecZnxBigNormalize, - Scratch: TakeVecZnxDft + TakeVecZnxBig + TakeGLWEPlaintext, + Scratch: ScratchTakeCore, { #[cfg(debug_assertions)] { diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs index 245fe03..953ee7a 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs @@ -1,10 +1,8 @@ use itertools::Itertools; use poulpy_core::{ - GLWEExternalProductInplace, GLWEOperations, TakeGLWESlice, layouts::{ - GLWE, GLWEToMut, LWEInfos, - prepared::{GGSWPrepared, GGSWPreparedToRef}, - }, + prepared::{GGSWPrepared, GGSWPreparedToRef}, GLWEToMut, LWEInfos, GLWE + }, GLWEExternalProduct, ScratchTakeCore }; use poulpy_hal::{ api::{VecZnxAddInplace, VecZnxCopy, VecZnxNegateInplace, VecZnxSub}, @@ -49,7 +47,7 @@ impl Circuit where Self: GetBitCircuitInfo, Module: Cmux + VecZnxCopy, - Scratch: TakeGLWESlice, + Scratch: ScratchTakeCore, { fn execute( &self, @@ -169,7 +167,7 @@ pub trait Cmux { impl Cmux for Module where - Module: GLWEExternalProductInplace + VecZnxSub + VecZnxCopy + VecZnxNegateInplace + VecZnxAddInplace, + Module: GLWEExternalProduct + VecZnxSub + VecZnxCopy + VecZnxNegateInplace + VecZnxAddInplace, { fn cmux(&self, out: &mut GLWE, t: &GLWE, f: &GLWE, s: &GGSWPrepared, scratch: &mut Scratch) where diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs index fdcf5b3..e2b8453 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs @@ -9,16 +9,13 @@ use crate::tfhe::{ }, }; use poulpy_core::{ - TakeGGSW, TakeGLWE, layouts::{ - GLWESecret, GLWEToLWEKeyLayout, GLWEToLWESwitchingKey, LWE, LWESecret, - prepared::{GLWEToLWESwitchingKeyPrepared, Prepare, PrepareAlloc}, - }, + prepared::GLWEToLWESwitchingKeyPrepared, GLWESecret, GLWEToLWEKeyLayout, GLWEToLWESwitchingKey, LWESecret + }, ScratchTakeCore, }; use poulpy_hal::{ api::{ - ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, TakeScalarZnx, - TakeSvpPPol, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, + ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, @@ -96,7 +93,7 @@ impl BDDKey, Vec, BRA> { + SvpPPolAlloc + VecZnxAutomorphism + VecZnxAutomorphismInplace, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeSvpPPol + TakeVecZnxBig, + Scratch: ScratchTakeCore, { let mut ks: GLWEToLWESwitchingKey> = GLWEToLWESwitchingKey::alloc(&infos.ks_infos()); ks.encrypt_sk(module, sk_lwe, sk_glwe, source_xa, source_xe, scratch); @@ -217,7 +214,7 @@ where + VecZnxBigNormalize + VecZnxNormalize + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable + TakeVecZnxDft + TakeGLWE + TakeVecZnx + TakeGGSW, + Scratch: ScratchTakeCore, CircuitBootstrappingKeyPrepared: CirtuitBootstrappingExecute, { fn prepare( diff --git a/poulpy-schemes/src/tfhe/blind_rotation/cggi_algo.rs b/poulpy-schemes/src/tfhe/blind_rotation/algorithms/cggi/algorithm.rs similarity index 55% rename from poulpy-schemes/src/tfhe/blind_rotation/cggi_algo.rs rename to poulpy-schemes/src/tfhe/blind_rotation/algorithms/cggi/algorithm.rs index 8bf0a9e..b9ec277 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/cggi_algo.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/algorithms/cggi/algorithm.rs @@ -1,158 +1,142 @@ use itertools::izip; use poulpy_hal::{ api::{ - ScratchAvailable, SvpApplyDftToDft, SvpPPolBytesOf, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice, - TakeVecZnxSlice, VecZnxAddInplace, VecZnxBigAddSmallInplace, VecZnxBigBytesOf, VecZnxBigNormalize, - VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftApply, VecZnxDftBytesOf, - VecZnxDftSubInplace, VecZnxDftZero, VecZnxIdftApply, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpBytes, - VecZnxMulXpMinusOneInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, - VecZnxSubInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + ModuleN, ScratchTakeBasic, SvpApplyDftToDft, VecZnxBigAddSmallInplace, VecZnxBigBytesOf, VecZnxBigNormalize, + VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAddInplace, VecZnxDftApply, VecZnxDftBytesOf, VecZnxDftSubInplace, + VecZnxDftZero, VecZnxIdftApply, VecZnxIdftApplyTmpBytes, VecZnxRotate, VmpApplyDftToDft, VmpApplyDftToDftTmpBytes, }, - layouts::{Backend, DataMut, DataRef, Module, Scratch, SvpPPol, VecZnx, ZnxView, ZnxZero}, + layouts::{Backend, DataMut, DataRef, Module, Scratch, SvpPPol, VecZnx, ZnxZero}, }; use poulpy_core::{ - Distribution, GLWEOperations, TakeGLWE, + Distribution, GLWEAdd, GLWEExternalProduct, GLWEMulXpMinusOne, GLWENormalize, ScratchTakeCore, layouts::{GGSWInfos, GLWE, GLWEInfos, GLWEToMut, LWE, LWEInfos, LWEToRef}, }; use crate::tfhe::blind_rotation::{ - BlincRotationExecute, BlindRotationKeyInfos, BlindRotationKeyPrepared, CGGI, LookUpTable, LookUpTableRotationDirection, + BlindRotationExecute, BlindRotationKeyInfos, BlindRotationKeyPrepared, CGGI, LookupTable, mod_switch_2n, }; -#[allow(clippy::too_many_arguments)] -pub fn cggi_blind_rotate_tmp_bytes( - module: &Module, - block_size: usize, - extension_factor: usize, - glwe_infos: &OUT, - brk_infos: &GGSW, -) -> usize +impl BlindRotationExecute for Module where - OUT: GLWEInfos, - GGSW: GGSWInfos, - Module: VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxNormalizeTmpBytes + Self: VecZnxDftBytesOf + VecZnxBigBytesOf - + VecZnxIdftApplyTmpBytes - + VecZnxBigNormalizeTmpBytes, -{ - let brk_size: usize = brk_infos.size(); - - if block_size > 1 { - let cols: usize = (brk_infos.rank() + 1).into(); - let dnum: usize = brk_infos.dnum().into(); - let acc_dft: usize = module.bytes_of_vec_znx_dft(cols, dnum) * extension_factor; - let acc_big: usize = module.bytes_of_vec_znx_big(1, brk_size); - let vmp_res: usize = module.bytes_of_vec_znx_dft(cols, brk_size) * extension_factor; - let vmp_xai: usize = module.bytes_of_vec_znx_dft(1, brk_size); - let acc_dft_add: usize = vmp_res; - let vmp: usize = module.vmp_apply_dft_to_dft_tmp_bytes(brk_size, dnum, dnum, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2) - let acc: usize = if extension_factor > 1 { - VecZnx::bytes_of(module.n(), cols, glwe_infos.size()) * extension_factor - } else { - 0 - }; - - acc + acc_dft - + acc_dft_add - + vmp_res - + vmp_xai - + (vmp | (acc_big + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_apply_tmp_bytes()))) - } else { - GLWE::bytes_of(glwe_infos) + GLWE::external_product_inplace_tmp_bytes(module, glwe_infos, brk_infos) - } -} - -impl BlincRotationExecute for BlindRotationKeyPrepared -where - Module: VecZnxBigBytesOf - + VecZnxDftBytesOf - + SvpPPolBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxIdftApplyTmpBytes - + VecZnxIdftApply - + VecZnxDftAdd - + VecZnxDftAddInplace - + VecZnxDftApply - + VecZnxDftZero - + SvpApplyDftToDft - + VecZnxDftSubInplace - + VecZnxBigAddSmallInplace + + GLWEExternalProduct + + ModuleN + VecZnxRotate - + VecZnxAddInplace - + VecZnxSubInplace - + VecZnxNormalize - + VecZnxNormalizeInplace + + VecZnxDftApply + + VecZnxDftZero + + VmpApplyDftToDft + + SvpApplyDftToDft + + VecZnxDftAddInplace + + VecZnxDftSubInplace + + VecZnxIdftApply + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize + VecZnxCopy - + VecZnxMulXpMinusOneInplace - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxIdftApplyConsume - + VecZnxBigNormalize - + VecZnxNormalizeTmpBytes, - Scratch: TakeVecZnxDftSlice + TakeVecZnxDft + TakeVecZnxBig + TakeVecZnxSlice + TakeVecZnx + ScratchAvailable, + + GLWEMulXpMinusOne + + GLWEAdd + + GLWENormalize, + Scratch: ScratchTakeCore, { - fn execute( + fn blind_rotation_execute_tmp_bytes( + &self, + block_size: usize, + extension_factor: usize, + glwe_infos: &G, + brk_infos: &B, + ) -> usize + where + G: GLWEInfos, + B: GGSWInfos, + { + let brk_size: usize = brk_infos.size(); + + if block_size > 1 { + let cols: usize = (brk_infos.rank() + 1).into(); + let dnum: usize = brk_infos.dnum().into(); + let acc_dft: usize = self.bytes_of_vec_znx_dft(cols, dnum) * extension_factor; + let acc_big: usize = self.bytes_of_vec_znx_big(1, brk_size); + let vmp_res: usize = self.bytes_of_vec_znx_dft(cols, brk_size) * extension_factor; + let vmp_xai: usize = self.bytes_of_vec_znx_dft(1, brk_size); + let acc_dft_add: usize = vmp_res; + let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes(brk_size, dnum, dnum, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2) + let acc: usize = if extension_factor > 1 { + VecZnx::bytes_of(self.n(), cols, glwe_infos.size()) * extension_factor + } else { + 0 + }; + + acc + acc_dft + + acc_dft_add + + vmp_res + + vmp_xai + + (vmp + | (acc_big + + (self + .vec_znx_big_normalize_tmp_bytes() + .max(self.vec_znx_idft_apply_tmp_bytes())))) + } else { + GLWE::bytes_of_from_infos(glwe_infos) + GLWE::external_product_tmp_bytes(self, glwe_infos, glwe_infos, brk_infos) + } + } + + fn blind_rotation_execute( &self, - module: &Module, res: &mut GLWE, - lwe: &LWE, - lut: &LookUpTable, - scratch: &mut Scratch, - ) { - match self.dist { + lwe: &LWE
, + lut: &LookupTable, + brk: &BlindRotationKeyPrepared, + scratch: &mut Scratch, + ) where + DR: DataMut, + DL: DataRef, + DB: DataRef, + { + match brk.dist { Distribution::BinaryBlock(_) | Distribution::BinaryFixed(_) | Distribution::BinaryProb(_) | Distribution::ZERO => { if lut.extension_factor() > 1 { - execute_block_binary_extended(module, res, lwe, lut, self, scratch) - } else if self.block_size() > 1 { - execute_block_binary(module, res, lwe, lut, self, scratch); + execute_block_binary_extended(self, res, lwe, lut, brk, scratch) + } else if brk.block_size() > 1 { + execute_block_binary(self, res, lwe, lut, brk, scratch); } else { - execute_standard(module, res, lwe, lut, self, scratch); + execute_standard(self, res, lwe, lut, brk, scratch); } } - _ => panic!("invalid CGGI distribution"), + _ => panic!("invalid CGGI distribution (have you prepared the key?)"), } } } -fn execute_block_binary_extended( - module: &Module, +fn execute_block_binary_extended( + module: &M, res: &mut GLWE, lwe: &LWE, - lut: &LookUpTable, - brk: &BlindRotationKeyPrepared, - scratch: &mut Scratch, + lut: &LookupTable, + brk: &BlindRotationKeyPrepared, + scratch: &mut Scratch, ) where DataRes: DataMut, DataIn: DataRef, DataBrk: DataRef, - Module: VecZnxBigBytesOf - + VecZnxDftBytesOf - + SvpPPolBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VecZnxIdftApplyTmpBytes - + VecZnxIdftApply - + VecZnxDftAdd - + VecZnxDftAddInplace - + VecZnxDftApply - + VecZnxDftZero - + SvpApplyDftToDft - + VecZnxDftSubInplace - + VecZnxBigAddSmallInplace + M: VecZnxDftBytesOf + + ModuleN + VecZnxRotate - + VecZnxAddInplace - + VecZnxSubInplace - + VecZnxNormalize - + VecZnxNormalizeInplace + + VecZnxDftApply + + VecZnxDftZero + + VmpApplyDftToDft + + SvpApplyDftToDft + + VecZnxDftAddInplace + + VecZnxDftSubInplace + + VecZnxIdftApply + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize + VecZnxCopy - + VecZnxMulXpMinusOneInplace - + VecZnxBigNormalize - + VmpApplyDftToDft, - Scratch: TakeVecZnxDftSlice + TakeVecZnxDft + TakeVecZnxBig + TakeVecZnxSlice + ScratchAvailable + TakeVecZnx, + + VecZnxBigBytesOf, + Scratch: ScratchTakeCore, { let n_glwe: usize = brk.n_glwe().into(); let extension_factor: usize = lut.extension_factor(); @@ -161,16 +145,16 @@ fn execute_block_binary_extended( let cols: usize = (res.rank() + 1).into(); let (mut acc, scratch_1) = scratch.take_vec_znx_slice(extension_factor, n_glwe, cols, res.size()); - let (mut acc_dft, scratch_2) = scratch_1.take_vec_znx_dft_slice(extension_factor, n_glwe, cols, dnum); - let (mut vmp_res, scratch_3) = scratch_2.take_vec_znx_dft_slice(extension_factor, n_glwe, cols, brk.size()); - let (mut acc_add_dft, scratch_4) = scratch_3.take_vec_znx_dft_slice(extension_factor, n_glwe, cols, brk.size()); - let (mut vmp_xai, scratch_5) = scratch_4.take_vec_znx_dft(n_glwe, 1, brk.size()); + let (mut acc_dft, scratch_2) = scratch_1.take_vec_znx_dft_slice(module, extension_factor, cols, dnum); + let (mut vmp_res, scratch_3) = scratch_2.take_vec_znx_dft_slice(module, extension_factor, cols, brk.size()); + let (mut acc_add_dft, scratch_4) = scratch_3.take_vec_znx_dft_slice(module, extension_factor, cols, brk.size()); + let (mut vmp_xai, scratch_5) = scratch_4.take_vec_znx_dft(module, 1, brk.size()); (0..extension_factor).for_each(|i| { acc[i].zero(); }); - let x_pow_a: &Vec, B>>; + let x_pow_a: &Vec, BE>>; if let Some(b) = &brk.x_pow_a { x_pow_a = b } else { @@ -268,7 +252,7 @@ fn execute_block_binary_extended( }); { - let (mut acc_add_big, scratch7) = scratch_5.take_vec_znx_big(n_glwe, 1, brk.size()); + let (mut acc_add_big, scratch7) = scratch_5.take_vec_znx_big(module, 1, brk.size()); (0..extension_factor).for_each(|j| { (0..cols).for_each(|i| { @@ -285,41 +269,32 @@ fn execute_block_binary_extended( }); } -fn execute_block_binary( - module: &Module, +fn execute_block_binary( + module: &M, res: &mut GLWE, lwe: &LWE, - lut: &LookUpTable, - brk: &BlindRotationKeyPrepared, - scratch: &mut Scratch, + lut: &LookupTable, + brk: &BlindRotationKeyPrepared, + scratch: &mut Scratch, ) where DataRes: DataMut, DataIn: DataRef, DataBrk: DataRef, - Module: VecZnxBigBytesOf - + VecZnxDftBytesOf - + SvpPPolBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VecZnxIdftApplyTmpBytes - + VecZnxIdftApply - + VecZnxDftAdd - + VecZnxDftAddInplace - + VecZnxDftApply - + VecZnxDftZero - + SvpApplyDftToDft - + VecZnxDftSubInplace - + VecZnxBigAddSmallInplace + M: VecZnxDftBytesOf + + ModuleN + VecZnxRotate - + VecZnxAddInplace - + VecZnxSubInplace - + VecZnxNormalize - + VecZnxNormalizeInplace + + VecZnxDftApply + + VecZnxDftZero + + VmpApplyDftToDft + + SvpApplyDftToDft + + VecZnxDftAddInplace + + VecZnxDftSubInplace + + VecZnxIdftApply + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize + VecZnxCopy - + VecZnxMulXpMinusOneInplace - + VmpApplyDftToDft - + VecZnxBigNormalize, - Scratch: TakeVecZnxDftSlice + TakeVecZnxDft + TakeVecZnxBig + TakeVecZnxSlice + ScratchAvailable + TakeVecZnx, + + VecZnxBigBytesOf, + Scratch: ScratchTakeCore, { let n_glwe: usize = brk.n_glwe().into(); let mut lwe_2n: Vec = vec![0i64; (lwe.n() + 1).into()]; // TODO: from scratch space @@ -350,12 +325,12 @@ fn execute_block_binary( // ACC + [sum DFT(X^ai -1) * (DFT(ACC) x BRKi)] - let (mut acc_dft, scratch_1) = scratch.take_vec_znx_dft(n_glwe, cols, dnum); - let (mut vmp_res, scratch_2) = scratch_1.take_vec_znx_dft(n_glwe, cols, brk.size()); - let (mut acc_add_dft, scratch_3) = scratch_2.take_vec_znx_dft(n_glwe, cols, brk.size()); - let (mut vmp_xai, scratch_4) = scratch_3.take_vec_znx_dft(n_glwe, 1, brk.size()); + let (mut acc_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols, dnum); + let (mut vmp_res, scratch_2) = scratch_1.take_vec_znx_dft(module, cols, brk.size()); + let (mut acc_add_dft, scratch_3) = scratch_2.take_vec_znx_dft(module, cols, brk.size()); + let (mut vmp_xai, scratch_4) = scratch_3.take_vec_znx_dft(module, 1, brk.size()); - let x_pow_a: &Vec, B>>; + let x_pow_a: &Vec, BE>>; if let Some(b) = &brk.x_pow_a { x_pow_a = b } else { @@ -388,7 +363,7 @@ fn execute_block_binary( }); { - let (mut acc_add_big, scratch_5) = scratch_4.take_vec_znx_big(n_glwe, 1, brk.size()); + let (mut acc_add_big, scratch_5) = scratch_4.take_vec_znx_big(module, 1, brk.size()); (0..cols).for_each(|i| { module.vec_znx_idft_apply(&mut acc_add_big, 0, &acc_add_dft, i, scratch_5); @@ -407,44 +382,19 @@ fn execute_block_binary( }); } -fn execute_standard( - module: &Module, +fn execute_standard( + module: &M, res: &mut GLWE, lwe: &LWE, - lut: &LookUpTable, - brk: &BlindRotationKeyPrepared, - scratch: &mut Scratch, + lut: &LookupTable, + brk: &BlindRotationKeyPrepared, + scratch: &mut Scratch, ) where DataRes: DataMut, DataIn: DataRef, DataBrk: DataRef, - Module: VecZnxBigBytesOf - + VecZnxDftBytesOf - + SvpPPolBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VecZnxIdftApplyTmpBytes - + VecZnxIdftApply - + VecZnxDftAdd - + VecZnxDftAddInplace - + VecZnxDftApply - + VecZnxDftZero - + SvpApplyDftToDft - + VecZnxDftSubInplace - + VecZnxBigAddSmallInplace - + VecZnxRotate - + VecZnxAddInplace - + VecZnxSubInplace - + VecZnxNormalize - + VecZnxNormalizeInplace - + VecZnxCopy - + VecZnxMulXpMinusOneInplace - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxIdftApplyConsume - + VecZnxBigNormalize - + VecZnxNormalizeTmpBytes, - Scratch: TakeVecZnxDftSlice + TakeVecZnxDft + TakeVecZnxBig + TakeVecZnxSlice + ScratchAvailable + TakeVecZnx, + M: VecZnxRotate + GLWEExternalProduct + GLWEMulXpMinusOne + GLWEAdd + GLWENormalize, + Scratch: ScratchTakeCore, { #[cfg(debug_assertions)] { @@ -498,7 +448,7 @@ fn execute_standard( module.vec_znx_rotate(b, out_mut.data_mut(), 0, &lut.data[0], 0); // ACC + [sum DFT(X^ai -1) * (DFT(ACC) x BRKi)] - let (mut acc_tmp, scratch_1) = scratch.take_glwe_ct(&out_mut); + let (mut acc_tmp, scratch_1) = scratch.take_glwe(&out_mut); // TODO: see if faster by skipping normalization in external product and keeping acc in big coeffs // TODO: first iteration can be optimized to be a gglwe product @@ -507,55 +457,13 @@ fn execute_standard( acc_tmp.external_product(module, &out_mut, ski, scratch_1); // acc_tmp = (sk[i] * acc) * (X^{ai} - 1) - acc_tmp.mul_xp_minus_one_inplace(module, *ai, scratch_1); + module.glwe_mul_xp_minus_one_inplace(*ai, &mut acc_tmp, scratch_1); // acc = acc + (sk[i] * acc) * (X^{ai} - 1) - out_mut.add_inplace(module, &acc_tmp); + module.glwe_add_inplace(&mut out_mut, &acc_tmp); }); // We can normalize only at the end because we add normalized values in [-2^{base2k-1}, 2^{base2k-1}] // on top of each others, thus ~ 2^{63-base2k} additions are supported before overflow. - out_mut.normalize_inplace(module, scratch_1); -} - -pub fn mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWE<&[u8]>, rot_dir: LookUpTableRotationDirection) { - let base2k: usize = lwe.base2k().into(); - - let log2n: usize = usize::BITS as usize - (n - 1).leading_zeros() as usize + 1; - - res.copy_from_slice(lwe.data().at(0, 0)); - - match rot_dir { - LookUpTableRotationDirection::Left => { - res.iter_mut().for_each(|x| *x = -*x); - } - LookUpTableRotationDirection::Right => {} - } - - if base2k > log2n { - let diff: usize = base2k - (log2n - 1); // additional -1 because we map to [-N/2, N/2) instead of [0, N) - res.iter_mut().for_each(|x| { - *x = div_round_by_pow2(x, diff); - }) - } else { - let rem: usize = base2k - (log2n % base2k); - let size: usize = log2n.div_ceil(base2k); - (1..size).for_each(|i| { - if i == size - 1 && rem != base2k { - let k_rem: usize = base2k - rem; - izip!(lwe.data().at(0, i).iter(), res.iter_mut()).for_each(|(x, y)| { - *y = (*y << k_rem) + (x >> rem); - }); - } else { - izip!(lwe.data().at(0, i).iter(), res.iter_mut()).for_each(|(x, y)| { - *y = (*y << base2k) + x; - }); - } - }) - } -} - -#[inline(always)] -fn div_round_by_pow2(x: &i64, k: usize) -> i64 { - (x + (1 << (k - 1))) >> k + module.glwe_normalize_inplace(&mut out_mut, scratch_1); } diff --git a/poulpy-schemes/src/tfhe/blind_rotation/algorithms/cggi/key.rs b/poulpy-schemes/src/tfhe/blind_rotation/algorithms/cggi/key.rs new file mode 100644 index 0000000..830a829 --- /dev/null +++ b/poulpy-schemes/src/tfhe/blind_rotation/algorithms/cggi/key.rs @@ -0,0 +1,79 @@ +use poulpy_hal::{ + layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxView, ZnxViewMut}, + source::Source, +}; + +use std::marker::PhantomData; + +use poulpy_core::{ + Distribution, GGSWEncryptSk, GetDistribution, ScratchTakeCore, + layouts::{GGSW, GGSWInfos, GLWEInfos, GLWESecretPreparedToRef, LWEInfos, LWESecret, LWESecretToRef}, +}; + +use crate::tfhe::blind_rotation::{ + BlindRotationKey, BlindRotationKeyEncryptSk, BlindRotationKeyFactory, BlindRotationKeyInfos, CGGI, +}; + +impl BlindRotationKeyFactory for BlindRotationKey { + fn blind_rotation_key_alloc(infos: &A) -> BlindRotationKey, CGGI> + where + A: BlindRotationKeyInfos, + { + BlindRotationKey { + keys: (0..infos.n_lwe().as_usize()) + .map(|_| GGSW::alloc_from_infos(infos)) + .collect(), + dist: Distribution::NONE, + _phantom: PhantomData, + } + } +} + +impl BlindRotationKeyEncryptSk for Module +where + Self: GGSWEncryptSk, + Scratch: ScratchTakeCore, +{ + fn blind_rotation_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize { + self.ggsw_encrypt_sk_tmp_bytes(infos) + } + + fn blind_rotation_key_encrypt_sk( + &self, + res: &mut BlindRotationKey, + sk_glwe: &S0, + sk_lwe: &S1, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + D: DataMut, + S0: GLWESecretPreparedToRef + GLWEInfos, + S1: LWESecretToRef + LWEInfos + GetDistribution, + { + assert_eq!(res.keys.len() as u32, sk_lwe.n()); + assert!(sk_glwe.n() <= self.n() as u32); + assert_eq!(sk_glwe.rank(), res.rank()); + + match sk_lwe.dist() { + Distribution::BinaryBlock(_) | Distribution::BinaryFixed(_) | Distribution::BinaryProb(_) | Distribution::ZERO => {} + _ => { + panic!("invalid GLWESecret distribution: must be BinaryBlock, BinaryFixed or BinaryProb (or ZERO for debugging)") + } + } + + { + let sk_lwe: &LWESecret<&[u8]> = &sk_lwe.to_ref(); + + res.dist = sk_lwe.dist(); + + let mut pt: ScalarZnx> = ScalarZnx::alloc(sk_glwe.n().into(), 1); + let sk_ref: ScalarZnx<&[u8]> = sk_lwe.data().to_ref(); + + for (i, ggsw) in res.keys.iter_mut().enumerate() { + pt.at_mut(0, 0)[0] = sk_ref.at(0, 0)[i]; + ggsw.encrypt_sk(self, &pt, sk_glwe, source_xa, source_xe, scratch); + } + } + } +} diff --git a/poulpy-schemes/src/tfhe/blind_rotation/algorithms/cggi/key_compressed.rs b/poulpy-schemes/src/tfhe/blind_rotation/algorithms/cggi/key_compressed.rs new file mode 100644 index 0000000..255f8f2 --- /dev/null +++ b/poulpy-schemes/src/tfhe/blind_rotation/algorithms/cggi/key_compressed.rs @@ -0,0 +1,84 @@ +use std::marker::PhantomData; + +use poulpy_core::{ + Distribution, GGSWCompressedEncryptSk, GetDistribution, ScratchTakeCore, + layouts::{GGSWCompressed, GGSWInfos, GLWEInfos, GLWESecretPreparedToRef, LWEInfos, LWESecret, LWESecretToRef}, +}; +use poulpy_hal::{ + layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxView, ZnxViewMut}, + source::Source, +}; + +use crate::tfhe::blind_rotation::{ + BlindRotationKeyCompressed, BlindRotationKeyCompressedEncryptSk, BlindRotationKeyCompressedFactory, BlindRotationKeyInfos, + CGGI, +}; + +impl BlindRotationKeyCompressedFactory for BlindRotationKeyCompressed { + fn blind_rotation_key_compressed_alloc(infos: &A) -> BlindRotationKeyCompressed, CGGI> + where + A: BlindRotationKeyInfos, + { + let mut data: Vec>> = Vec::with_capacity(infos.n_lwe().into()); + (0..infos.n_lwe().as_usize()).for_each(|_| data.push(GGSWCompressed::alloc_from_infos(infos))); + BlindRotationKeyCompressed { + keys: data, + dist: Distribution::NONE, + _phantom: PhantomData, + } + } +} + +impl BlindRotationKeyCompressedEncryptSk for Module +where + Self: GGSWCompressedEncryptSk, + Scratch: ScratchTakeCore, +{ + fn blind_rotation_key_compressed_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGSWInfos, + { + self.ggsw_compressed_encrypt_sk_tmp_bytes(infos) + } + + fn blind_rotation_key_compressed_encrypt_sk( + &self, + res: &mut BlindRotationKeyCompressed, + sk_glwe: &S0, + sk_lwe: &S1, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + D: DataMut, + S0: GLWESecretPreparedToRef + GLWEInfos, + S1: LWESecretToRef + LWEInfos + GetDistribution, + { + assert_eq!(res.keys.len() as u32, sk_lwe.n()); + assert!(sk_glwe.n() <= self.n() as u32); + assert_eq!(sk_glwe.rank(), res.rank()); + + match sk_lwe.dist() { + Distribution::BinaryBlock(_) | Distribution::BinaryFixed(_) | Distribution::BinaryProb(_) | Distribution::ZERO => {} + _ => { + panic!("invalid GLWESecret distribution: must be BinaryBlock, BinaryFixed or BinaryProb (or ZERO for debugging)") + } + } + + { + let sk_lwe: &LWESecret<&[u8]> = &sk_lwe.to_ref(); + + let mut source_xa: Source = Source::new(seed_xa); + + res.dist = sk_lwe.dist(); + + let mut pt: ScalarZnx> = ScalarZnx::alloc(sk_glwe.n().into(), 1); + let sk_ref: ScalarZnx<&[u8]> = sk_lwe.data().to_ref(); + + for (i, ggsw) in res.keys.iter_mut().enumerate() { + pt.at_mut(0, 0)[0] = sk_ref.at(0, 0)[i]; + ggsw.encrypt_sk(self, &pt, sk_glwe, source_xa.new_seed(), source_xe, scratch); + } + } + } +} diff --git a/poulpy-schemes/src/tfhe/blind_rotation/algorithms/cggi/key_prepared.rs b/poulpy-schemes/src/tfhe/blind_rotation/algorithms/cggi/key_prepared.rs new file mode 100644 index 0000000..b711f16 --- /dev/null +++ b/poulpy-schemes/src/tfhe/blind_rotation/algorithms/cggi/key_prepared.rs @@ -0,0 +1,69 @@ +use poulpy_hal::{ + api::{SvpPPolAlloc, SvpPrepare}, + layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, SvpPPol}, +}; + +use std::marker::PhantomData; + +use poulpy_core::{ + Distribution, + layouts::{GGSWPreparedFactory, LWEInfos, prepared::GGSWPrepared}, +}; + +use crate::tfhe::blind_rotation::{ + BlindRotationKey, BlindRotationKeyInfos, BlindRotationKeyPrepared, BlindRotationKeyPreparedFactory, CGGI, + utils::set_xai_plus_y, +}; + +impl BlindRotationKeyPreparedFactory for Module +where + Self: GGSWPreparedFactory + SvpPPolAlloc + SvpPrepare, +{ + fn blind_rotation_key_prepared_alloc(&self, infos: &A) -> BlindRotationKeyPrepared, CGGI, BE> + where + A: BlindRotationKeyInfos, + { + BlindRotationKeyPrepared { + data: (0..infos.n_lwe().as_usize()) + .map(|_| GGSWPrepared::alloc_from_infos(self, infos)) + .collect(), + dist: Distribution::NONE, + x_pow_a: None, + _phantom: PhantomData, + } + } + + fn blind_rotation_key_prepare( + &self, + res: &mut BlindRotationKeyPrepared, + other: &BlindRotationKey, + scratch: &mut Scratch, + ) where + DM: DataMut, + DR: DataRef, + { + #[cfg(debug_assertions)] + { + assert_eq!(res.data.len(), other.keys.len()); + } + + let n: usize = other.n().as_usize(); + + for (a, b) in res.data.iter_mut().zip(other.keys.iter()) { + a.prepare(self, b, scratch); + } + + res.dist = other.dist; + + if let Distribution::BinaryBlock(_) = other.dist { + let mut x_pow_a: Vec, BE>> = Vec::with_capacity(n << 1); + let mut buf: ScalarZnx> = ScalarZnx::alloc(n, 1); + (0..n << 1).for_each(|i| { + let mut res: SvpPPol, BE> = self.svp_ppol_alloc(1); + set_xai_plus_y(self, i, 0, &mut res, &mut buf); + x_pow_a.push(res); + }); + res.x_pow_a = Some(x_pow_a); + } + } +} diff --git a/poulpy-schemes/src/tfhe/blind_rotation/algorithms/cggi/mod.rs b/poulpy-schemes/src/tfhe/blind_rotation/algorithms/cggi/mod.rs new file mode 100644 index 0000000..d67ee45 --- /dev/null +++ b/poulpy-schemes/src/tfhe/blind_rotation/algorithms/cggi/mod.rs @@ -0,0 +1,10 @@ +mod algorithm; +mod key; +mod key_compressed; +mod key_prepared; + +use crate::tfhe::blind_rotation::BlindRotationAlgo; + +#[derive(Clone)] +pub struct CGGI {} +impl BlindRotationAlgo for CGGI {} diff --git a/poulpy-schemes/src/tfhe/blind_rotation/algorithms/mod.rs b/poulpy-schemes/src/tfhe/blind_rotation/algorithms/mod.rs new file mode 100644 index 0000000..f25fb51 --- /dev/null +++ b/poulpy-schemes/src/tfhe/blind_rotation/algorithms/mod.rs @@ -0,0 +1,116 @@ +mod cggi; + +pub use cggi::*; + +use itertools::izip; +use poulpy_core::{ + ScratchTakeCore, + layouts::{GGSWInfos, GLWE, GLWEInfos, LWE, LWEInfos}, +}; +use poulpy_hal::layouts::{Backend, DataMut, DataRef, Scratch, ZnxView}; + +use crate::tfhe::blind_rotation::{BlindRotationKeyInfos, BlindRotationKeyPrepared, LookUpTableRotationDirection, LookupTable}; + +pub trait BlindRotationAlgo {} + +pub trait BlindRotationExecute { + fn blind_rotation_execute_tmp_bytes( + &self, + block_size: usize, + extension_factor: usize, + glwe_infos: &G, + brk_infos: &B, + ) -> usize + where + G: GLWEInfos, + B: GGSWInfos; + + fn blind_rotation_execute( + &self, + res: &mut GLWE, + lwe: &LWE
, + lut: &LookupTable, + brk: &BlindRotationKeyPrepared, + scratch: &mut Scratch, + ) where + DR: DataMut, + DL: DataRef, + DB: DataRef; +} + +impl BlindRotationKeyPrepared +where + Scratch: ScratchTakeCore, +{ + pub fn execute( + &self, + module: &M, + res: &mut GLWE, + lwe: &LWE, + lut: &LookupTable, + scratch: &mut Scratch, + ) where + M: BlindRotationExecute, + { + module.blind_rotation_execute(res, lwe, lut, self, scratch); + } +} + +impl BlindRotationKeyPrepared, BRA, BE> { + pub fn execute_tmp_bytes( + module: &M, + block_size: usize, + extension_factor: usize, + glwe_infos: &A, + brk_infos: &B, + ) -> usize + where + A: GLWEInfos, + B: BlindRotationKeyInfos, + M: BlindRotationExecute, + { + module.blind_rotation_execute_tmp_bytes(block_size, extension_factor, glwe_infos, brk_infos) + } +} + +pub fn mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWE<&[u8]>, rot_dir: LookUpTableRotationDirection) { + let base2k: usize = lwe.base2k().into(); + + let log2n: usize = usize::BITS as usize - (n - 1).leading_zeros() as usize + 1; + + res.copy_from_slice(lwe.data().at(0, 0)); + + match rot_dir { + LookUpTableRotationDirection::Left => { + res.iter_mut().for_each(|x| *x = -*x); + } + LookUpTableRotationDirection::Right => {} + } + + if base2k > log2n { + let diff: usize = base2k - (log2n - 1); // additional -1 because we map to [-N/2, N/2) instead of [0, N) + res.iter_mut().for_each(|x| { + *x = div_round_by_pow2(x, diff); + }) + } else { + let rem: usize = base2k - (log2n % base2k); + let size: usize = log2n.div_ceil(base2k); + (1..size).for_each(|i| { + if i == size - 1 && rem != base2k { + let k_rem: usize = base2k - rem; + izip!(lwe.data().at(0, i).iter(), res.iter_mut()).for_each(|(x, y)| { + *y = (*y << k_rem) + (x >> rem); + }); + } else { + izip!(lwe.data().at(0, i).iter(), res.iter_mut()).for_each(|(x, y)| { + *y = (*y << base2k) + x; + }); + } + }) + } +} + +#[inline(always)] +fn div_round_by_pow2(x: &i64, k: usize) -> i64 { + (x + (1 << (k - 1))) >> k +} diff --git a/poulpy-schemes/src/tfhe/blind_rotation/cggi_key.rs b/poulpy-schemes/src/tfhe/blind_rotation/cggi_key.rs deleted file mode 100644 index 3114ea1..0000000 --- a/poulpy-schemes/src/tfhe/blind_rotation/cggi_key.rs +++ /dev/null @@ -1,223 +0,0 @@ -use poulpy_hal::{ - api::{ - ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, - VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, - VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VmpPMatAlloc, VmpPrepare, - }, - layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxView, ZnxViewMut}, - source::Source, -}; - -use std::marker::PhantomData; - -use poulpy_core::{ - Distribution, - layouts::{ - GGSW, GGSWInfos, LWESecret, - compressed::GGSWCompressed, - prepared::{GGSWPrepared, GLWESecretPrepared}, - }, -}; - -use crate::tfhe::blind_rotation::{ - BlindRotationKey, BlindRotationKeyAlloc, BlindRotationKeyCompressed, BlindRotationKeyEncryptSk, BlindRotationKeyInfos, - BlindRotationKeyPrepared, BlindRotationKeyPreparedAlloc, CGGI, -}; - -impl BlindRotationKeyAlloc for BlindRotationKey, CGGI> { - fn alloc(infos: &A) -> Self - where - A: BlindRotationKeyInfos, - { - let mut data: Vec>> = Vec::with_capacity(infos.n_lwe().into()); - for _ in 0..infos.n_lwe().as_usize() { - data.push(GGSW::alloc_from_infos(infos)); - } - - Self { - keys: data, - dist: Distribution::NONE, - _phantom: PhantomData, - } - } -} - -impl BlindRotationKey, CGGI> { - pub fn generate_from_sk_tmp_bytes(module: &Module, infos: &A) -> usize - where - A: GGSWInfos, - Module: VecZnxNormalizeTmpBytes + VecZnxDftBytesOf, - { - GGSW::encrypt_sk_tmp_bytes(module, infos) - } -} - -impl BlindRotationKeyEncryptSk for BlindRotationKey -where - Module: VecZnxAddScalarInplace - + VecZnxDftBytesOf - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, -{ - fn encrypt_sk( - &mut self, - module: &Module, - sk_glwe: &GLWESecretPrepared, - sk_lwe: &LWESecret, - source_xa: &mut Source, - source_xe: &mut Source, - scratch: &mut Scratch, - ) where - DataSkGLWE: DataRef, - DataSkLWE: DataRef, - { - #[cfg(debug_assertions)] - { - use poulpy_core::layouts::{GLWEInfos, LWEInfos}; - - assert_eq!(self.keys.len() as u32, sk_lwe.n()); - assert!(sk_glwe.n() <= module.n() as u32); - assert_eq!(sk_glwe.rank(), self.rank()); - match sk_lwe.dist() { - Distribution::BinaryBlock(_) - | Distribution::BinaryFixed(_) - | Distribution::BinaryProb(_) - | Distribution::ZERO => {} - _ => panic!( - "invalid GLWESecret distribution: must be BinaryBlock, BinaryFixed or BinaryProb (or ZERO for debugging)" - ), - } - } - - self.dist = sk_lwe.dist(); - - let mut pt: ScalarZnx> = ScalarZnx::alloc(sk_glwe.n().into(), 1); - let sk_ref: ScalarZnx<&[u8]> = sk_lwe.data().to_ref(); - - self.keys.iter_mut().enumerate().for_each(|(i, ggsw)| { - pt.at_mut(0, 0)[0] = sk_ref.at(0, 0)[i]; - ggsw.encrypt_sk(module, &pt, sk_glwe, source_xa, source_xe, scratch); - }); - } -} - -impl BlindRotationKeyPreparedAlloc for BlindRotationKeyPrepared, CGGI, B> -where - Module: VmpPMatAlloc + VmpPrepare, -{ - fn alloc(module: &Module, infos: &A) -> Self - where - A: BlindRotationKeyInfos, - { - let mut data: Vec, B>> = Vec::with_capacity(infos.n_lwe().into()); - (0..infos.n_lwe().as_usize()).for_each(|_| data.push(GGSWPrepared::alloc_from_infos(module, infos))); - Self { - data, - dist: Distribution::NONE, - x_pow_a: None, - _phantom: PhantomData, - } - } -} - -impl BlindRotationKeyCompressed, CGGI> { - pub fn alloc(infos: &A) -> Self - where - A: BlindRotationKeyInfos, - { - let mut data: Vec>> = Vec::with_capacity(infos.n_lwe().into()); - (0..infos.n_lwe().as_usize()).for_each(|_| data.push(GGSWCompressed::alloc_from_infos(infos))); - Self { - keys: data, - dist: Distribution::NONE, - _phantom: PhantomData, - } - } - - pub fn generate_from_sk_tmp_bytes(module: &Module, infos: &A) -> usize - where - A: GGSWInfos, - Module: VecZnxNormalizeTmpBytes + VecZnxDftBytesOf, - { - GGSWCompressed::encrypt_sk_tmp_bytes(module, infos) - } -} - -impl BlindRotationKeyCompressed { - #[allow(clippy::too_many_arguments)] - pub fn encrypt_sk( - &mut self, - module: &Module, - sk_glwe: &GLWESecretPrepared, - sk_lwe: &LWESecret, - seed_xa: [u8; 32], - source_xe: &mut Source, - scratch: &mut Scratch, - ) where - DataSkGLWE: DataRef, - DataSkLWE: DataRef, - Module: VecZnxAddScalarInplace - + VecZnxDftBytesOf - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, - { - #[cfg(debug_assertions)] - { - use poulpy_core::layouts::{GLWEInfos, LWEInfos}; - - assert_eq!(self.n_lwe(), sk_lwe.n()); - assert!(sk_glwe.n() <= module.n() as u32); - assert_eq!(sk_glwe.rank(), self.rank()); - match sk_lwe.dist() { - Distribution::BinaryBlock(_) - | Distribution::BinaryFixed(_) - | Distribution::BinaryProb(_) - | Distribution::ZERO => {} - _ => panic!( - "invalid GLWESecret distribution: must be BinaryBlock, BinaryFixed or BinaryProb (or ZERO for debugging)" - ), - } - } - - self.dist = sk_lwe.dist(); - - let mut pt: ScalarZnx> = ScalarZnx::alloc(sk_glwe.n().into(), 1); - let sk_ref: ScalarZnx<&[u8]> = sk_lwe.data().to_ref(); - - let mut source_xa: Source = Source::new(seed_xa); - - self.keys.iter_mut().enumerate().for_each(|(i, ggsw)| { - pt.at_mut(0, 0)[0] = sk_ref.at(0, 0)[i]; - ggsw.encrypt_sk( - module, - &pt, - sk_glwe, - source_xa.new_seed(), - source_xe, - scratch, - ); - }); - } -} diff --git a/poulpy-schemes/src/tfhe/blind_rotation/encryption/key.rs b/poulpy-schemes/src/tfhe/blind_rotation/encryption/key.rs new file mode 100644 index 0000000..99a6e2e --- /dev/null +++ b/poulpy-schemes/src/tfhe/blind_rotation/encryption/key.rs @@ -0,0 +1,60 @@ +use poulpy_hal::{ + layouts::{Backend, DataMut, Scratch}, + source::Source, +}; + +use poulpy_core::{ + GetDistribution, ScratchTakeCore, + layouts::{GGSWInfos, GLWEInfos, GLWESecretPreparedToRef, LWEInfos, LWESecretToRef}, +}; + +use crate::tfhe::blind_rotation::{BlindRotationAlgo, BlindRotationKey}; + +pub trait BlindRotationKeyEncryptSk { + fn blind_rotation_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGSWInfos; + + #[allow(clippy::too_many_arguments)] + fn blind_rotation_key_encrypt_sk( + &self, + res: &mut BlindRotationKey, + sk_glwe: &S0, + sk_lwe: &S1, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + D: DataMut, + S0: GLWESecretPreparedToRef + GLWEInfos, + S1: LWESecretToRef + LWEInfos + GetDistribution; +} + +impl BlindRotationKey { + pub fn encrypt_sk( + &mut self, + module: &M, + sk_glwe: &S0, + sk_lwe: &S1, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + S0: GLWESecretPreparedToRef + GLWEInfos, + S1: LWESecretToRef + LWEInfos + GetDistribution, + Scratch: ScratchTakeCore, + M: BlindRotationKeyEncryptSk, + { + module.blind_rotation_key_encrypt_sk(self, sk_glwe, sk_lwe, source_xa, source_xe, scratch); + } +} + +impl BlindRotationKey, BRA> { + pub fn encrypt_sk_tmp_bytes(module: &M, infos: &A) -> usize + where + A: GGSWInfos, + M: BlindRotationKeyEncryptSk, + { + module.blind_rotation_key_encrypt_sk_tmp_bytes(infos) + } +} diff --git a/poulpy-schemes/src/tfhe/blind_rotation/encryption/key_compressed.rs b/poulpy-schemes/src/tfhe/blind_rotation/encryption/key_compressed.rs new file mode 100644 index 0000000..4898365 --- /dev/null +++ b/poulpy-schemes/src/tfhe/blind_rotation/encryption/key_compressed.rs @@ -0,0 +1,30 @@ +use poulpy_core::{ + GetDistribution, + layouts::{GGSWInfos, GLWEInfos, GLWESecretPreparedToRef, LWEInfos, LWESecretToRef}, +}; +use poulpy_hal::{ + layouts::{Backend, DataMut, Scratch}, + source::Source, +}; + +use crate::tfhe::blind_rotation::{BlindRotationAlgo, BlindRotationKeyCompressed}; + +pub trait BlindRotationKeyCompressedEncryptSk { + fn blind_rotation_key_compressed_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGSWInfos; + + #[allow(clippy::too_many_arguments)] + fn blind_rotation_key_compressed_encrypt_sk( + &self, + res: &mut BlindRotationKeyCompressed, + sk_glwe: &S0, + sk_lwe: &S1, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + D: DataMut, + S0: GLWESecretPreparedToRef + GLWEInfos, + S1: LWESecretToRef + LWEInfos + GetDistribution; +} diff --git a/poulpy-schemes/src/tfhe/blind_rotation/encryption/mod.rs b/poulpy-schemes/src/tfhe/blind_rotation/encryption/mod.rs new file mode 100644 index 0000000..62623fc --- /dev/null +++ b/poulpy-schemes/src/tfhe/blind_rotation/encryption/mod.rs @@ -0,0 +1,5 @@ +mod key; +mod key_compressed; + +pub use key::*; +pub use key_compressed::*; diff --git a/poulpy-schemes/src/tfhe/blind_rotation/key_prepared.rs b/poulpy-schemes/src/tfhe/blind_rotation/key_prepared.rs deleted file mode 100644 index 2001083..0000000 --- a/poulpy-schemes/src/tfhe/blind_rotation/key_prepared.rs +++ /dev/null @@ -1,130 +0,0 @@ -use poulpy_hal::{ - api::{SvpPPolAlloc, SvpPrepare, VmpPMatAlloc, VmpPrepare}, - layouts::{Backend, Data, DataMut, DataRef, Module, ScalarZnx, Scratch, SvpPPol}, -}; - -use std::marker::PhantomData; - -use poulpy_core::{ - Distribution, - layouts::{ - Base2K, Degree, Dnum, Dsize, GGSWInfos, GLWEInfos, LWEInfos, Rank, TorusPrecision, - prepared::{GGSWPrepared, Prepare, PrepareAlloc}, - }, -}; - -use crate::tfhe::blind_rotation::{BlindRotationAlgo, BlindRotationKey, BlindRotationKeyInfos, utils::set_xai_plus_y}; - -pub trait BlindRotationKeyPreparedAlloc { - fn alloc(module: &Module, infos: &A) -> Self - where - A: BlindRotationKeyInfos; -} - -#[derive(PartialEq, Eq)] -pub struct BlindRotationKeyPrepared { - pub(crate) data: Vec>, - pub(crate) dist: Distribution, - pub(crate) x_pow_a: Option, B>>>, - pub(crate) _phantom: PhantomData, -} - -impl BlindRotationKeyInfos for BlindRotationKeyPrepared { - fn n_glwe(&self) -> Degree { - self.n() - } - - fn n_lwe(&self) -> Degree { - Degree(self.data.len() as u32) - } -} - -impl LWEInfos for BlindRotationKeyPrepared { - fn base2k(&self) -> Base2K { - self.data[0].base2k() - } - - fn k(&self) -> TorusPrecision { - self.data[0].k() - } - - fn n(&self) -> Degree { - self.data[0].n() - } - - fn size(&self) -> usize { - self.data[0].size() - } -} - -impl GLWEInfos for BlindRotationKeyPrepared { - fn rank(&self) -> Rank { - self.data[0].rank() - } -} -impl GGSWInfos for BlindRotationKeyPrepared { - fn dsize(&self) -> poulpy_core::layouts::Dsize { - Dsize(1) - } - - fn dnum(&self) -> Dnum { - self.data[0].dnum() - } -} - -impl BlindRotationKeyPrepared { - pub fn block_size(&self) -> usize { - match self.dist { - Distribution::BinaryBlock(value) => value, - _ => 1, - } - } -} - -impl PrepareAlloc, BRA, B>> - for BlindRotationKey -where - BlindRotationKeyPrepared, BRA, B>: BlindRotationKeyPreparedAlloc, - BlindRotationKeyPrepared, BRA, B>: Prepare>, -{ - fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> BlindRotationKeyPrepared, BRA, B> { - let mut brk: BlindRotationKeyPrepared, BRA, B> = BlindRotationKeyPrepared::alloc(module, self); - brk.prepare(module, self, scratch); - brk - } -} - -impl Prepare> - for BlindRotationKeyPrepared -where - Module: VmpPMatAlloc + VmpPrepare + SvpPPolAlloc + SvpPrepare, -{ - fn prepare(&mut self, module: &Module, other: &BlindRotationKey, scratch: &mut Scratch) { - #[cfg(debug_assertions)] - { - assert_eq!(self.data.len(), other.keys.len()); - } - - let n: usize = other.n().as_usize(); - - self.data - .iter_mut() - .zip(other.keys.iter()) - .for_each(|(ggsw_prepared, other)| { - ggsw_prepared.prepare(module, other, scratch); - }); - - self.dist = other.dist; - - if let Distribution::BinaryBlock(_) = other.dist { - let mut x_pow_a: Vec, B>> = Vec::with_capacity(n << 1); - let mut buf: ScalarZnx> = ScalarZnx::alloc(n, 1); - (0..n << 1).for_each(|i| { - let mut res: SvpPPol, B> = module.svp_ppol_alloc(1); - set_xai_plus_y(module, i, 0, &mut res, &mut buf); - x_pow_a.push(res); - }); - self.x_pow_a = Some(x_pow_a); - } - } -} diff --git a/poulpy-schemes/src/tfhe/blind_rotation/key.rs b/poulpy-schemes/src/tfhe/blind_rotation/layouts/key.rs similarity index 87% rename from poulpy-schemes/src/tfhe/blind_rotation/key.rs rename to poulpy-schemes/src/tfhe/blind_rotation/layouts/key.rs index d9bb0a9..182c973 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/key.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/layouts/key.rs @@ -1,5 +1,5 @@ use poulpy_hal::{ - layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, Scratch, WriterTo}, + layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, WriterTo}, source::Source, }; @@ -7,10 +7,7 @@ use std::{fmt, marker::PhantomData}; use poulpy_core::{ Distribution, - layouts::{ - Base2K, Degree, Dnum, Dsize, GGSW, GGSWInfos, GLWEInfos, LWEInfos, LWESecret, Rank, TorusPrecision, - prepared::GLWESecretPrepared, - }, + layouts::{Base2K, Degree, Dnum, Dsize, GGSW, GGSWInfos, GLWEInfos, LWEInfos, Rank, TorusPrecision}, }; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; @@ -81,21 +78,6 @@ pub trait BlindRotationKeyAlloc { A: BlindRotationKeyInfos; } -pub trait BlindRotationKeyEncryptSk { - #[allow(clippy::too_many_arguments)] - fn encrypt_sk( - &mut self, - module: &Module, - sk_glwe: &GLWESecretPrepared, - sk_lwe: &LWESecret, - source_xa: &mut Source, - source_xe: &mut Source, - scratch: &mut Scratch, - ) where - DataSkGLWE: DataRef, - DataSkLWE: DataRef; -} - #[derive(Clone)] pub struct BlindRotationKey { pub(crate) keys: Vec>, @@ -103,6 +85,24 @@ pub struct BlindRotationKey { pub(crate) _phantom: PhantomData, } +pub trait BlindRotationKeyFactory { + fn blind_rotation_key_alloc(infos: &A) -> BlindRotationKey, BRA> + where + A: BlindRotationKeyInfos; +} + +impl BlindRotationKey, BRA> +where + Self: BlindRotationKeyFactory, +{ + pub fn alloc(infos: &A) -> BlindRotationKey, BRA> + where + A: BlindRotationKeyInfos, + { + Self::blind_rotation_key_alloc(infos) + } +} + impl fmt::Debug for BlindRotationKey { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{self}") diff --git a/poulpy-schemes/src/tfhe/blind_rotation/key_compressed.rs b/poulpy-schemes/src/tfhe/blind_rotation/layouts/key_compressed.rs similarity index 89% rename from poulpy-schemes/src/tfhe/blind_rotation/key_compressed.rs rename to poulpy-schemes/src/tfhe/blind_rotation/layouts/key_compressed.rs index 784d332..26539e1 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/key_compressed.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/layouts/key_compressed.rs @@ -20,6 +20,24 @@ pub struct BlindRotationKeyCompressed { pub(crate) _phantom: PhantomData, } +pub trait BlindRotationKeyCompressedFactory { + fn blind_rotation_key_compressed_alloc(infos: &A) -> BlindRotationKeyCompressed, BRA> + where + A: BlindRotationKeyInfos; +} + +impl BlindRotationKeyCompressed, BRA> +where + Self: BlindRotationKeyCompressedFactory, +{ + pub fn alloc(infos: &A) -> BlindRotationKeyCompressed, BRA> + where + A: BlindRotationKeyInfos, + { + Self::blind_rotation_key_compressed_alloc(infos) + } +} + impl fmt::Debug for BlindRotationKeyCompressed { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{self}") diff --git a/poulpy-schemes/src/tfhe/blind_rotation/layouts/key_prepared.rs b/poulpy-schemes/src/tfhe/blind_rotation/layouts/key_prepared.rs new file mode 100644 index 0000000..5830c84 --- /dev/null +++ b/poulpy-schemes/src/tfhe/blind_rotation/layouts/key_prepared.rs @@ -0,0 +1,108 @@ +use poulpy_hal::layouts::{Backend, Data, DataMut, DataRef, Scratch, SvpPPol}; + +use std::marker::PhantomData; + +use poulpy_core::{ + Distribution, ScratchTakeCore, + layouts::{Base2K, Degree, Dnum, Dsize, GGSWInfos, GLWEInfos, LWEInfos, Rank, TorusPrecision, prepared::GGSWPrepared}, +}; + +use crate::tfhe::blind_rotation::{BlindRotationAlgo, BlindRotationKey, BlindRotationKeyInfos}; + +pub trait BlindRotationKeyPreparedFactory { + fn blind_rotation_key_prepared_alloc(&self, infos: &A) -> BlindRotationKeyPrepared, BRA, BE> + where + A: BlindRotationKeyInfos; + + fn blind_rotation_key_prepare( + &self, + res: &mut BlindRotationKeyPrepared, + other: &BlindRotationKey, + scratch: &mut Scratch, + ) where + DM: DataMut, + DR: DataRef, + Scratch: ScratchTakeCore; +} + +impl BlindRotationKeyPrepared, BRA, BE> { + pub fn alloc(module: &M, infos: &A) -> Self + where + A: BlindRotationKeyInfos, + M: BlindRotationKeyPreparedFactory, + { + module.blind_rotation_key_prepared_alloc(infos) + } +} + +impl BlindRotationKeyPrepared +where + Scratch: ScratchTakeCore, +{ + pub fn prepare(&mut self, module: &M, other: &BlindRotationKey, scratch: &mut Scratch) + where + M: BlindRotationKeyPreparedFactory, + { + module.blind_rotation_key_prepare(self, other, scratch); + } +} + +#[derive(PartialEq, Eq)] +pub struct BlindRotationKeyPrepared { + pub(crate) data: Vec>, + pub(crate) dist: Distribution, + pub(crate) x_pow_a: Option, B>>>, + pub(crate) _phantom: PhantomData, +} + +impl BlindRotationKeyInfos for BlindRotationKeyPrepared { + fn n_glwe(&self) -> Degree { + self.n() + } + + fn n_lwe(&self) -> Degree { + Degree(self.data.len() as u32) + } +} + +impl LWEInfos for BlindRotationKeyPrepared { + fn base2k(&self) -> Base2K { + self.data[0].base2k() + } + + fn k(&self) -> TorusPrecision { + self.data[0].k() + } + + fn n(&self) -> Degree { + self.data[0].n() + } + + fn size(&self) -> usize { + self.data[0].size() + } +} + +impl GLWEInfos for BlindRotationKeyPrepared { + fn rank(&self) -> Rank { + self.data[0].rank() + } +} +impl GGSWInfos for BlindRotationKeyPrepared { + fn dsize(&self) -> poulpy_core::layouts::Dsize { + Dsize(1) + } + + fn dnum(&self) -> Dnum { + self.data[0].dnum() + } +} + +impl BlindRotationKeyPrepared { + pub fn block_size(&self) -> usize { + match self.dist { + Distribution::BinaryBlock(value) => value, + _ => 1, + } + } +} diff --git a/poulpy-schemes/src/tfhe/blind_rotation/layouts/mod.rs b/poulpy-schemes/src/tfhe/blind_rotation/layouts/mod.rs new file mode 100644 index 0000000..f3e285f --- /dev/null +++ b/poulpy-schemes/src/tfhe/blind_rotation/layouts/mod.rs @@ -0,0 +1,6 @@ +mod key; +mod key_compressed; +mod key_prepared; +pub use key::*; +pub use key_compressed::*; +pub use key_prepared::*; diff --git a/poulpy-schemes/src/tfhe/blind_rotation/lut.rs b/poulpy-schemes/src/tfhe/blind_rotation/lut.rs index f8e9006..74c0441 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/lut.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/lut.rs @@ -1,3 +1,4 @@ +use poulpy_core::layouts::{Base2K, Degree, TorusPrecision}; use poulpy_hal::{ api::{ ScratchOwnedAlloc, ScratchOwnedBorrow, TakeSlice, VecZnxCopy, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, @@ -13,32 +14,97 @@ pub enum LookUpTableRotationDirection { Right, } -pub struct LookUpTable { +pub struct LookUpTableLayout { + pub n: Degree, + pub extension_factor: usize, + pub k: TorusPrecision, + pub base2k: Base2K, +} + +pub trait LookupTableInfos { + fn n(&self) -> Degree; + fn extension_factor(&self) -> usize; + fn k(&self) -> TorusPrecision; + fn base2k(&self) -> Base2K; + fn size(&self) -> usize; +} + +impl LookupTableInfos for LookUpTableLayout { + fn base2k(&self) -> Base2K { + self.base2k + } + + fn extension_factor(&self) -> usize { + self.extension_factor + } + + fn k(&self) -> TorusPrecision { + self.k + } + + fn size(&self) -> usize { + self.k().as_usize().div_ceil(self.base2k().as_usize()) + } + + fn n(&self) -> Degree { + self.n + } +} + +pub struct LookupTable { pub(crate) data: Vec>>, pub(crate) rot_dir: LookUpTableRotationDirection, - pub(crate) base2k: usize, - pub(crate) k: usize, + pub(crate) base2k: Base2K, + pub(crate) k: TorusPrecision, pub(crate) drift: usize, } -impl LookUpTable { - pub fn alloc(module: &Module, base2k: usize, k: usize, extension_factor: usize) -> Self { +impl LookupTableInfos for LookupTable { + fn base2k(&self) -> Base2K { + self.base2k + } + + fn extension_factor(&self) -> usize { + self.data.len() + } + + fn k(&self) -> TorusPrecision { + self.k + } + + fn n(&self) -> Degree { + self.data[0].n().into() + } + + fn size(&self) -> usize { + self.data[0].size() + } +} + +pub trait LookupTableFactory { + fn lookup_table_set(&self, res: &mut LookupTable, f: &[i64], k: usize); + fn lookup_table_rotate(&self, k: i64, res: &mut LookupTable); +} + +impl LookupTable { + pub fn alloc(infos: &A) -> Self + where + A: LookupTableInfos, + { #[cfg(debug_assertions)] { assert!( - extension_factor & (extension_factor - 1) == 0, - "extension_factor must be a power of two but is: {extension_factor}" + infos.extension_factor() & (infos.extension_factor() - 1) == 0, + "extension_factor must be a power of two but is: {}", + infos.extension_factor() ); } - let size: usize = k.div_ceil(base2k); - let mut data: Vec>> = Vec::with_capacity(extension_factor); - (0..extension_factor).for_each(|_| { - data.push(VecZnx::alloc(module.n(), 1, size)); - }); Self { - data, - base2k, - k, + data: (0..infos.extension_factor()) + .map(|_| VecZnx::alloc(infos.n().into(), 1, infos.size())) + .collect(), + base2k: infos.base2k(), + k: infos.k(), drift: 0, rot_dir: LookUpTableRotationDirection::Left, } @@ -68,115 +134,18 @@ impl LookUpTable { self.rot_dir = rot_dir } - pub fn set(&mut self, module: &Module, f: &[i64], k: usize) + pub fn set(&mut self, module: &M, f: &[i64], k: usize) where - Module: VecZnxRotateInplace - + VecZnxNormalizeInplace - + VecZnxNormalizeTmpBytes - + VecZnxSwitchRing - + VecZnxCopy - + VecZnxRotateInplaceTmpBytes, - ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, - Scratch: TakeSlice, + M: LookupTableFactory, { - assert!(f.len() <= module.n()); - - let base2k: usize = self.base2k; - - let mut scratch: ScratchOwned = ScratchOwned::alloc(module.vec_znx_normalize_tmp_bytes() | (self.domain_size() << 3)); - - // Get the number minimum limb to store the message modulus - let limbs: usize = k.div_ceil(base2k); - - #[cfg(debug_assertions)] - { - assert!(f.len() <= module.n()); - assert!( - (max_bit_size(f) + (k % base2k) as u32) < i64::BITS, - "overflow: max(|f|) << (k%base2k) > i64::BITS" - ); - assert!(limbs <= self.data[0].size()); - } - - // Scaling factor - let mut scale = 1; - if !k.is_multiple_of(base2k) { - scale <<= base2k - (k % base2k); - } - - // #elements in lookup table - let f_len: usize = f.len(); - - // If LUT size > TakeScalarZnx - let domain_size: usize = self.domain_size(); - - let size: usize = self.k.div_ceil(self.base2k); - - // Equivalent to AUTO([f(0), -f(n-1), -f(n-2), ..., -f(1)], -1) - let mut lut_full: VecZnx> = VecZnx::alloc(domain_size, 1, size); - - let lut_at: &mut [i64] = lut_full.at_mut(0, limbs - 1); - - let step: usize = domain_size.div_round(f_len); - - f.iter().enumerate().for_each(|(i, fi)| { - let start: usize = i * step; - let end: usize = start + step; - lut_at[start..end].fill(fi * scale); - }); - - let drift: usize = step >> 1; - - // Rotates half the step to the left - - if self.extension_factor() > 1 { - let (tmp, _) = scratch.borrow().take_slice(lut_full.n()); - - for i in 0..self.extension_factor() { - module.vec_znx_switch_ring(&mut self.data[i], 0, &lut_full, 0); - if i < self.extension_factor() { - vec_znx_rotate_inplace::<_, ZnxRef>(-1, &mut lut_full, 0, tmp); - } - } - } else { - module.vec_znx_copy(&mut self.data[0], 0, &lut_full, 0); - } - - for a in self.data.iter_mut() { - module.vec_znx_normalize_inplace(self.base2k, a, 0, scratch.borrow()); - } - - self.rotate(module, -(drift as i64)); - - self.drift = drift + module.lookup_table_set(self, f, k); } - #[allow(dead_code)] - pub(crate) fn rotate(&mut self, module: &Module, k: i64) + pub(crate) fn rotate(&mut self, module: &M, k: i64) where - Module: VecZnxRotateInplace + VecZnxRotateInplaceTmpBytes, - ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + M: LookupTableFactory, { - let extension_factor: usize = self.extension_factor(); - let two_n: usize = 2 * self.data[0].n(); - let two_n_ext: usize = two_n * extension_factor; - - let mut scratch: ScratchOwned<_> = ScratchOwned::alloc(module.vec_znx_rotate_inplace_tmp_bytes()); - - let k_pos: usize = ((k + two_n_ext as i64) % two_n_ext as i64) as usize; - - let k_hi: usize = k_pos / extension_factor; - let k_lo: usize = k_pos % extension_factor; - - (0..extension_factor - k_lo).for_each(|i| { - module.vec_znx_rotate_inplace(k_hi as i64, &mut self.data[i], 0, scratch.borrow()); - }); - - (extension_factor - k_lo..extension_factor).for_each(|i| { - module.vec_znx_rotate_inplace(k_hi as i64 + 1, &mut self.data[i], 0, scratch.borrow()); - }); - - self.data.rotate_right(k_lo); + module.lookup_table_rotate(k, self); } } @@ -204,3 +173,116 @@ fn max_bit_size(vec: &[i64]) -> u32 { .max() .unwrap_or(0) } + +impl LookupTableFactory for Module +where + Self: VecZnxRotateInplace + + VecZnxNormalizeInplace + + VecZnxNormalizeTmpBytes + + VecZnxSwitchRing + + VecZnxCopy + + VecZnxRotateInplaceTmpBytes + + VecZnxRotateInplace + + VecZnxRotateInplaceTmpBytes, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: TakeSlice, +{ + fn lookup_table_set(&self, res: &mut LookupTable, f: &[i64], k: usize) { + assert!(f.len() <= self.n()); + + let base2k: usize = res.base2k.into(); + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + self.vec_znx_normalize_tmp_bytes() + .max(res.domain_size() << 3), + ); + + // Get the number minimum limb to store the message modulus + let limbs: usize = k.div_ceil(base2k); + + #[cfg(debug_assertions)] + { + assert!(f.len() <= self.n()); + assert!( + (max_bit_size(f) + (k % base2k) as u32) < i64::BITS, + "overflow: max(|f|) << (k%base2k) > i64::BITS" + ); + assert!(limbs <= res.data[0].size()); + } + + // Scaling factor + let mut scale = 1; + if !k.is_multiple_of(base2k) { + scale <<= base2k - (k % base2k); + } + + // #elements in lookup table + let f_len: usize = f.len(); + + // If LUT size > TakeScalarZnx + let domain_size: usize = res.domain_size(); + + let size: usize = res.k.div_ceil(res.base2k) as usize; + + // Equivalent to AUTO([f(0), -f(n-1), -f(n-2), ..., -f(1)], -1) + let mut lut_full: VecZnx> = VecZnx::alloc(domain_size, 1, size); + + let lut_at: &mut [i64] = lut_full.at_mut(0, limbs - 1); + + let step: usize = domain_size.div_round(f_len); + + f.iter().enumerate().for_each(|(i, fi)| { + let start: usize = i * step; + let end: usize = start + step; + lut_at[start..end].fill(fi * scale); + }); + + let drift: usize = step >> 1; + + // Rotates half the step to the left + + if res.extension_factor() > 1 { + let (tmp, _) = scratch.borrow().take_slice(lut_full.n()); + + for i in 0..res.extension_factor() { + self.vec_znx_switch_ring(&mut res.data[i], 0, &lut_full, 0); + if i < res.extension_factor() { + vec_znx_rotate_inplace::<_, ZnxRef>(-1, &mut lut_full, 0, tmp); + } + } + } else { + self.vec_znx_copy(&mut res.data[0], 0, &lut_full, 0); + } + + for a in res.data.iter_mut() { + self.vec_znx_normalize_inplace(res.base2k.into(), a, 0, scratch.borrow()); + } + + res.rotate(self, -(drift as i64)); + + res.drift = drift + } + + fn lookup_table_rotate(&self, k: i64, res: &mut LookupTable) { + let extension_factor: usize = res.extension_factor(); + let two_n: usize = 2 * res.data[0].n(); + let two_n_ext: usize = two_n * extension_factor; + + let mut scratch: ScratchOwned<_> = ScratchOwned::alloc(self.vec_znx_rotate_inplace_tmp_bytes()); + + let k_pos: usize = ((k + two_n_ext as i64) % two_n_ext as i64) as usize; + + let k_hi: usize = k_pos / extension_factor; + let k_lo: usize = k_pos % extension_factor; + + (0..extension_factor - k_lo).for_each(|i| { + self.vec_znx_rotate_inplace(k_hi as i64, &mut res.data[i], 0, scratch.borrow()); + }); + + (extension_factor - k_lo..extension_factor).for_each(|i| { + self.vec_znx_rotate_inplace(k_hi as i64 + 1, &mut res.data[i], 0, scratch.borrow()); + }); + + res.data.rotate_right(k_lo); + } +} diff --git a/poulpy-schemes/src/tfhe/blind_rotation/mod.rs b/poulpy-schemes/src/tfhe/blind_rotation/mod.rs index 8cc262b..93da18b 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/mod.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/mod.rs @@ -1,35 +1,11 @@ -mod cggi_algo; -mod cggi_key; -mod key; -mod key_compressed; -mod key_prepared; +mod algorithms; +mod encryption; +mod layouts; mod lut; mod utils; -pub use cggi_algo::*; -pub use key::*; -pub use key_compressed::*; -pub use key_prepared::*; +pub use algorithms::*; +pub use encryption::*; +pub use layouts::*; pub use lut::*; - pub mod tests; - -use poulpy_core::layouts::{GLWE, LWE}; -use poulpy_hal::layouts::{Backend, DataMut, DataRef, Module, Scratch}; - -pub trait BlindRotationAlgo {} - -#[derive(Clone)] -pub struct CGGI {} -impl BlindRotationAlgo for CGGI {} - -pub trait BlincRotationExecute { - fn execute( - &self, - module: &Module, - res: &mut GLWE, - lwe: &LWE, - lut: &LookUpTable, - scratch: &mut Scratch, - ); -} diff --git a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs index 4b8131a..8d0a017 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs @@ -1,88 +1,40 @@ use poulpy_hal::{ - api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, - SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, - VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAdd, VecZnxDftAddInplace, - VecZnxDftApply, VecZnxDftBytesOf, VecZnxDftSubInplace, VecZnxDftZero, VecZnxFillUniform, VecZnxIdftApply, - VecZnxIdftApplyConsume, VecZnxIdftApplyTmpBytes, VecZnxMulXpMinusOneInplace, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxSub, VecZnxSubInplace, - VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, ZnAddNormal, - ZnFillUniform, ZnNormalizeInplace, - }, - layouts::{Backend, Module, ScratchOwned, ZnxView}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeSliceImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl, - TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl, - }, + api::{ScratchOwnedAlloc, ScratchOwnedBorrow}, + layouts::{Backend, Scratch, ScratchOwned, ZnxView}, source::Source, }; use crate::tfhe::blind_rotation::{ - BlincRotationExecute, BlindRotationKey, BlindRotationKeyAlloc, BlindRotationKeyEncryptSk, BlindRotationKeyLayout, - BlindRotationKeyPrepared, CGGI, LookUpTable, cggi_blind_rotate_tmp_bytes, mod_switch_2n, + BlindRotationAlgo, BlindRotationExecute, BlindRotationKey, BlindRotationKeyEncryptSk, BlindRotationKeyFactory, + BlindRotationKeyLayout, BlindRotationKeyPrepared, BlindRotationKeyPreparedFactory, LookUpTableLayout, LookupTable, + LookupTableFactory, mod_switch_2n, }; -use poulpy_core::layouts::{ - GLWE, GLWELayout, GLWEPlaintext, GLWESecret, LWE, LWEInfos, LWELayout, LWEPlaintext, LWESecret, LWEToRef, - prepared::{GLWESecretPrepared, PrepareAlloc}, +use poulpy_core::{ + GLWEDecrypt, LWEEncryptSk, ScratchTakeCore, + layouts::{ + GLWE, GLWELayout, GLWEPlaintext, GLWESecret, GLWESecretPreparedFactory, LWE, LWEInfos, LWELayout, LWEPlaintext, + LWESecret, LWEToRef, prepared::GLWESecretPrepared, + }, }; -pub fn test_blind_rotation(module: &Module, n_lwe: usize, block_size: usize, extension_factor: usize) -where - Module: VecZnxBigBytesOf - + VecZnxDftBytesOf - + SvpPPolBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VecZnxIdftApplyTmpBytes - + VecZnxIdftApply - + VecZnxDftAdd - + VecZnxDftAddInplace - + VecZnxDftApply - + VecZnxDftZero - + SvpApplyDftToDft - + VecZnxDftSubInplace - + VecZnxBigAddSmallInplace - + VecZnxRotate - + VecZnxAddInplace - + VecZnxSubInplace - + VecZnxNormalize - + VecZnxNormalizeInplace - + VecZnxCopy - + VecZnxMulXpMinusOneInplace - + SvpPrepare - + SvpPPolAlloc - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxBigAddInplace - + VecZnxBigNormalize - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxAddNormal - + VecZnxAddScalarInplace - + VecZnxRotateInplace - + VecZnxSwitchRing - + VecZnxSub - + VmpPMatAlloc - + VmpPrepare - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + ZnFillUniform - + ZnAddNormal - + VecZnxRotateInplaceTmpBytes - + ZnNormalizeInplace, - B: Backend - + VecZnxDftAllocBytesImpl - + VecZnxBigAllocBytesImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeVecZnxDftSliceImpl - + ScratchAvailableImpl - + TakeVecZnxImpl - + TakeVecZnxSliceImpl - + TakeSliceImpl, +pub fn test_blind_rotation( + module: &M, + n_lwe: usize, + block_size: usize, + extension_factor: usize, +) where + M: BlindRotationKeyEncryptSk + + BlindRotationKeyPreparedFactory + + BlindRotationExecute + + GLWESecretPreparedFactory + + BlindRotationExecute + + LWEEncryptSk + + LookupTableFactory + + GLWEDecrypt, + BlindRotationKey, BRA>: BlindRotationKeyFactory, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchTakeCore, { let n_glwe: usize = module.n(); let base2k: usize = 19; @@ -123,18 +75,17 @@ where base2k: base2k.into(), }; - let mut scratch: ScratchOwned = ScratchOwned::::alloc(BlindRotationKey::generate_from_sk_tmp_bytes( - module, &brk_infos, - )); + let mut scratch: ScratchOwned = ScratchOwned::::alloc(BlindRotationKey::encrypt_sk_tmp_bytes(module, &brk_infos)); let mut sk_glwe: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_infos); sk_glwe.fill_ternary_prob(0.5, &mut source_xs); - let sk_glwe_dft: GLWESecretPrepared, B> = sk_glwe.prepare_alloc(module, scratch.borrow()); + let mut sk_glwe_dft: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc_from_infos(module, &glwe_infos); + sk_glwe_dft.prepare(module, &sk_glwe); let mut sk_lwe: LWESecret> = LWESecret::alloc(n_lwe.into()); sk_lwe.fill_binary_block(block_size, &mut source_xs); - let mut scratch_br: ScratchOwned = ScratchOwned::::alloc(cggi_blind_rotate_tmp_bytes( + let mut scratch_br: ScratchOwned = ScratchOwned::::alloc(BlindRotationKeyPrepared::execute_tmp_bytes( module, block_size, extension_factor, @@ -142,7 +93,7 @@ where &brk_infos, )); - let mut brk: BlindRotationKey, CGGI> = BlindRotationKey::, CGGI>::alloc(&brk_infos); + let mut brk: BlindRotationKey, BRA> = BlindRotationKey::, BRA>::alloc(&brk_infos); brk.encrypt_sk( module, @@ -171,12 +122,20 @@ where .enumerate() .for_each(|(i, x)| *x = f(i as i64)); - let mut lut: LookUpTable = LookUpTable::alloc(module, base2k, k_lut, extension_factor); + let lut_infos = LookUpTableLayout { + n: module.n().into(), + extension_factor, + k: k_lut.into(), + base2k: base2k.into(), + }; + + let mut lut: LookupTable = LookupTable::alloc(&lut_infos); lut.set(module, &f_vec, log_message_modulus + 1); let mut res: GLWE> = GLWE::alloc_from_infos(&glwe_infos); - let brk_prepared: BlindRotationKeyPrepared, CGGI, B> = brk.prepare_alloc(module, scratch.borrow()); + let mut brk_prepared: BlindRotationKeyPrepared, BRA, BE> = BlindRotationKeyPrepared::alloc(module, &brk); + brk_prepared.prepare(module, &brk, scratch_br.borrow()); brk_prepared.execute(module, &mut res, &lwe, &lut, scratch_br.borrow()); diff --git a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_lut.rs b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_lut.rs index bb6492c..80d7663 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_lut.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_lut.rs @@ -1,25 +1,12 @@ use std::vec; -use poulpy_hal::{ - api::{ - VecZnxCopy, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, - VecZnxSwitchRing, - }, - layouts::{Backend, Module}, - oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeSliceImpl}, -}; +use poulpy_hal::api::ModuleN; -use crate::tfhe::blind_rotation::{DivRound, LookUpTable}; +use crate::tfhe::blind_rotation::{DivRound, LookUpTableLayout, LookupTable, LookupTableFactory}; -pub fn test_lut_standard(module: &Module) +pub fn test_lut_standard(module: &M) where - Module: VecZnxRotateInplace - + VecZnxNormalizeInplace - + VecZnxNormalizeTmpBytes - + VecZnxSwitchRing - + VecZnxCopy - + VecZnxRotateInplaceTmpBytes, - B: Backend + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl + TakeSliceImpl, + M: LookupTableFactory + ModuleN, { let base2k: usize = 20; let k_lut: usize = 40; @@ -33,7 +20,14 @@ where .enumerate() .for_each(|(i, x)| *x = (i as i64) - 8); - let mut lut: LookUpTable = LookUpTable::alloc(module, base2k, k_lut, extension_factor); + let lut_infos: LookUpTableLayout = LookUpTableLayout { + n: module.n().into(), + extension_factor, + k: k_lut.into(), + base2k: base2k.into(), + }; + + let mut lut: LookupTable = LookupTable::alloc(&lut_infos); lut.set(module, &f, log_scale); let half_step: i64 = lut.domain_size().div_round(message_modulus << 1) as i64; @@ -51,15 +45,9 @@ where }); } -pub fn test_lut_extended(module: &Module) +pub fn test_lut_extended(module: &M) where - Module: VecZnxRotateInplace - + VecZnxNormalizeInplace - + VecZnxNormalizeTmpBytes - + VecZnxSwitchRing - + VecZnxCopy - + VecZnxRotateInplaceTmpBytes, - B: Backend + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl + TakeSliceImpl, + M: LookupTableFactory + ModuleN, { let base2k: usize = 20; let k_lut: usize = 40; @@ -73,7 +61,14 @@ where .enumerate() .for_each(|(i, x)| *x = (i as i64) - 8); - let mut lut: LookUpTable = LookUpTable::alloc(module, base2k, k_lut, extension_factor); + let lut_infos: LookUpTableLayout = LookUpTableLayout { + n: module.n().into(), + extension_factor, + k: k_lut.into(), + base2k: base2k.into(), + }; + + let mut lut: LookupTable = LookupTable::alloc(&lut_infos); lut.set(module, &f, log_scale); let half_step: i64 = lut.domain_size().div_round(message_modulus << 1) as i64; diff --git a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_serialization.rs b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_serialization.rs index f25a236..341dd49 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_serialization.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_serialization.rs @@ -1,8 +1,6 @@ use poulpy_hal::test_suite::serialization::test_reader_writer_interface; -use crate::tfhe::blind_rotation::{ - BlindRotationKey, BlindRotationKeyAlloc, BlindRotationKeyCompressed, BlindRotationKeyLayout, CGGI, -}; +use crate::tfhe::blind_rotation::{BlindRotationKey, BlindRotationKeyCompressed, BlindRotationKeyLayout, CGGI}; #[test] fn test_cggi_blind_rotation_key_serialization() { @@ -14,7 +12,6 @@ fn test_cggi_blind_rotation_key_serialization() { dnum: 2_usize.into(), rank: 2_usize.into(), }; - let original: BlindRotationKey, CGGI> = BlindRotationKey::alloc(&layout); test_reader_writer_interface(original); } @@ -29,7 +26,6 @@ fn test_cggi_blind_rotation_key_compressed_serialization() { dnum: 2_usize.into(), rank: 2_usize.into(), }; - let original: BlindRotationKeyCompressed, CGGI> = BlindRotationKeyCompressed::alloc(&layout); test_reader_writer_interface(original); } diff --git a/poulpy-schemes/src/tfhe/blind_rotation/tests/implementation/cpu_spqlios/fft64.rs b/poulpy-schemes/src/tfhe/blind_rotation/tests/implementation/cpu_spqlios/fft64.rs deleted file mode 100644 index ecb2421..0000000 --- a/poulpy-schemes/src/tfhe/blind_rotation/tests/implementation/cpu_spqlios/fft64.rs +++ /dev/null @@ -1,37 +0,0 @@ -use poulpy_backend::cpu_spqlios::FFT64Spqlios; -use poulpy_hal::{api::ModuleNew, layouts::Module}; - -use crate::tfhe::blind_rotation::tests::{ - generic_blind_rotation::test_blind_rotation, - generic_lut::{test_lut_extended, test_lut_standard}, -}; - -#[test] -fn lut_standard() { - let module: Module = Module::::new(32); - test_lut_standard(&module); -} - -#[test] -fn lut_extended() { - let module: Module = Module::::new(32); - test_lut_extended(&module); -} - -#[test] -fn standard() { - let module: Module = Module::::new(512); - test_blind_rotation(&module, 224, 1, 1); -} - -#[test] -fn block_binary() { - let module: Module = Module::::new(512); - test_blind_rotation(&module, 224, 7, 1); -} - -#[test] -fn block_binary_extended() { - let module: Module = Module::::new(512); - test_blind_rotation(&module, 224, 7, 2); -} diff --git a/poulpy-schemes/src/tfhe/blind_rotation/tests/implementation/cpu_spqlios/mod.rs b/poulpy-schemes/src/tfhe/blind_rotation/tests/implementation/cpu_spqlios/mod.rs deleted file mode 100644 index aebaafb..0000000 --- a/poulpy-schemes/src/tfhe/blind_rotation/tests/implementation/cpu_spqlios/mod.rs +++ /dev/null @@ -1 +0,0 @@ -mod fft64; diff --git a/poulpy-schemes/src/tfhe/blind_rotation/tests/implementation/fft64.rs b/poulpy-schemes/src/tfhe/blind_rotation/tests/implementation/fft64.rs new file mode 100644 index 0000000..5a471e2 --- /dev/null +++ b/poulpy-schemes/src/tfhe/blind_rotation/tests/implementation/fft64.rs @@ -0,0 +1,40 @@ +use poulpy_backend::cpu_fft64_ref::FFT64Ref; +use poulpy_hal::{api::ModuleNew, layouts::Module}; + +use crate::tfhe::blind_rotation::{ + CGGI, + tests::{ + generic_blind_rotation::test_blind_rotation, + generic_lut::{test_lut_extended, test_lut_standard}, + }, +}; + +#[test] +fn lut_standard() { + let module: Module = Module::::new(32); + test_lut_standard(&module); +} + +#[test] +fn lut_extended() { + let module: Module = Module::::new(32); + test_lut_extended(&module); +} + +#[test] +fn standard() { + let module: Module = Module::::new(512); + test_blind_rotation::(&module, 224, 1, 1); +} + +#[test] +fn block_binary() { + let module: Module = Module::::new(512); + test_blind_rotation::(&module, 224, 7, 1); +} + +#[test] +fn block_binary_extended() { + let module: Module = Module::::new(512); + test_blind_rotation::(&module, 224, 7, 2); +} diff --git a/poulpy-schemes/src/tfhe/blind_rotation/tests/implementation/mod.rs b/poulpy-schemes/src/tfhe/blind_rotation/tests/implementation/mod.rs index f2bc1d4..aebaafb 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/tests/implementation/mod.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/tests/implementation/mod.rs @@ -1 +1 @@ -mod cpu_spqlios; +mod fft64; diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/key.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/key.rs index a471ffc..81f99e6 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/key.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/key.rs @@ -1,14 +1,13 @@ use poulpy_core::layouts::{ AutomorphismKey, AutomorphismKeyLayout, GGLWEInfos, GGSWInfos, GLWE, GLWEInfos, GLWESecret, LWEInfos, LWESecret, TensorKey, TensorKeyLayout, - prepared::{GLWEAutomorphismKeyPrepared, GLWESecretPrepared, PrepareAlloc, TensorKeyPrepared}, + prepared::{GLWEAutomorphismKeyPrepared, GLWESecretPrepared, TensorKeyPrepared}, }; use std::collections::HashMap; use poulpy_hal::{ api::{ - ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, TakeScalarZnx, - TakeSvpPPol, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, + ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpPMatAlloc, VmpPrepare, diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/mod.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/mod.rs index 5835765..ed3b6a1 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/mod.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/mod.rs @@ -1,6 +1,8 @@ mod circuit; mod key; -pub mod tests; + +//[cfg(tests)] +//pub mod tests; pub use circuit::*; pub use key::*; diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/mod.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/mod.rs index f9bc7d9..22f8f4f 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/mod.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/mod.rs @@ -1,4 +1,3 @@ pub mod circuit_bootstrapping; -#[cfg(test)] mod implementation; diff --git a/poulpy-schemes/src/tfhe/mod.rs b/poulpy-schemes/src/tfhe/mod.rs index 85c84d4..cc2dbe9 100644 --- a/poulpy-schemes/src/tfhe/mod.rs +++ b/poulpy-schemes/src/tfhe/mod.rs @@ -1,3 +1,3 @@ -pub mod bdd_arithmetic; +// pub mod bdd_arithmetic; pub mod blind_rotation; -pub mod circuit_bootstrapping; +//pub mod circuit_bootstrapping;