fix blind rotation

This commit is contained in:
Pro7ech
2025-10-21 14:26:53 +02:00
parent fef2a2fc27
commit 0926913001
37 changed files with 1106 additions and 961 deletions

View File

@@ -4,6 +4,7 @@ use poulpy_hal::{
}; };
use crate::{ use crate::{
GetDistribution,
dist::Distribution, dist::Distribution,
layouts::{Base2K, Degree, LWEInfos, TorusPrecision}, layouts::{Base2K, Degree, LWEInfos, TorusPrecision},
}; };
@@ -22,6 +23,12 @@ impl LWESecret<Vec<u8>> {
} }
} }
impl<D: DataRef> GetDistribution for LWESecret<D> {
fn dist(&self) -> &Distribution {
&self.dist
}
}
impl<D: DataRef> LWESecret<D> { impl<D: DataRef> LWESecret<D> {
pub fn raw(&self) -> &[i64] { pub fn raw(&self) -> &[i64] {
self.data.at(0, 0) self.data.at(0, 0)

View File

@@ -23,9 +23,8 @@ pub use external_product::*;
pub use glwe_packing::*; pub use glwe_packing::*;
pub use keyswitching::*; pub use keyswitching::*;
pub use noise::*; pub use noise::*;
pub use scratch::*;
pub use encryption::SIGMA; pub use encryption::SIGMA;
pub use scratch::*;
pub mod tests; pub mod tests;

View File

@@ -217,6 +217,8 @@ where
} }
} }
impl<BE: Backend> GLWEMulXpMinusOne<BE> for Module<BE> where Self: ModuleN + VecZnxMulXpMinusOne + VecZnxMulXpMinusOneInplace<BE> {}
pub trait GLWEMulXpMinusOne<BE: Backend> pub trait GLWEMulXpMinusOne<BE: Backend>
where where
Self: ModuleN + VecZnxMulXpMinusOne + VecZnxMulXpMinusOneInplace<BE>, Self: ModuleN + VecZnxMulXpMinusOne + VecZnxMulXpMinusOneInplace<BE>,

View File

@@ -34,13 +34,9 @@ pub trait ScratchTakeBasic
where where
Self: TakeSlice, Self: TakeSlice,
{ {
fn take_scalar_znx(&mut self, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) fn take_scalar_znx(&mut self, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) {
{
let (take_slice, rem_slice) = self.take_slice(ScalarZnx::bytes_of(n, cols)); let (take_slice, rem_slice) = self.take_slice(ScalarZnx::bytes_of(n, cols));
( (ScalarZnx::from_data(take_slice, n, cols), rem_slice)
ScalarZnx::from_data(take_slice, n, cols),
rem_slice,
)
} }
fn take_svp_ppol<M, B: Backend>(&mut self, module: &M, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Self) fn take_svp_ppol<M, B: Backend>(&mut self, module: &M, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Self)
@@ -51,12 +47,9 @@ where
(SvpPPol::from_data(take_slice, module.n(), cols), rem_slice) (SvpPPol::from_data(take_slice, module.n(), cols), rem_slice)
} }
fn take_vec_znx(&mut self, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self){ fn take_vec_znx(&mut self, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) {
let (take_slice, rem_slice) = self.take_slice(VecZnx::bytes_of(n, cols, size)); let (take_slice, rem_slice) = self.take_slice(VecZnx::bytes_of(n, cols, size));
( (VecZnx::from_data(take_slice, n, cols, size), rem_slice)
VecZnx::from_data(take_slice, n, cols, size),
rem_slice,
)
} }
fn take_vec_znx_big<M, B: Backend>(&mut self, module: &M, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, &mut Self) fn take_vec_znx_big<M, B: Backend>(&mut self, module: &M, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, &mut Self)
@@ -102,7 +95,7 @@ where
(slice, scratch) (slice, scratch)
} }
fn take_vec_znx_slice(&mut self, n: usize, len: usize, cols: usize, size: usize) -> (Vec<VecZnx<&mut [u8]>>, &mut Self){ fn take_vec_znx_slice(&mut self, len: usize, n: usize, cols: usize, size: usize) -> (Vec<VecZnx<&mut [u8]>>, &mut Self) {
let mut scratch: &mut Self = self; let mut scratch: &mut Self = self;
let mut slice: Vec<VecZnx<&mut [u8]>> = Vec::with_capacity(len); let mut slice: Vec<VecZnx<&mut [u8]>> = Vec::with_capacity(len);
for _ in 0..len { for _ in 0..len {
@@ -133,13 +126,12 @@ where
fn take_mat_znx( fn take_mat_znx(
&mut self, &mut self,
n: usize, n: usize,
rows: usize, rows: usize,
cols_in: usize, cols_in: usize,
cols_out: usize, cols_out: usize,
size: usize, size: usize,
) -> (MatZnx<&mut [u8]>, &mut Self) ) -> (MatZnx<&mut [u8]>, &mut Self) {
{
let (take_slice, rem_slice) = self.take_slice(MatZnx::bytes_of(n, rows, cols_in, cols_out, size)); let (take_slice, rem_slice) = self.take_slice(MatZnx::bytes_of(n, rows, cols_in, cols_out, size));
( (
MatZnx::from_data(take_slice, n, rows, cols_in, cols_out, size), MatZnx::from_data(take_slice, n, rows, cols_in, cols_out, size),

View File

@@ -2,17 +2,19 @@ use std::marker::PhantomData;
use poulpy_core::layouts::{Base2K, GLWE, GLWEInfos, GLWEPlaintextLayout, LWEInfos, Rank, TorusPrecision}; use poulpy_core::layouts::{Base2K, GLWE, GLWEInfos, GLWEPlaintextLayout, LWEInfos, Rank, TorusPrecision};
use poulpy_core::{TakeGLWEPlaintext, layouts::prepared::GLWESecretPrepared}; #[cfg(test)]
use poulpy_core::ScratchTakeCore;
use poulpy_core::{layouts::prepared::GLWESecretPrepared};
use poulpy_hal::api::VecZnxBigBytesOf; use poulpy_hal::api::VecZnxBigBytesOf;
#[cfg(test)] #[cfg(test)]
use poulpy_hal::api::{ use poulpy_hal::api::{
ScratchAvailable, TakeVecZnx, VecZnxAddInplace, VecZnxAddNormal, VecZnxFillUniform, VecZnxNormalize, VecZnxSub, VecZnxAddInplace, VecZnxAddNormal, VecZnxFillUniform, VecZnxNormalize, VecZnxSub,
}; };
#[cfg(test)] #[cfg(test)]
use poulpy_hal::source::Source; use poulpy_hal::source::Source;
use poulpy_hal::{ use poulpy_hal::{
api::{ api::{
TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxDftApply, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxDftApply,
VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalizeTmpBytes, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalizeTmpBytes,
}, },
layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}, layouts::{Backend, Data, DataMut, DataRef, Module, Scratch},
@@ -96,7 +98,7 @@ impl<D: DataMut, T: UnsignedInteger + ToBits> FheUintBlocks<D, T> {
+ VecZnxAddNormal + VecZnxAddNormal
+ VecZnxNormalize<BE> + VecZnxNormalize<BE>
+ VecZnxSub, + VecZnxSub,
Scratch<BE>: TakeVecZnxDft<BE> + ScratchAvailable + TakeVecZnx + TakeGLWEPlaintext<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
use poulpy_core::layouts::GLWEPlaintextLayout; use poulpy_core::layouts::GLWEPlaintextLayout;
@@ -136,7 +138,7 @@ impl<D: DataRef, T: UnsignedInteger + FromBits + ToBits> FheUintBlocks<D, T> {
+ VecZnxBigAddInplace<BE> + VecZnxBigAddInplace<BE>
+ VecZnxBigAddSmallInplace<BE> + VecZnxBigAddSmallInplace<BE>
+ VecZnxBigNormalize<BE>, + VecZnxBigNormalize<BE>,
Scratch<BE>: TakeVecZnxDft<BE> + TakeVecZnxBig<BE> + TakeGLWEPlaintext<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
@@ -186,7 +188,7 @@ impl<D: DataRef, T: UnsignedInteger + FromBits + ToBits> FheUintBlocks<D, T> {
+ VecZnxNormalizeTmpBytes + VecZnxNormalizeTmpBytes
+ VecZnxSubInplace + VecZnxSubInplace
+ VecZnxNormalizeInplace<BE>, + VecZnxNormalizeInplace<BE>,
Scratch<BE>: TakeGLWEPlaintext<BE> + TakeVecZnxDft<BE> + TakeVecZnxBig<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {

View File

@@ -144,7 +144,7 @@ impl<D: DataMut, T: UnsignedInteger + ToBits, BE: Backend> FheUintBlocksPrep<D,
assert_eq!(sk.n(), module.n() as u32); assert_eq!(sk.n(), module.n() as u32);
} }
let (mut tmp_ggsw, scratch_1) = scratch.take_ggsw(module, self); let (mut tmp_ggsw, scratch_1) = scratch.take_ggsw(self);
let (mut pt, scratch_2) = scratch_1.take_scalar_znx(module.n(), 1); let (mut pt, scratch_2) = scratch_1.take_scalar_znx(module.n(), 1);
for i in 0..T::WORD_SIZE { for i in 0..T::WORD_SIZE {

View File

@@ -1,14 +1,12 @@
use itertools::Itertools; use itertools::Itertools;
use poulpy_core::{ use poulpy_core::{
GLWEOperations, TakeGLWEPlaintext, TakeGLWESlice, glwe_packing,
layouts::{ layouts::{
GLWE, GLWEInfos, GLWEPlaintextLayout, LWEInfos, TorusPrecision, prepared::{GLWEAutomorphismKeyPrepared, GLWESecretPrepared}, GLWEInfos, GLWEPlaintextLayout, LWEInfos, TorusPrecision, GLWE
prepared::{GLWEAutomorphismKeyPrepared, GLWESecretPrepared}, }, ScratchTakeCore,
},
}; };
use poulpy_hal::{ use poulpy_hal::{
api::{ api::{
ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, ScratchAvailable, SvpApplyDftToDftInplace, VecZnxAddInplace, VecZnxAddNormal,
VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace,
VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxCopy, VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxCopy,
VecZnxDftApply, VecZnxDftBytesOf, VecZnxDftCopy, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxDftApply, VecZnxDftBytesOf, VecZnxDftCopy, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA,
@@ -62,7 +60,7 @@ impl<D: DataMut, T: UnsignedInteger> FheUintWord<D, T> {
+ VecZnxAutomorphismInplace<BE> + VecZnxAutomorphismInplace<BE>
+ VecZnxBigSubSmallNegateInplace<BE> + VecZnxBigSubSmallNegateInplace<BE>
+ VecZnxRotate, + VecZnxRotate,
Scratch<BE>: TakeVecZnxDft<BE> + ScratchAvailable + TakeVecZnx + TakeGLWESlice, Scratch<BE>: ScratchTakeCore<BE>,
{ {
// Repacks the GLWE ciphertexts bits // Repacks the GLWE ciphertexts bits
let gap: usize = module.n() / T::WORD_SIZE; let gap: usize = module.n() / T::WORD_SIZE;
@@ -122,7 +120,7 @@ impl<D: DataMut, T: UnsignedInteger + ToBits> FheUintWord<D, T> {
+ VecZnxAddNormal + VecZnxAddNormal
+ VecZnxNormalize<BE> + VecZnxNormalize<BE>
+ VecZnxSub, + VecZnxSub,
Scratch<BE>: TakeVecZnxDft<BE> + ScratchAvailable + TakeVecZnx + TakeGLWEPlaintext<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
@@ -167,7 +165,7 @@ impl<D: DataRef, T: UnsignedInteger + FromBits> FheUintWord<D, T> {
+ VecZnxBigAddInplace<BE> + VecZnxBigAddInplace<BE>
+ VecZnxBigAddSmallInplace<BE> + VecZnxBigAddSmallInplace<BE>
+ VecZnxBigNormalize<BE>, + VecZnxBigNormalize<BE>,
Scratch<BE>: TakeVecZnxDft<BE> + TakeVecZnxBig<BE> + TakeGLWEPlaintext<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {

View File

@@ -1,10 +1,8 @@
use itertools::Itertools; use itertools::Itertools;
use poulpy_core::{ use poulpy_core::{
GLWEExternalProductInplace, GLWEOperations, TakeGLWESlice,
layouts::{ layouts::{
GLWE, GLWEToMut, LWEInfos, prepared::{GGSWPrepared, GGSWPreparedToRef}, GLWEToMut, LWEInfos, GLWE
prepared::{GGSWPrepared, GGSWPreparedToRef}, }, GLWEExternalProduct, ScratchTakeCore
},
}; };
use poulpy_hal::{ use poulpy_hal::{
api::{VecZnxAddInplace, VecZnxCopy, VecZnxNegateInplace, VecZnxSub}, api::{VecZnxAddInplace, VecZnxCopy, VecZnxNegateInplace, VecZnxSub},
@@ -49,7 +47,7 @@ impl<C: BitCircuitInfo, const N: usize, T: UnsignedInteger, BE: Backend> Circuit
where where
Self: GetBitCircuitInfo<T>, Self: GetBitCircuitInfo<T>,
Module<BE>: Cmux<BE> + VecZnxCopy, Module<BE>: Cmux<BE> + VecZnxCopy,
Scratch<BE>: TakeGLWESlice, Scratch<BE>: ScratchTakeCore<BE>,
{ {
fn execute<O>( fn execute<O>(
&self, &self,
@@ -169,7 +167,7 @@ pub trait Cmux<BE: Backend> {
impl<BE: Backend> Cmux<BE> for Module<BE> impl<BE: Backend> Cmux<BE> for Module<BE>
where where
Module<BE>: GLWEExternalProductInplace<BE> + VecZnxSub + VecZnxCopy + VecZnxNegateInplace + VecZnxAddInplace, Module<BE>: GLWEExternalProduct<BE> + VecZnxSub + VecZnxCopy + VecZnxNegateInplace + VecZnxAddInplace,
{ {
fn cmux<O, T, F, S>(&self, out: &mut GLWE<O>, t: &GLWE<T>, f: &GLWE<F>, s: &GGSWPrepared<S, BE>, scratch: &mut Scratch<BE>) fn cmux<O, T, F, S>(&self, out: &mut GLWE<O>, t: &GLWE<T>, f: &GLWE<F>, s: &GGSWPrepared<S, BE>, scratch: &mut Scratch<BE>)
where where

View File

@@ -9,16 +9,13 @@ use crate::tfhe::{
}, },
}; };
use poulpy_core::{ use poulpy_core::{
TakeGGSW, TakeGLWE,
layouts::{ layouts::{
GLWESecret, GLWEToLWEKeyLayout, GLWEToLWESwitchingKey, LWE, LWESecret, prepared::GLWEToLWESwitchingKeyPrepared, GLWESecret, GLWEToLWEKeyLayout, GLWEToLWESwitchingKey, LWESecret
prepared::{GLWEToLWESwitchingKeyPrepared, Prepare, PrepareAlloc}, }, ScratchTakeCore,
},
}; };
use poulpy_hal::{ use poulpy_hal::{
api::{ api::{
ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, TakeScalarZnx, ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace,
TakeSvpPPol, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace,
VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize,
VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing,
@@ -96,7 +93,7 @@ impl<BRA: BlindRotationAlgo> BDDKey<Vec<u8>, Vec<u8>, BRA> {
+ SvpPPolAlloc<BE> + SvpPPolAlloc<BE>
+ VecZnxAutomorphism + VecZnxAutomorphism
+ VecZnxAutomorphismInplace<BE>, + VecZnxAutomorphismInplace<BE>,
Scratch<BE>: TakeVecZnxDft<BE> + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeSvpPPol<BE> + TakeVecZnxBig<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
let mut ks: GLWEToLWESwitchingKey<Vec<u8>> = GLWEToLWESwitchingKey::alloc(&infos.ks_infos()); let mut ks: GLWEToLWESwitchingKey<Vec<u8>> = GLWEToLWESwitchingKey::alloc(&infos.ks_infos());
ks.encrypt_sk(module, sk_lwe, sk_glwe, source_xa, source_xe, scratch); ks.encrypt_sk(module, sk_lwe, sk_glwe, source_xa, source_xe, scratch);
@@ -217,7 +214,7 @@ where
+ VecZnxBigNormalize<BE> + VecZnxBigNormalize<BE>
+ VecZnxNormalize<BE> + VecZnxNormalize<BE>
+ VecZnxNormalizeTmpBytes, + VecZnxNormalizeTmpBytes,
Scratch<BE>: ScratchAvailable + TakeVecZnxDft<BE> + TakeGLWE + TakeVecZnx + TakeGGSW, Scratch<BE>: ScratchTakeCore<BE>,
CircuitBootstrappingKeyPrepared<CBT, BRA, BE>: CirtuitBootstrappingExecute<BE>, CircuitBootstrappingKeyPrepared<CBT, BRA, BE>: CirtuitBootstrappingExecute<BE>,
{ {
fn prepare( fn prepare(

View File

@@ -1,158 +1,142 @@
use itertools::izip; use itertools::izip;
use poulpy_hal::{ use poulpy_hal::{
api::{ api::{
ScratchAvailable, SvpApplyDftToDft, SvpPPolBytesOf, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice, ModuleN, ScratchTakeBasic, SvpApplyDftToDft, VecZnxBigAddSmallInplace, VecZnxBigBytesOf, VecZnxBigNormalize,
TakeVecZnxSlice, VecZnxAddInplace, VecZnxBigAddSmallInplace, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAddInplace, VecZnxDftApply, VecZnxDftBytesOf, VecZnxDftSubInplace,
VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftApply, VecZnxDftBytesOf, VecZnxDftZero, VecZnxIdftApply, VecZnxIdftApplyTmpBytes, VecZnxRotate, VmpApplyDftToDft, VmpApplyDftToDftTmpBytes,
VecZnxDftSubInplace, VecZnxDftZero, VecZnxIdftApply, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpBytes,
VecZnxMulXpMinusOneInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate,
VecZnxSubInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
}, },
layouts::{Backend, DataMut, DataRef, Module, Scratch, SvpPPol, VecZnx, ZnxView, ZnxZero}, layouts::{Backend, DataMut, DataRef, Module, Scratch, SvpPPol, VecZnx, ZnxZero},
}; };
use poulpy_core::{ use poulpy_core::{
Distribution, GLWEOperations, TakeGLWE, Distribution, GLWEAdd, GLWEExternalProduct, GLWEMulXpMinusOne, GLWENormalize, ScratchTakeCore,
layouts::{GGSWInfos, GLWE, GLWEInfos, GLWEToMut, LWE, LWEInfos, LWEToRef}, layouts::{GGSWInfos, GLWE, GLWEInfos, GLWEToMut, LWE, LWEInfos, LWEToRef},
}; };
use crate::tfhe::blind_rotation::{ use crate::tfhe::blind_rotation::{
BlincRotationExecute, BlindRotationKeyInfos, BlindRotationKeyPrepared, CGGI, LookUpTable, LookUpTableRotationDirection, BlindRotationExecute, BlindRotationKeyInfos, BlindRotationKeyPrepared, CGGI, LookupTable, mod_switch_2n,
}; };
#[allow(clippy::too_many_arguments)] impl<BE: Backend> BlindRotationExecute<CGGI, BE> for Module<BE>
pub fn cggi_blind_rotate_tmp_bytes<B: Backend, OUT, GGSW>(
module: &Module<B>,
block_size: usize,
extension_factor: usize,
glwe_infos: &OUT,
brk_infos: &GGSW,
) -> usize
where where
OUT: GLWEInfos, Self: VecZnxDftBytesOf
GGSW: GGSWInfos,
Module<B>: VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxNormalizeTmpBytes
+ VecZnxBigBytesOf + VecZnxBigBytesOf
+ VecZnxIdftApplyTmpBytes
+ VecZnxBigNormalizeTmpBytes,
{
let brk_size: usize = brk_infos.size();
if block_size > 1 {
let cols: usize = (brk_infos.rank() + 1).into();
let dnum: usize = brk_infos.dnum().into();
let acc_dft: usize = module.bytes_of_vec_znx_dft(cols, dnum) * extension_factor;
let acc_big: usize = module.bytes_of_vec_znx_big(1, brk_size);
let vmp_res: usize = module.bytes_of_vec_znx_dft(cols, brk_size) * extension_factor;
let vmp_xai: usize = module.bytes_of_vec_znx_dft(1, brk_size);
let acc_dft_add: usize = vmp_res;
let vmp: usize = module.vmp_apply_dft_to_dft_tmp_bytes(brk_size, dnum, dnum, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2)
let acc: usize = if extension_factor > 1 {
VecZnx::bytes_of(module.n(), cols, glwe_infos.size()) * extension_factor
} else {
0
};
acc + acc_dft
+ acc_dft_add
+ vmp_res
+ vmp_xai
+ (vmp | (acc_big + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_apply_tmp_bytes())))
} else {
GLWE::bytes_of(glwe_infos) + GLWE::external_product_inplace_tmp_bytes(module, glwe_infos, brk_infos)
}
}
impl<D: DataRef, B: Backend> BlincRotationExecute<B> for BlindRotationKeyPrepared<D, CGGI, B>
where
Module<B>: VecZnxBigBytesOf
+ VecZnxDftBytesOf
+ SvpPPolBytesOf
+ VmpApplyDftToDftTmpBytes + VmpApplyDftToDftTmpBytes
+ VecZnxBigNormalizeTmpBytes + VecZnxBigNormalizeTmpBytes
+ VecZnxIdftApplyTmpBytes + VecZnxIdftApplyTmpBytes
+ VecZnxIdftApply<B> + GLWEExternalProduct<BE>
+ VecZnxDftAdd<B> + ModuleN
+ VecZnxDftAddInplace<B>
+ VecZnxDftApply<B>
+ VecZnxDftZero<B>
+ SvpApplyDftToDft<B>
+ VecZnxDftSubInplace<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxRotate + VecZnxRotate
+ VecZnxAddInplace + VecZnxDftApply<BE>
+ VecZnxSubInplace + VecZnxDftZero<BE>
+ VecZnxNormalize<B> + VmpApplyDftToDft<BE>
+ VecZnxNormalizeInplace<B> + SvpApplyDftToDft<BE>
+ VecZnxDftAddInplace<BE>
+ VecZnxDftSubInplace<BE>
+ VecZnxIdftApply<BE>
+ VecZnxBigAddSmallInplace<BE>
+ VecZnxBigNormalize<BE>
+ VecZnxCopy + VecZnxCopy
+ VecZnxMulXpMinusOneInplace<B> + GLWEMulXpMinusOne<BE>
+ VmpApplyDftToDft<B> + GLWEAdd
+ VmpApplyDftToDftAdd<B> + GLWENormalize<BE>,
+ VecZnxIdftApplyConsume<B> Scratch<BE>: ScratchTakeCore<BE>,
+ VecZnxBigNormalize<B>
+ VecZnxNormalizeTmpBytes,
Scratch<B>: TakeVecZnxDftSlice<B> + TakeVecZnxDft<B> + TakeVecZnxBig<B> + TakeVecZnxSlice + TakeVecZnx + ScratchAvailable,
{ {
fn execute<DR: DataMut, DI: DataRef>( fn blind_rotation_execute_tmp_bytes<G, B>(
&self,
block_size: usize,
extension_factor: usize,
glwe_infos: &G,
brk_infos: &B,
) -> usize
where
G: GLWEInfos,
B: GGSWInfos,
{
let brk_size: usize = brk_infos.size();
if block_size > 1 {
let cols: usize = (brk_infos.rank() + 1).into();
let dnum: usize = brk_infos.dnum().into();
let acc_dft: usize = self.bytes_of_vec_znx_dft(cols, dnum) * extension_factor;
let acc_big: usize = self.bytes_of_vec_znx_big(1, brk_size);
let vmp_res: usize = self.bytes_of_vec_znx_dft(cols, brk_size) * extension_factor;
let vmp_xai: usize = self.bytes_of_vec_znx_dft(1, brk_size);
let acc_dft_add: usize = vmp_res;
let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes(brk_size, dnum, dnum, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2)
let acc: usize = if extension_factor > 1 {
VecZnx::bytes_of(self.n(), cols, glwe_infos.size()) * extension_factor
} else {
0
};
acc + acc_dft
+ acc_dft_add
+ vmp_res
+ vmp_xai
+ (vmp
| (acc_big
+ (self
.vec_znx_big_normalize_tmp_bytes()
.max(self.vec_znx_idft_apply_tmp_bytes()))))
} else {
GLWE::bytes_of_from_infos(glwe_infos) + GLWE::external_product_tmp_bytes(self, glwe_infos, glwe_infos, brk_infos)
}
}
fn blind_rotation_execute<DR, DL, DB>(
&self, &self,
module: &Module<B>,
res: &mut GLWE<DR>, res: &mut GLWE<DR>,
lwe: &LWE<DI>, lwe: &LWE<DL>,
lut: &LookUpTable, lut: &LookupTable,
scratch: &mut Scratch<B>, brk: &BlindRotationKeyPrepared<DB, CGGI, BE>,
) { scratch: &mut Scratch<BE>,
match self.dist { ) where
DR: DataMut,
DL: DataRef,
DB: DataRef,
{
match brk.dist {
Distribution::BinaryBlock(_) | Distribution::BinaryFixed(_) | Distribution::BinaryProb(_) | Distribution::ZERO => { Distribution::BinaryBlock(_) | Distribution::BinaryFixed(_) | Distribution::BinaryProb(_) | Distribution::ZERO => {
if lut.extension_factor() > 1 { if lut.extension_factor() > 1 {
execute_block_binary_extended(module, res, lwe, lut, self, scratch) execute_block_binary_extended(self, res, lwe, lut, brk, scratch)
} else if self.block_size() > 1 { } else if brk.block_size() > 1 {
execute_block_binary(module, res, lwe, lut, self, scratch); execute_block_binary(self, res, lwe, lut, brk, scratch);
} else { } else {
execute_standard(module, res, lwe, lut, self, scratch); execute_standard(self, res, lwe, lut, brk, scratch);
} }
} }
_ => panic!("invalid CGGI distribution"), _ => panic!("invalid CGGI distribution (have you prepared the key?)"),
} }
} }
} }
fn execute_block_binary_extended<DataRes, DataIn, DataBrk, B: Backend>( fn execute_block_binary_extended<DataRes, DataIn, DataBrk, M, BE: Backend>(
module: &Module<B>, module: &M,
res: &mut GLWE<DataRes>, res: &mut GLWE<DataRes>,
lwe: &LWE<DataIn>, lwe: &LWE<DataIn>,
lut: &LookUpTable, lut: &LookupTable,
brk: &BlindRotationKeyPrepared<DataBrk, CGGI, B>, brk: &BlindRotationKeyPrepared<DataBrk, CGGI, BE>,
scratch: &mut Scratch<B>, scratch: &mut Scratch<BE>,
) where ) where
DataRes: DataMut, DataRes: DataMut,
DataIn: DataRef, DataIn: DataRef,
DataBrk: DataRef, DataBrk: DataRef,
Module<B>: VecZnxBigBytesOf M: VecZnxDftBytesOf
+ VecZnxDftBytesOf + ModuleN
+ SvpPPolBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VecZnxIdftApplyTmpBytes
+ VecZnxIdftApply<B>
+ VecZnxDftAdd<B>
+ VecZnxDftAddInplace<B>
+ VecZnxDftApply<B>
+ VecZnxDftZero<B>
+ SvpApplyDftToDft<B>
+ VecZnxDftSubInplace<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxRotate + VecZnxRotate
+ VecZnxAddInplace + VecZnxDftApply<BE>
+ VecZnxSubInplace + VecZnxDftZero<BE>
+ VecZnxNormalize<B> + VmpApplyDftToDft<BE>
+ VecZnxNormalizeInplace<B> + SvpApplyDftToDft<BE>
+ VecZnxDftAddInplace<BE>
+ VecZnxDftSubInplace<BE>
+ VecZnxIdftApply<BE>
+ VecZnxBigAddSmallInplace<BE>
+ VecZnxBigNormalize<BE>
+ VecZnxCopy + VecZnxCopy
+ VecZnxMulXpMinusOneInplace<B> + VecZnxBigBytesOf,
+ VecZnxBigNormalize<B> Scratch<BE>: ScratchTakeCore<BE>,
+ VmpApplyDftToDft<B>,
Scratch<B>: TakeVecZnxDftSlice<B> + TakeVecZnxDft<B> + TakeVecZnxBig<B> + TakeVecZnxSlice + ScratchAvailable + TakeVecZnx,
{ {
let n_glwe: usize = brk.n_glwe().into(); let n_glwe: usize = brk.n_glwe().into();
let extension_factor: usize = lut.extension_factor(); let extension_factor: usize = lut.extension_factor();
@@ -161,16 +145,16 @@ fn execute_block_binary_extended<DataRes, DataIn, DataBrk, B: Backend>(
let cols: usize = (res.rank() + 1).into(); let cols: usize = (res.rank() + 1).into();
let (mut acc, scratch_1) = scratch.take_vec_znx_slice(extension_factor, n_glwe, cols, res.size()); let (mut acc, scratch_1) = scratch.take_vec_znx_slice(extension_factor, n_glwe, cols, res.size());
let (mut acc_dft, scratch_2) = scratch_1.take_vec_znx_dft_slice(extension_factor, n_glwe, cols, dnum); let (mut acc_dft, scratch_2) = scratch_1.take_vec_znx_dft_slice(module, extension_factor, cols, dnum);
let (mut vmp_res, scratch_3) = scratch_2.take_vec_znx_dft_slice(extension_factor, n_glwe, cols, brk.size()); let (mut vmp_res, scratch_3) = scratch_2.take_vec_znx_dft_slice(module, extension_factor, cols, brk.size());
let (mut acc_add_dft, scratch_4) = scratch_3.take_vec_znx_dft_slice(extension_factor, n_glwe, cols, brk.size()); let (mut acc_add_dft, scratch_4) = scratch_3.take_vec_znx_dft_slice(module, extension_factor, cols, brk.size());
let (mut vmp_xai, scratch_5) = scratch_4.take_vec_znx_dft(n_glwe, 1, brk.size()); let (mut vmp_xai, scratch_5) = scratch_4.take_vec_znx_dft(module, 1, brk.size());
(0..extension_factor).for_each(|i| { (0..extension_factor).for_each(|i| {
acc[i].zero(); acc[i].zero();
}); });
let x_pow_a: &Vec<SvpPPol<Vec<u8>, B>>; let x_pow_a: &Vec<SvpPPol<Vec<u8>, BE>>;
if let Some(b) = &brk.x_pow_a { if let Some(b) = &brk.x_pow_a {
x_pow_a = b x_pow_a = b
} else { } else {
@@ -268,7 +252,7 @@ fn execute_block_binary_extended<DataRes, DataIn, DataBrk, B: Backend>(
}); });
{ {
let (mut acc_add_big, scratch7) = scratch_5.take_vec_znx_big(n_glwe, 1, brk.size()); let (mut acc_add_big, scratch7) = scratch_5.take_vec_znx_big(module, 1, brk.size());
(0..extension_factor).for_each(|j| { (0..extension_factor).for_each(|j| {
(0..cols).for_each(|i| { (0..cols).for_each(|i| {
@@ -285,41 +269,32 @@ fn execute_block_binary_extended<DataRes, DataIn, DataBrk, B: Backend>(
}); });
} }
fn execute_block_binary<DataRes, DataIn, DataBrk, B: Backend>( fn execute_block_binary<DataRes, DataIn, DataBrk, M, BE: Backend>(
module: &Module<B>, module: &M,
res: &mut GLWE<DataRes>, res: &mut GLWE<DataRes>,
lwe: &LWE<DataIn>, lwe: &LWE<DataIn>,
lut: &LookUpTable, lut: &LookupTable,
brk: &BlindRotationKeyPrepared<DataBrk, CGGI, B>, brk: &BlindRotationKeyPrepared<DataBrk, CGGI, BE>,
scratch: &mut Scratch<B>, scratch: &mut Scratch<BE>,
) where ) where
DataRes: DataMut, DataRes: DataMut,
DataIn: DataRef, DataIn: DataRef,
DataBrk: DataRef, DataBrk: DataRef,
Module<B>: VecZnxBigBytesOf M: VecZnxDftBytesOf
+ VecZnxDftBytesOf + ModuleN
+ SvpPPolBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VecZnxIdftApplyTmpBytes
+ VecZnxIdftApply<B>
+ VecZnxDftAdd<B>
+ VecZnxDftAddInplace<B>
+ VecZnxDftApply<B>
+ VecZnxDftZero<B>
+ SvpApplyDftToDft<B>
+ VecZnxDftSubInplace<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxRotate + VecZnxRotate
+ VecZnxAddInplace + VecZnxDftApply<BE>
+ VecZnxSubInplace + VecZnxDftZero<BE>
+ VecZnxNormalize<B> + VmpApplyDftToDft<BE>
+ VecZnxNormalizeInplace<B> + SvpApplyDftToDft<BE>
+ VecZnxDftAddInplace<BE>
+ VecZnxDftSubInplace<BE>
+ VecZnxIdftApply<BE>
+ VecZnxBigAddSmallInplace<BE>
+ VecZnxBigNormalize<BE>
+ VecZnxCopy + VecZnxCopy
+ VecZnxMulXpMinusOneInplace<B> + VecZnxBigBytesOf,
+ VmpApplyDftToDft<B> Scratch<BE>: ScratchTakeCore<BE>,
+ VecZnxBigNormalize<B>,
Scratch<B>: TakeVecZnxDftSlice<B> + TakeVecZnxDft<B> + TakeVecZnxBig<B> + TakeVecZnxSlice + ScratchAvailable + TakeVecZnx,
{ {
let n_glwe: usize = brk.n_glwe().into(); let n_glwe: usize = brk.n_glwe().into();
let mut lwe_2n: Vec<i64> = vec![0i64; (lwe.n() + 1).into()]; // TODO: from scratch space let mut lwe_2n: Vec<i64> = vec![0i64; (lwe.n() + 1).into()]; // TODO: from scratch space
@@ -350,12 +325,12 @@ fn execute_block_binary<DataRes, DataIn, DataBrk, B: Backend>(
// ACC + [sum DFT(X^ai -1) * (DFT(ACC) x BRKi)] // ACC + [sum DFT(X^ai -1) * (DFT(ACC) x BRKi)]
let (mut acc_dft, scratch_1) = scratch.take_vec_znx_dft(n_glwe, cols, dnum); let (mut acc_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols, dnum);
let (mut vmp_res, scratch_2) = scratch_1.take_vec_znx_dft(n_glwe, cols, brk.size()); let (mut vmp_res, scratch_2) = scratch_1.take_vec_znx_dft(module, cols, brk.size());
let (mut acc_add_dft, scratch_3) = scratch_2.take_vec_znx_dft(n_glwe, cols, brk.size()); let (mut acc_add_dft, scratch_3) = scratch_2.take_vec_znx_dft(module, cols, brk.size());
let (mut vmp_xai, scratch_4) = scratch_3.take_vec_znx_dft(n_glwe, 1, brk.size()); let (mut vmp_xai, scratch_4) = scratch_3.take_vec_znx_dft(module, 1, brk.size());
let x_pow_a: &Vec<SvpPPol<Vec<u8>, B>>; let x_pow_a: &Vec<SvpPPol<Vec<u8>, BE>>;
if let Some(b) = &brk.x_pow_a { if let Some(b) = &brk.x_pow_a {
x_pow_a = b x_pow_a = b
} else { } else {
@@ -388,7 +363,7 @@ fn execute_block_binary<DataRes, DataIn, DataBrk, B: Backend>(
}); });
{ {
let (mut acc_add_big, scratch_5) = scratch_4.take_vec_znx_big(n_glwe, 1, brk.size()); let (mut acc_add_big, scratch_5) = scratch_4.take_vec_znx_big(module, 1, brk.size());
(0..cols).for_each(|i| { (0..cols).for_each(|i| {
module.vec_znx_idft_apply(&mut acc_add_big, 0, &acc_add_dft, i, scratch_5); module.vec_znx_idft_apply(&mut acc_add_big, 0, &acc_add_dft, i, scratch_5);
@@ -407,44 +382,19 @@ fn execute_block_binary<DataRes, DataIn, DataBrk, B: Backend>(
}); });
} }
fn execute_standard<DataRes, DataIn, DataBrk, B: Backend>( fn execute_standard<DataRes, DataIn, DataBrk, M, BE: Backend>(
module: &Module<B>, module: &M,
res: &mut GLWE<DataRes>, res: &mut GLWE<DataRes>,
lwe: &LWE<DataIn>, lwe: &LWE<DataIn>,
lut: &LookUpTable, lut: &LookupTable,
brk: &BlindRotationKeyPrepared<DataBrk, CGGI, B>, brk: &BlindRotationKeyPrepared<DataBrk, CGGI, BE>,
scratch: &mut Scratch<B>, scratch: &mut Scratch<BE>,
) where ) where
DataRes: DataMut, DataRes: DataMut,
DataIn: DataRef, DataIn: DataRef,
DataBrk: DataRef, DataBrk: DataRef,
Module<B>: VecZnxBigBytesOf M: VecZnxRotate + GLWEExternalProduct<BE> + GLWEMulXpMinusOne<BE> + GLWEAdd + GLWENormalize<BE>,
+ VecZnxDftBytesOf Scratch<BE>: ScratchTakeCore<BE>,
+ SvpPPolBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VecZnxIdftApplyTmpBytes
+ VecZnxIdftApply<B>
+ VecZnxDftAdd<B>
+ VecZnxDftAddInplace<B>
+ VecZnxDftApply<B>
+ VecZnxDftZero<B>
+ SvpApplyDftToDft<B>
+ VecZnxDftSubInplace<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxRotate
+ VecZnxAddInplace
+ VecZnxSubInplace
+ VecZnxNormalize<B>
+ VecZnxNormalizeInplace<B>
+ VecZnxCopy
+ VecZnxMulXpMinusOneInplace<B>
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigNormalize<B>
+ VecZnxNormalizeTmpBytes,
Scratch<B>: TakeVecZnxDftSlice<B> + TakeVecZnxDft<B> + TakeVecZnxBig<B> + TakeVecZnxSlice + ScratchAvailable + TakeVecZnx,
{ {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
@@ -498,7 +448,7 @@ fn execute_standard<DataRes, DataIn, DataBrk, B: Backend>(
module.vec_znx_rotate(b, out_mut.data_mut(), 0, &lut.data[0], 0); module.vec_znx_rotate(b, out_mut.data_mut(), 0, &lut.data[0], 0);
// ACC + [sum DFT(X^ai -1) * (DFT(ACC) x BRKi)] // ACC + [sum DFT(X^ai -1) * (DFT(ACC) x BRKi)]
let (mut acc_tmp, scratch_1) = scratch.take_glwe_ct(&out_mut); let (mut acc_tmp, scratch_1) = scratch.take_glwe(&out_mut);
// TODO: see if faster by skipping normalization in external product and keeping acc in big coeffs // TODO: see if faster by skipping normalization in external product and keeping acc in big coeffs
// TODO: first iteration can be optimized to be a gglwe product // TODO: first iteration can be optimized to be a gglwe product
@@ -507,55 +457,13 @@ fn execute_standard<DataRes, DataIn, DataBrk, B: Backend>(
acc_tmp.external_product(module, &out_mut, ski, scratch_1); acc_tmp.external_product(module, &out_mut, ski, scratch_1);
// acc_tmp = (sk[i] * acc) * (X^{ai} - 1) // acc_tmp = (sk[i] * acc) * (X^{ai} - 1)
acc_tmp.mul_xp_minus_one_inplace(module, *ai, scratch_1); module.glwe_mul_xp_minus_one_inplace(*ai, &mut acc_tmp, scratch_1);
// acc = acc + (sk[i] * acc) * (X^{ai} - 1) // acc = acc + (sk[i] * acc) * (X^{ai} - 1)
out_mut.add_inplace(module, &acc_tmp); module.glwe_add_inplace(&mut out_mut, &acc_tmp);
}); });
// We can normalize only at the end because we add normalized values in [-2^{base2k-1}, 2^{base2k-1}] // We can normalize only at the end because we add normalized values in [-2^{base2k-1}, 2^{base2k-1}]
// on top of each others, thus ~ 2^{63-base2k} additions are supported before overflow. // on top of each others, thus ~ 2^{63-base2k} additions are supported before overflow.
out_mut.normalize_inplace(module, scratch_1); module.glwe_normalize_inplace(&mut out_mut, scratch_1);
}
pub fn mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWE<&[u8]>, rot_dir: LookUpTableRotationDirection) {
let base2k: usize = lwe.base2k().into();
let log2n: usize = usize::BITS as usize - (n - 1).leading_zeros() as usize + 1;
res.copy_from_slice(lwe.data().at(0, 0));
match rot_dir {
LookUpTableRotationDirection::Left => {
res.iter_mut().for_each(|x| *x = -*x);
}
LookUpTableRotationDirection::Right => {}
}
if base2k > log2n {
let diff: usize = base2k - (log2n - 1); // additional -1 because we map to [-N/2, N/2) instead of [0, N)
res.iter_mut().for_each(|x| {
*x = div_round_by_pow2(x, diff);
})
} else {
let rem: usize = base2k - (log2n % base2k);
let size: usize = log2n.div_ceil(base2k);
(1..size).for_each(|i| {
if i == size - 1 && rem != base2k {
let k_rem: usize = base2k - rem;
izip!(lwe.data().at(0, i).iter(), res.iter_mut()).for_each(|(x, y)| {
*y = (*y << k_rem) + (x >> rem);
});
} else {
izip!(lwe.data().at(0, i).iter(), res.iter_mut()).for_each(|(x, y)| {
*y = (*y << base2k) + x;
});
}
})
}
}
#[inline(always)]
fn div_round_by_pow2(x: &i64, k: usize) -> i64 {
(x + (1 << (k - 1))) >> k
} }

View File

@@ -0,0 +1,79 @@
use poulpy_hal::{
layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxView, ZnxViewMut},
source::Source,
};
use std::marker::PhantomData;
use poulpy_core::{
Distribution, GGSWEncryptSk, GetDistribution, ScratchTakeCore,
layouts::{GGSW, GGSWInfos, GLWEInfos, GLWESecretPreparedToRef, LWEInfos, LWESecret, LWESecretToRef},
};
use crate::tfhe::blind_rotation::{
BlindRotationKey, BlindRotationKeyEncryptSk, BlindRotationKeyFactory, BlindRotationKeyInfos, CGGI,
};
impl<D: DataRef> BlindRotationKeyFactory<CGGI> for BlindRotationKey<D, CGGI> {
fn blind_rotation_key_alloc<A>(infos: &A) -> BlindRotationKey<Vec<u8>, CGGI>
where
A: BlindRotationKeyInfos,
{
BlindRotationKey {
keys: (0..infos.n_lwe().as_usize())
.map(|_| GGSW::alloc_from_infos(infos))
.collect(),
dist: Distribution::NONE,
_phantom: PhantomData,
}
}
}
impl<BE: Backend> BlindRotationKeyEncryptSk<BE, CGGI> for Module<BE>
where
Self: GGSWEncryptSk<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
fn blind_rotation_key_encrypt_sk_tmp_bytes<A: GGSWInfos>(&self, infos: &A) -> usize {
self.ggsw_encrypt_sk_tmp_bytes(infos)
}
fn blind_rotation_key_encrypt_sk<D, S0, S1>(
&self,
res: &mut BlindRotationKey<D, CGGI>,
sk_glwe: &S0,
sk_lwe: &S1,
source_xa: &mut Source,
source_xe: &mut Source,
scratch: &mut Scratch<BE>,
) where
D: DataMut,
S0: GLWESecretPreparedToRef<BE> + GLWEInfos,
S1: LWESecretToRef + LWEInfos + GetDistribution,
{
assert_eq!(res.keys.len() as u32, sk_lwe.n());
assert!(sk_glwe.n() <= self.n() as u32);
assert_eq!(sk_glwe.rank(), res.rank());
match sk_lwe.dist() {
Distribution::BinaryBlock(_) | Distribution::BinaryFixed(_) | Distribution::BinaryProb(_) | Distribution::ZERO => {}
_ => {
panic!("invalid GLWESecret distribution: must be BinaryBlock, BinaryFixed or BinaryProb (or ZERO for debugging)")
}
}
{
let sk_lwe: &LWESecret<&[u8]> = &sk_lwe.to_ref();
res.dist = sk_lwe.dist();
let mut pt: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(sk_glwe.n().into(), 1);
let sk_ref: ScalarZnx<&[u8]> = sk_lwe.data().to_ref();
for (i, ggsw) in res.keys.iter_mut().enumerate() {
pt.at_mut(0, 0)[0] = sk_ref.at(0, 0)[i];
ggsw.encrypt_sk(self, &pt, sk_glwe, source_xa, source_xe, scratch);
}
}
}
}

View File

@@ -0,0 +1,84 @@
use std::marker::PhantomData;
use poulpy_core::{
Distribution, GGSWCompressedEncryptSk, GetDistribution, ScratchTakeCore,
layouts::{GGSWCompressed, GGSWInfos, GLWEInfos, GLWESecretPreparedToRef, LWEInfos, LWESecret, LWESecretToRef},
};
use poulpy_hal::{
layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxView, ZnxViewMut},
source::Source,
};
use crate::tfhe::blind_rotation::{
BlindRotationKeyCompressed, BlindRotationKeyCompressedEncryptSk, BlindRotationKeyCompressedFactory, BlindRotationKeyInfos,
CGGI,
};
impl<D: DataRef> BlindRotationKeyCompressedFactory<CGGI> for BlindRotationKeyCompressed<D, CGGI> {
fn blind_rotation_key_compressed_alloc<A>(infos: &A) -> BlindRotationKeyCompressed<Vec<u8>, CGGI>
where
A: BlindRotationKeyInfos,
{
let mut data: Vec<GGSWCompressed<Vec<u8>>> = Vec::with_capacity(infos.n_lwe().into());
(0..infos.n_lwe().as_usize()).for_each(|_| data.push(GGSWCompressed::alloc_from_infos(infos)));
BlindRotationKeyCompressed {
keys: data,
dist: Distribution::NONE,
_phantom: PhantomData,
}
}
}
impl<BE: Backend> BlindRotationKeyCompressedEncryptSk<BE, CGGI> for Module<BE>
where
Self: GGSWCompressedEncryptSk<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
fn blind_rotation_key_compressed_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
where
A: GGSWInfos,
{
self.ggsw_compressed_encrypt_sk_tmp_bytes(infos)
}
fn blind_rotation_key_compressed_encrypt_sk<D, S0, S1>(
&self,
res: &mut BlindRotationKeyCompressed<D, CGGI>,
sk_glwe: &S0,
sk_lwe: &S1,
seed_xa: [u8; 32],
source_xe: &mut Source,
scratch: &mut Scratch<BE>,
) where
D: DataMut,
S0: GLWESecretPreparedToRef<BE> + GLWEInfos,
S1: LWESecretToRef + LWEInfos + GetDistribution,
{
assert_eq!(res.keys.len() as u32, sk_lwe.n());
assert!(sk_glwe.n() <= self.n() as u32);
assert_eq!(sk_glwe.rank(), res.rank());
match sk_lwe.dist() {
Distribution::BinaryBlock(_) | Distribution::BinaryFixed(_) | Distribution::BinaryProb(_) | Distribution::ZERO => {}
_ => {
panic!("invalid GLWESecret distribution: must be BinaryBlock, BinaryFixed or BinaryProb (or ZERO for debugging)")
}
}
{
let sk_lwe: &LWESecret<&[u8]> = &sk_lwe.to_ref();
let mut source_xa: Source = Source::new(seed_xa);
res.dist = sk_lwe.dist();
let mut pt: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(sk_glwe.n().into(), 1);
let sk_ref: ScalarZnx<&[u8]> = sk_lwe.data().to_ref();
for (i, ggsw) in res.keys.iter_mut().enumerate() {
pt.at_mut(0, 0)[0] = sk_ref.at(0, 0)[i];
ggsw.encrypt_sk(self, &pt, sk_glwe, source_xa.new_seed(), source_xe, scratch);
}
}
}
}

View File

@@ -0,0 +1,69 @@
use poulpy_hal::{
api::{SvpPPolAlloc, SvpPrepare},
layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, SvpPPol},
};
use std::marker::PhantomData;
use poulpy_core::{
Distribution,
layouts::{GGSWPreparedFactory, LWEInfos, prepared::GGSWPrepared},
};
use crate::tfhe::blind_rotation::{
BlindRotationKey, BlindRotationKeyInfos, BlindRotationKeyPrepared, BlindRotationKeyPreparedFactory, CGGI,
utils::set_xai_plus_y,
};
impl<BE: Backend> BlindRotationKeyPreparedFactory<BE, CGGI> for Module<BE>
where
Self: GGSWPreparedFactory<BE> + SvpPPolAlloc<BE> + SvpPrepare<BE>,
{
fn blind_rotation_key_prepared_alloc<A>(&self, infos: &A) -> BlindRotationKeyPrepared<Vec<u8>, CGGI, BE>
where
A: BlindRotationKeyInfos,
{
BlindRotationKeyPrepared {
data: (0..infos.n_lwe().as_usize())
.map(|_| GGSWPrepared::alloc_from_infos(self, infos))
.collect(),
dist: Distribution::NONE,
x_pow_a: None,
_phantom: PhantomData,
}
}
fn blind_rotation_key_prepare<DM, DR>(
&self,
res: &mut BlindRotationKeyPrepared<DM, CGGI, BE>,
other: &BlindRotationKey<DR, CGGI>,
scratch: &mut Scratch<BE>,
) where
DM: DataMut,
DR: DataRef,
{
#[cfg(debug_assertions)]
{
assert_eq!(res.data.len(), other.keys.len());
}
let n: usize = other.n().as_usize();
for (a, b) in res.data.iter_mut().zip(other.keys.iter()) {
a.prepare(self, b, scratch);
}
res.dist = other.dist;
if let Distribution::BinaryBlock(_) = other.dist {
let mut x_pow_a: Vec<SvpPPol<Vec<u8>, BE>> = Vec::with_capacity(n << 1);
let mut buf: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n, 1);
(0..n << 1).for_each(|i| {
let mut res: SvpPPol<Vec<u8>, BE> = self.svp_ppol_alloc(1);
set_xai_plus_y(self, i, 0, &mut res, &mut buf);
x_pow_a.push(res);
});
res.x_pow_a = Some(x_pow_a);
}
}
}

View File

@@ -0,0 +1,10 @@
mod algorithm;
mod key;
mod key_compressed;
mod key_prepared;
use crate::tfhe::blind_rotation::BlindRotationAlgo;
#[derive(Clone)]
pub struct CGGI {}
impl BlindRotationAlgo for CGGI {}

View File

@@ -0,0 +1,116 @@
mod cggi;
pub use cggi::*;
use itertools::izip;
use poulpy_core::{
ScratchTakeCore,
layouts::{GGSWInfos, GLWE, GLWEInfos, LWE, LWEInfos},
};
use poulpy_hal::layouts::{Backend, DataMut, DataRef, Scratch, ZnxView};
use crate::tfhe::blind_rotation::{BlindRotationKeyInfos, BlindRotationKeyPrepared, LookUpTableRotationDirection, LookupTable};
pub trait BlindRotationAlgo {}
pub trait BlindRotationExecute<BRA: BlindRotationAlgo, BE: Backend> {
fn blind_rotation_execute_tmp_bytes<G, B>(
&self,
block_size: usize,
extension_factor: usize,
glwe_infos: &G,
brk_infos: &B,
) -> usize
where
G: GLWEInfos,
B: GGSWInfos;
fn blind_rotation_execute<DR, DL, DB>(
&self,
res: &mut GLWE<DR>,
lwe: &LWE<DL>,
lut: &LookupTable,
brk: &BlindRotationKeyPrepared<DB, BRA, BE>,
scratch: &mut Scratch<BE>,
) where
DR: DataMut,
DL: DataRef,
DB: DataRef;
}
impl<D: DataRef, BRA: BlindRotationAlgo, BE: Backend> BlindRotationKeyPrepared<D, BRA, BE>
where
Scratch<BE>: ScratchTakeCore<BE>,
{
pub fn execute<DR: DataMut, DI: DataRef, M>(
&self,
module: &M,
res: &mut GLWE<DR>,
lwe: &LWE<DI>,
lut: &LookupTable,
scratch: &mut Scratch<BE>,
) where
M: BlindRotationExecute<BRA, BE>,
{
module.blind_rotation_execute(res, lwe, lut, self, scratch);
}
}
impl<BE: Backend, BRA: BlindRotationAlgo> BlindRotationKeyPrepared<Vec<u8>, BRA, BE> {
pub fn execute_tmp_bytes<A, B, M>(
module: &M,
block_size: usize,
extension_factor: usize,
glwe_infos: &A,
brk_infos: &B,
) -> usize
where
A: GLWEInfos,
B: BlindRotationKeyInfos,
M: BlindRotationExecute<BRA, BE>,
{
module.blind_rotation_execute_tmp_bytes(block_size, extension_factor, glwe_infos, brk_infos)
}
}
pub fn mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWE<&[u8]>, rot_dir: LookUpTableRotationDirection) {
let base2k: usize = lwe.base2k().into();
let log2n: usize = usize::BITS as usize - (n - 1).leading_zeros() as usize + 1;
res.copy_from_slice(lwe.data().at(0, 0));
match rot_dir {
LookUpTableRotationDirection::Left => {
res.iter_mut().for_each(|x| *x = -*x);
}
LookUpTableRotationDirection::Right => {}
}
if base2k > log2n {
let diff: usize = base2k - (log2n - 1); // additional -1 because we map to [-N/2, N/2) instead of [0, N)
res.iter_mut().for_each(|x| {
*x = div_round_by_pow2(x, diff);
})
} else {
let rem: usize = base2k - (log2n % base2k);
let size: usize = log2n.div_ceil(base2k);
(1..size).for_each(|i| {
if i == size - 1 && rem != base2k {
let k_rem: usize = base2k - rem;
izip!(lwe.data().at(0, i).iter(), res.iter_mut()).for_each(|(x, y)| {
*y = (*y << k_rem) + (x >> rem);
});
} else {
izip!(lwe.data().at(0, i).iter(), res.iter_mut()).for_each(|(x, y)| {
*y = (*y << base2k) + x;
});
}
})
}
}
#[inline(always)]
fn div_round_by_pow2(x: &i64, k: usize) -> i64 {
(x + (1 << (k - 1))) >> k
}

View File

@@ -1,223 +0,0 @@
use poulpy_hal::{
api::{
ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal,
VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume,
VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VmpPMatAlloc, VmpPrepare,
},
layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxView, ZnxViewMut},
source::Source,
};
use std::marker::PhantomData;
use poulpy_core::{
Distribution,
layouts::{
GGSW, GGSWInfos, LWESecret,
compressed::GGSWCompressed,
prepared::{GGSWPrepared, GLWESecretPrepared},
},
};
use crate::tfhe::blind_rotation::{
BlindRotationKey, BlindRotationKeyAlloc, BlindRotationKeyCompressed, BlindRotationKeyEncryptSk, BlindRotationKeyInfos,
BlindRotationKeyPrepared, BlindRotationKeyPreparedAlloc, CGGI,
};
impl BlindRotationKeyAlloc for BlindRotationKey<Vec<u8>, CGGI> {
fn alloc<A>(infos: &A) -> Self
where
A: BlindRotationKeyInfos,
{
let mut data: Vec<GGSW<Vec<u8>>> = Vec::with_capacity(infos.n_lwe().into());
for _ in 0..infos.n_lwe().as_usize() {
data.push(GGSW::alloc_from_infos(infos));
}
Self {
keys: data,
dist: Distribution::NONE,
_phantom: PhantomData,
}
}
}
impl BlindRotationKey<Vec<u8>, CGGI> {
pub fn generate_from_sk_tmp_bytes<B: Backend, A>(module: &Module<B>, infos: &A) -> usize
where
A: GGSWInfos,
Module<B>: VecZnxNormalizeTmpBytes + VecZnxDftBytesOf,
{
GGSW::encrypt_sk_tmp_bytes(module, infos)
}
}
impl<D: DataMut, B: Backend> BlindRotationKeyEncryptSk<B> for BlindRotationKey<D, CGGI>
where
Module<B>: VecZnxAddScalarInplace
+ VecZnxDftBytesOf
+ VecZnxBigNormalize<B>
+ VecZnxDftApply<B>
+ SvpApplyDftToDftInplace<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxNormalizeTmpBytes
+ VecZnxFillUniform
+ VecZnxSubInplace
+ VecZnxAddInplace
+ VecZnxNormalizeInplace<B>
+ VecZnxAddNormal
+ VecZnxNormalize<B>
+ VecZnxSub,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
{
fn encrypt_sk<DataSkGLWE, DataSkLWE>(
&mut self,
module: &Module<B>,
sk_glwe: &GLWESecretPrepared<DataSkGLWE, B>,
sk_lwe: &LWESecret<DataSkLWE>,
source_xa: &mut Source,
source_xe: &mut Source,
scratch: &mut Scratch<B>,
) where
DataSkGLWE: DataRef,
DataSkLWE: DataRef,
{
#[cfg(debug_assertions)]
{
use poulpy_core::layouts::{GLWEInfos, LWEInfos};
assert_eq!(self.keys.len() as u32, sk_lwe.n());
assert!(sk_glwe.n() <= module.n() as u32);
assert_eq!(sk_glwe.rank(), self.rank());
match sk_lwe.dist() {
Distribution::BinaryBlock(_)
| Distribution::BinaryFixed(_)
| Distribution::BinaryProb(_)
| Distribution::ZERO => {}
_ => panic!(
"invalid GLWESecret distribution: must be BinaryBlock, BinaryFixed or BinaryProb (or ZERO for debugging)"
),
}
}
self.dist = sk_lwe.dist();
let mut pt: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(sk_glwe.n().into(), 1);
let sk_ref: ScalarZnx<&[u8]> = sk_lwe.data().to_ref();
self.keys.iter_mut().enumerate().for_each(|(i, ggsw)| {
pt.at_mut(0, 0)[0] = sk_ref.at(0, 0)[i];
ggsw.encrypt_sk(module, &pt, sk_glwe, source_xa, source_xe, scratch);
});
}
}
impl<B: Backend> BlindRotationKeyPreparedAlloc<B> for BlindRotationKeyPrepared<Vec<u8>, CGGI, B>
where
Module<B>: VmpPMatAlloc<B> + VmpPrepare<B>,
{
fn alloc<A>(module: &Module<B>, infos: &A) -> Self
where
A: BlindRotationKeyInfos,
{
let mut data: Vec<GGSWPrepared<Vec<u8>, B>> = Vec::with_capacity(infos.n_lwe().into());
(0..infos.n_lwe().as_usize()).for_each(|_| data.push(GGSWPrepared::alloc_from_infos(module, infos)));
Self {
data,
dist: Distribution::NONE,
x_pow_a: None,
_phantom: PhantomData,
}
}
}
impl BlindRotationKeyCompressed<Vec<u8>, CGGI> {
pub fn alloc<A>(infos: &A) -> Self
where
A: BlindRotationKeyInfos,
{
let mut data: Vec<GGSWCompressed<Vec<u8>>> = Vec::with_capacity(infos.n_lwe().into());
(0..infos.n_lwe().as_usize()).for_each(|_| data.push(GGSWCompressed::alloc_from_infos(infos)));
Self {
keys: data,
dist: Distribution::NONE,
_phantom: PhantomData,
}
}
pub fn generate_from_sk_tmp_bytes<B: Backend, A>(module: &Module<B>, infos: &A) -> usize
where
A: GGSWInfos,
Module<B>: VecZnxNormalizeTmpBytes + VecZnxDftBytesOf,
{
GGSWCompressed::encrypt_sk_tmp_bytes(module, infos)
}
}
impl<D: DataMut> BlindRotationKeyCompressed<D, CGGI> {
#[allow(clippy::too_many_arguments)]
pub fn encrypt_sk<DataSkGLWE, DataSkLWE, B: Backend>(
&mut self,
module: &Module<B>,
sk_glwe: &GLWESecretPrepared<DataSkGLWE, B>,
sk_lwe: &LWESecret<DataSkLWE>,
seed_xa: [u8; 32],
source_xe: &mut Source,
scratch: &mut Scratch<B>,
) where
DataSkGLWE: DataRef,
DataSkLWE: DataRef,
Module<B>: VecZnxAddScalarInplace
+ VecZnxDftBytesOf
+ VecZnxBigNormalize<B>
+ VecZnxDftApply<B>
+ SvpApplyDftToDftInplace<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxNormalizeTmpBytes
+ VecZnxFillUniform
+ VecZnxSubInplace
+ VecZnxAddInplace
+ VecZnxNormalizeInplace<B>
+ VecZnxAddNormal
+ VecZnxNormalize<B>
+ VecZnxSub,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
{
#[cfg(debug_assertions)]
{
use poulpy_core::layouts::{GLWEInfos, LWEInfos};
assert_eq!(self.n_lwe(), sk_lwe.n());
assert!(sk_glwe.n() <= module.n() as u32);
assert_eq!(sk_glwe.rank(), self.rank());
match sk_lwe.dist() {
Distribution::BinaryBlock(_)
| Distribution::BinaryFixed(_)
| Distribution::BinaryProb(_)
| Distribution::ZERO => {}
_ => panic!(
"invalid GLWESecret distribution: must be BinaryBlock, BinaryFixed or BinaryProb (or ZERO for debugging)"
),
}
}
self.dist = sk_lwe.dist();
let mut pt: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(sk_glwe.n().into(), 1);
let sk_ref: ScalarZnx<&[u8]> = sk_lwe.data().to_ref();
let mut source_xa: Source = Source::new(seed_xa);
self.keys.iter_mut().enumerate().for_each(|(i, ggsw)| {
pt.at_mut(0, 0)[0] = sk_ref.at(0, 0)[i];
ggsw.encrypt_sk(
module,
&pt,
sk_glwe,
source_xa.new_seed(),
source_xe,
scratch,
);
});
}
}

View File

@@ -0,0 +1,60 @@
use poulpy_hal::{
layouts::{Backend, DataMut, Scratch},
source::Source,
};
use poulpy_core::{
GetDistribution, ScratchTakeCore,
layouts::{GGSWInfos, GLWEInfos, GLWESecretPreparedToRef, LWEInfos, LWESecretToRef},
};
use crate::tfhe::blind_rotation::{BlindRotationAlgo, BlindRotationKey};
pub trait BlindRotationKeyEncryptSk<B: Backend, BRA: BlindRotationAlgo> {
fn blind_rotation_key_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
where
A: GGSWInfos;
#[allow(clippy::too_many_arguments)]
fn blind_rotation_key_encrypt_sk<D, S0, S1>(
&self,
res: &mut BlindRotationKey<D, BRA>,
sk_glwe: &S0,
sk_lwe: &S1,
source_xa: &mut Source,
source_xe: &mut Source,
scratch: &mut Scratch<B>,
) where
D: DataMut,
S0: GLWESecretPreparedToRef<B> + GLWEInfos,
S1: LWESecretToRef + LWEInfos + GetDistribution;
}
impl<D: DataMut, BRA: BlindRotationAlgo> BlindRotationKey<D, BRA> {
pub fn encrypt_sk<M, S0, S1, BE: Backend>(
&mut self,
module: &M,
sk_glwe: &S0,
sk_lwe: &S1,
source_xa: &mut Source,
source_xe: &mut Source,
scratch: &mut Scratch<BE>,
) where
S0: GLWESecretPreparedToRef<BE> + GLWEInfos,
S1: LWESecretToRef + LWEInfos + GetDistribution,
Scratch<BE>: ScratchTakeCore<BE>,
M: BlindRotationKeyEncryptSk<BE, BRA>,
{
module.blind_rotation_key_encrypt_sk(self, sk_glwe, sk_lwe, source_xa, source_xe, scratch);
}
}
impl<BRA: BlindRotationAlgo> BlindRotationKey<Vec<u8>, BRA> {
pub fn encrypt_sk_tmp_bytes<A, M, BE: Backend>(module: &M, infos: &A) -> usize
where
A: GGSWInfos,
M: BlindRotationKeyEncryptSk<BE, BRA>,
{
module.blind_rotation_key_encrypt_sk_tmp_bytes(infos)
}
}

View File

@@ -0,0 +1,30 @@
use poulpy_core::{
GetDistribution,
layouts::{GGSWInfos, GLWEInfos, GLWESecretPreparedToRef, LWEInfos, LWESecretToRef},
};
use poulpy_hal::{
layouts::{Backend, DataMut, Scratch},
source::Source,
};
use crate::tfhe::blind_rotation::{BlindRotationAlgo, BlindRotationKeyCompressed};
pub trait BlindRotationKeyCompressedEncryptSk<B: Backend, BRA: BlindRotationAlgo> {
fn blind_rotation_key_compressed_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
where
A: GGSWInfos;
#[allow(clippy::too_many_arguments)]
fn blind_rotation_key_compressed_encrypt_sk<D, S0, S1>(
&self,
res: &mut BlindRotationKeyCompressed<D, BRA>,
sk_glwe: &S0,
sk_lwe: &S1,
seed_xa: [u8; 32],
source_xe: &mut Source,
scratch: &mut Scratch<B>,
) where
D: DataMut,
S0: GLWESecretPreparedToRef<B> + GLWEInfos,
S1: LWESecretToRef + LWEInfos + GetDistribution;
}

View File

@@ -0,0 +1,5 @@
mod key;
mod key_compressed;
pub use key::*;
pub use key_compressed::*;

View File

@@ -1,130 +0,0 @@
use poulpy_hal::{
api::{SvpPPolAlloc, SvpPrepare, VmpPMatAlloc, VmpPrepare},
layouts::{Backend, Data, DataMut, DataRef, Module, ScalarZnx, Scratch, SvpPPol},
};
use std::marker::PhantomData;
use poulpy_core::{
Distribution,
layouts::{
Base2K, Degree, Dnum, Dsize, GGSWInfos, GLWEInfos, LWEInfos, Rank, TorusPrecision,
prepared::{GGSWPrepared, Prepare, PrepareAlloc},
},
};
use crate::tfhe::blind_rotation::{BlindRotationAlgo, BlindRotationKey, BlindRotationKeyInfos, utils::set_xai_plus_y};
pub trait BlindRotationKeyPreparedAlloc<B: Backend> {
fn alloc<A>(module: &Module<B>, infos: &A) -> Self
where
A: BlindRotationKeyInfos;
}
#[derive(PartialEq, Eq)]
pub struct BlindRotationKeyPrepared<D: Data, BRT: BlindRotationAlgo, B: Backend> {
pub(crate) data: Vec<GGSWPrepared<D, B>>,
pub(crate) dist: Distribution,
pub(crate) x_pow_a: Option<Vec<SvpPPol<Vec<u8>, B>>>,
pub(crate) _phantom: PhantomData<BRT>,
}
impl<D: Data, BRT: BlindRotationAlgo, B: Backend> BlindRotationKeyInfos for BlindRotationKeyPrepared<D, BRT, B> {
fn n_glwe(&self) -> Degree {
self.n()
}
fn n_lwe(&self) -> Degree {
Degree(self.data.len() as u32)
}
}
impl<D: Data, BRT: BlindRotationAlgo, B: Backend> LWEInfos for BlindRotationKeyPrepared<D, BRT, B> {
fn base2k(&self) -> Base2K {
self.data[0].base2k()
}
fn k(&self) -> TorusPrecision {
self.data[0].k()
}
fn n(&self) -> Degree {
self.data[0].n()
}
fn size(&self) -> usize {
self.data[0].size()
}
}
impl<D: Data, BRT: BlindRotationAlgo, B: Backend> GLWEInfos for BlindRotationKeyPrepared<D, BRT, B> {
fn rank(&self) -> Rank {
self.data[0].rank()
}
}
impl<D: Data, BRT: BlindRotationAlgo, B: Backend> GGSWInfos for BlindRotationKeyPrepared<D, BRT, B> {
fn dsize(&self) -> poulpy_core::layouts::Dsize {
Dsize(1)
}
fn dnum(&self) -> Dnum {
self.data[0].dnum()
}
}
impl<D: Data, BRT: BlindRotationAlgo, B: Backend> BlindRotationKeyPrepared<D, BRT, B> {
pub fn block_size(&self) -> usize {
match self.dist {
Distribution::BinaryBlock(value) => value,
_ => 1,
}
}
}
impl<D: DataRef, BRA: BlindRotationAlgo, B: Backend> PrepareAlloc<B, BlindRotationKeyPrepared<Vec<u8>, BRA, B>>
for BlindRotationKey<D, BRA>
where
BlindRotationKeyPrepared<Vec<u8>, BRA, B>: BlindRotationKeyPreparedAlloc<B>,
BlindRotationKeyPrepared<Vec<u8>, BRA, B>: Prepare<B, BlindRotationKey<D, BRA>>,
{
fn prepare_alloc(&self, module: &Module<B>, scratch: &mut Scratch<B>) -> BlindRotationKeyPrepared<Vec<u8>, BRA, B> {
let mut brk: BlindRotationKeyPrepared<Vec<u8>, BRA, B> = BlindRotationKeyPrepared::alloc(module, self);
brk.prepare(module, self, scratch);
brk
}
}
impl<DM: DataMut, DR: DataRef, BRA: BlindRotationAlgo, B: Backend> Prepare<B, BlindRotationKey<DR, BRA>>
for BlindRotationKeyPrepared<DM, BRA, B>
where
Module<B>: VmpPMatAlloc<B> + VmpPrepare<B> + SvpPPolAlloc<B> + SvpPrepare<B>,
{
fn prepare(&mut self, module: &Module<B>, other: &BlindRotationKey<DR, BRA>, scratch: &mut Scratch<B>) {
#[cfg(debug_assertions)]
{
assert_eq!(self.data.len(), other.keys.len());
}
let n: usize = other.n().as_usize();
self.data
.iter_mut()
.zip(other.keys.iter())
.for_each(|(ggsw_prepared, other)| {
ggsw_prepared.prepare(module, other, scratch);
});
self.dist = other.dist;
if let Distribution::BinaryBlock(_) = other.dist {
let mut x_pow_a: Vec<SvpPPol<Vec<u8>, B>> = Vec::with_capacity(n << 1);
let mut buf: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n, 1);
(0..n << 1).for_each(|i| {
let mut res: SvpPPol<Vec<u8>, B> = module.svp_ppol_alloc(1);
set_xai_plus_y(module, i, 0, &mut res, &mut buf);
x_pow_a.push(res);
});
self.x_pow_a = Some(x_pow_a);
}
}
}

View File

@@ -1,5 +1,5 @@
use poulpy_hal::{ use poulpy_hal::{
layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, Scratch, WriterTo}, layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, WriterTo},
source::Source, source::Source,
}; };
@@ -7,10 +7,7 @@ use std::{fmt, marker::PhantomData};
use poulpy_core::{ use poulpy_core::{
Distribution, Distribution,
layouts::{ layouts::{Base2K, Degree, Dnum, Dsize, GGSW, GGSWInfos, GLWEInfos, LWEInfos, Rank, TorusPrecision},
Base2K, Degree, Dnum, Dsize, GGSW, GGSWInfos, GLWEInfos, LWEInfos, LWESecret, Rank, TorusPrecision,
prepared::GLWESecretPrepared,
},
}; };
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
@@ -81,21 +78,6 @@ pub trait BlindRotationKeyAlloc {
A: BlindRotationKeyInfos; A: BlindRotationKeyInfos;
} }
pub trait BlindRotationKeyEncryptSk<B: Backend> {
#[allow(clippy::too_many_arguments)]
fn encrypt_sk<DataSkGLWE, DataSkLWE>(
&mut self,
module: &Module<B>,
sk_glwe: &GLWESecretPrepared<DataSkGLWE, B>,
sk_lwe: &LWESecret<DataSkLWE>,
source_xa: &mut Source,
source_xe: &mut Source,
scratch: &mut Scratch<B>,
) where
DataSkGLWE: DataRef,
DataSkLWE: DataRef;
}
#[derive(Clone)] #[derive(Clone)]
pub struct BlindRotationKey<D: Data, BRT: BlindRotationAlgo> { pub struct BlindRotationKey<D: Data, BRT: BlindRotationAlgo> {
pub(crate) keys: Vec<GGSW<D>>, pub(crate) keys: Vec<GGSW<D>>,
@@ -103,6 +85,24 @@ pub struct BlindRotationKey<D: Data, BRT: BlindRotationAlgo> {
pub(crate) _phantom: PhantomData<BRT>, pub(crate) _phantom: PhantomData<BRT>,
} }
pub trait BlindRotationKeyFactory<BRA: BlindRotationAlgo> {
fn blind_rotation_key_alloc<A>(infos: &A) -> BlindRotationKey<Vec<u8>, BRA>
where
A: BlindRotationKeyInfos;
}
impl<BRA: BlindRotationAlgo> BlindRotationKey<Vec<u8>, BRA>
where
Self: BlindRotationKeyFactory<BRA>,
{
pub fn alloc<A>(infos: &A) -> BlindRotationKey<Vec<u8>, BRA>
where
A: BlindRotationKeyInfos,
{
Self::blind_rotation_key_alloc(infos)
}
}
impl<D: DataRef, BRT: BlindRotationAlgo> fmt::Debug for BlindRotationKey<D, BRT> { impl<D: DataRef, BRT: BlindRotationAlgo> fmt::Debug for BlindRotationKey<D, BRT> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{self}") write!(f, "{self}")

View File

@@ -20,6 +20,24 @@ pub struct BlindRotationKeyCompressed<D: Data, BRT: BlindRotationAlgo> {
pub(crate) _phantom: PhantomData<BRT>, pub(crate) _phantom: PhantomData<BRT>,
} }
pub trait BlindRotationKeyCompressedFactory<BRA: BlindRotationAlgo> {
fn blind_rotation_key_compressed_alloc<A>(infos: &A) -> BlindRotationKeyCompressed<Vec<u8>, BRA>
where
A: BlindRotationKeyInfos;
}
impl<BRA: BlindRotationAlgo> BlindRotationKeyCompressed<Vec<u8>, BRA>
where
Self: BlindRotationKeyCompressedFactory<BRA>,
{
pub fn alloc<A>(infos: &A) -> BlindRotationKeyCompressed<Vec<u8>, BRA>
where
A: BlindRotationKeyInfos,
{
Self::blind_rotation_key_compressed_alloc(infos)
}
}
impl<D: DataRef, BRT: BlindRotationAlgo> fmt::Debug for BlindRotationKeyCompressed<D, BRT> { impl<D: DataRef, BRT: BlindRotationAlgo> fmt::Debug for BlindRotationKeyCompressed<D, BRT> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{self}") write!(f, "{self}")

View File

@@ -0,0 +1,108 @@
use poulpy_hal::layouts::{Backend, Data, DataMut, DataRef, Scratch, SvpPPol};
use std::marker::PhantomData;
use poulpy_core::{
Distribution, ScratchTakeCore,
layouts::{Base2K, Degree, Dnum, Dsize, GGSWInfos, GLWEInfos, LWEInfos, Rank, TorusPrecision, prepared::GGSWPrepared},
};
use crate::tfhe::blind_rotation::{BlindRotationAlgo, BlindRotationKey, BlindRotationKeyInfos};
pub trait BlindRotationKeyPreparedFactory<BE: Backend, BRA: BlindRotationAlgo> {
fn blind_rotation_key_prepared_alloc<A>(&self, infos: &A) -> BlindRotationKeyPrepared<Vec<u8>, BRA, BE>
where
A: BlindRotationKeyInfos;
fn blind_rotation_key_prepare<DM, DR>(
&self,
res: &mut BlindRotationKeyPrepared<DM, BRA, BE>,
other: &BlindRotationKey<DR, BRA>,
scratch: &mut Scratch<BE>,
) where
DM: DataMut,
DR: DataRef,
Scratch<BE>: ScratchTakeCore<BE>;
}
impl<BE: Backend, BRA: BlindRotationAlgo> BlindRotationKeyPrepared<Vec<u8>, BRA, BE> {
pub fn alloc<A, M>(module: &M, infos: &A) -> Self
where
A: BlindRotationKeyInfos,
M: BlindRotationKeyPreparedFactory<BE, BRA>,
{
module.blind_rotation_key_prepared_alloc(infos)
}
}
impl<D: DataMut, BRA: BlindRotationAlgo, BE: Backend> BlindRotationKeyPrepared<D, BRA, BE>
where
Scratch<BE>: ScratchTakeCore<BE>,
{
pub fn prepare<DR: DataRef, M>(&mut self, module: &M, other: &BlindRotationKey<DR, BRA>, scratch: &mut Scratch<BE>)
where
M: BlindRotationKeyPreparedFactory<BE, BRA>,
{
module.blind_rotation_key_prepare(self, other, scratch);
}
}
#[derive(PartialEq, Eq)]
pub struct BlindRotationKeyPrepared<D: Data, BRT: BlindRotationAlgo, B: Backend> {
pub(crate) data: Vec<GGSWPrepared<D, B>>,
pub(crate) dist: Distribution,
pub(crate) x_pow_a: Option<Vec<SvpPPol<Vec<u8>, B>>>,
pub(crate) _phantom: PhantomData<BRT>,
}
impl<D: Data, BRT: BlindRotationAlgo, B: Backend> BlindRotationKeyInfos for BlindRotationKeyPrepared<D, BRT, B> {
fn n_glwe(&self) -> Degree {
self.n()
}
fn n_lwe(&self) -> Degree {
Degree(self.data.len() as u32)
}
}
impl<D: Data, BRT: BlindRotationAlgo, B: Backend> LWEInfos for BlindRotationKeyPrepared<D, BRT, B> {
fn base2k(&self) -> Base2K {
self.data[0].base2k()
}
fn k(&self) -> TorusPrecision {
self.data[0].k()
}
fn n(&self) -> Degree {
self.data[0].n()
}
fn size(&self) -> usize {
self.data[0].size()
}
}
impl<D: Data, BRT: BlindRotationAlgo, B: Backend> GLWEInfos for BlindRotationKeyPrepared<D, BRT, B> {
fn rank(&self) -> Rank {
self.data[0].rank()
}
}
impl<D: Data, BRT: BlindRotationAlgo, B: Backend> GGSWInfos for BlindRotationKeyPrepared<D, BRT, B> {
fn dsize(&self) -> poulpy_core::layouts::Dsize {
Dsize(1)
}
fn dnum(&self) -> Dnum {
self.data[0].dnum()
}
}
impl<D: Data, BRT: BlindRotationAlgo, B: Backend> BlindRotationKeyPrepared<D, BRT, B> {
pub fn block_size(&self) -> usize {
match self.dist {
Distribution::BinaryBlock(value) => value,
_ => 1,
}
}
}

View File

@@ -0,0 +1,6 @@
mod key;
mod key_compressed;
mod key_prepared;
pub use key::*;
pub use key_compressed::*;
pub use key_prepared::*;

View File

@@ -1,3 +1,4 @@
use poulpy_core::layouts::{Base2K, Degree, TorusPrecision};
use poulpy_hal::{ use poulpy_hal::{
api::{ api::{
ScratchOwnedAlloc, ScratchOwnedBorrow, TakeSlice, VecZnxCopy, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, ScratchOwnedAlloc, ScratchOwnedBorrow, TakeSlice, VecZnxCopy, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes,
@@ -13,32 +14,97 @@ pub enum LookUpTableRotationDirection {
Right, Right,
} }
pub struct LookUpTable { pub struct LookUpTableLayout {
pub n: Degree,
pub extension_factor: usize,
pub k: TorusPrecision,
pub base2k: Base2K,
}
pub trait LookupTableInfos {
fn n(&self) -> Degree;
fn extension_factor(&self) -> usize;
fn k(&self) -> TorusPrecision;
fn base2k(&self) -> Base2K;
fn size(&self) -> usize;
}
impl LookupTableInfos for LookUpTableLayout {
fn base2k(&self) -> Base2K {
self.base2k
}
fn extension_factor(&self) -> usize {
self.extension_factor
}
fn k(&self) -> TorusPrecision {
self.k
}
fn size(&self) -> usize {
self.k().as_usize().div_ceil(self.base2k().as_usize())
}
fn n(&self) -> Degree {
self.n
}
}
pub struct LookupTable {
pub(crate) data: Vec<VecZnx<Vec<u8>>>, pub(crate) data: Vec<VecZnx<Vec<u8>>>,
pub(crate) rot_dir: LookUpTableRotationDirection, pub(crate) rot_dir: LookUpTableRotationDirection,
pub(crate) base2k: usize, pub(crate) base2k: Base2K,
pub(crate) k: usize, pub(crate) k: TorusPrecision,
pub(crate) drift: usize, pub(crate) drift: usize,
} }
impl LookUpTable { impl LookupTableInfos for LookupTable {
pub fn alloc<B: Backend>(module: &Module<B>, base2k: usize, k: usize, extension_factor: usize) -> Self { fn base2k(&self) -> Base2K {
self.base2k
}
fn extension_factor(&self) -> usize {
self.data.len()
}
fn k(&self) -> TorusPrecision {
self.k
}
fn n(&self) -> Degree {
self.data[0].n().into()
}
fn size(&self) -> usize {
self.data[0].size()
}
}
pub trait LookupTableFactory {
fn lookup_table_set(&self, res: &mut LookupTable, f: &[i64], k: usize);
fn lookup_table_rotate(&self, k: i64, res: &mut LookupTable);
}
impl LookupTable {
pub fn alloc<A>(infos: &A) -> Self
where
A: LookupTableInfos,
{
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert!( assert!(
extension_factor & (extension_factor - 1) == 0, infos.extension_factor() & (infos.extension_factor() - 1) == 0,
"extension_factor must be a power of two but is: {extension_factor}" "extension_factor must be a power of two but is: {}",
infos.extension_factor()
); );
} }
let size: usize = k.div_ceil(base2k);
let mut data: Vec<VecZnx<Vec<u8>>> = Vec::with_capacity(extension_factor);
(0..extension_factor).for_each(|_| {
data.push(VecZnx::alloc(module.n(), 1, size));
});
Self { Self {
data, data: (0..infos.extension_factor())
base2k, .map(|_| VecZnx::alloc(infos.n().into(), 1, infos.size()))
k, .collect(),
base2k: infos.base2k(),
k: infos.k(),
drift: 0, drift: 0,
rot_dir: LookUpTableRotationDirection::Left, rot_dir: LookUpTableRotationDirection::Left,
} }
@@ -68,115 +134,18 @@ impl LookUpTable {
self.rot_dir = rot_dir self.rot_dir = rot_dir
} }
pub fn set<B: Backend>(&mut self, module: &Module<B>, f: &[i64], k: usize) pub fn set<M>(&mut self, module: &M, f: &[i64], k: usize)
where where
Module<B>: VecZnxRotateInplace<B> M: LookupTableFactory,
+ VecZnxNormalizeInplace<B>
+ VecZnxNormalizeTmpBytes
+ VecZnxSwitchRing
+ VecZnxCopy
+ VecZnxRotateInplaceTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
Scratch<B>: TakeSlice,
{ {
assert!(f.len() <= module.n()); module.lookup_table_set(self, f, k);
let base2k: usize = self.base2k;
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(module.vec_znx_normalize_tmp_bytes() | (self.domain_size() << 3));
// Get the number minimum limb to store the message modulus
let limbs: usize = k.div_ceil(base2k);
#[cfg(debug_assertions)]
{
assert!(f.len() <= module.n());
assert!(
(max_bit_size(f) + (k % base2k) as u32) < i64::BITS,
"overflow: max(|f|) << (k%base2k) > i64::BITS"
);
assert!(limbs <= self.data[0].size());
}
// Scaling factor
let mut scale = 1;
if !k.is_multiple_of(base2k) {
scale <<= base2k - (k % base2k);
}
// #elements in lookup table
let f_len: usize = f.len();
// If LUT size > TakeScalarZnx
let domain_size: usize = self.domain_size();
let size: usize = self.k.div_ceil(self.base2k);
// Equivalent to AUTO([f(0), -f(n-1), -f(n-2), ..., -f(1)], -1)
let mut lut_full: VecZnx<Vec<u8>> = VecZnx::alloc(domain_size, 1, size);
let lut_at: &mut [i64] = lut_full.at_mut(0, limbs - 1);
let step: usize = domain_size.div_round(f_len);
f.iter().enumerate().for_each(|(i, fi)| {
let start: usize = i * step;
let end: usize = start + step;
lut_at[start..end].fill(fi * scale);
});
let drift: usize = step >> 1;
// Rotates half the step to the left
if self.extension_factor() > 1 {
let (tmp, _) = scratch.borrow().take_slice(lut_full.n());
for i in 0..self.extension_factor() {
module.vec_znx_switch_ring(&mut self.data[i], 0, &lut_full, 0);
if i < self.extension_factor() {
vec_znx_rotate_inplace::<_, ZnxRef>(-1, &mut lut_full, 0, tmp);
}
}
} else {
module.vec_znx_copy(&mut self.data[0], 0, &lut_full, 0);
}
for a in self.data.iter_mut() {
module.vec_znx_normalize_inplace(self.base2k, a, 0, scratch.borrow());
}
self.rotate(module, -(drift as i64));
self.drift = drift
} }
#[allow(dead_code)] pub(crate) fn rotate<M>(&mut self, module: &M, k: i64)
pub(crate) fn rotate<B: Backend>(&mut self, module: &Module<B>, k: i64)
where where
Module<B>: VecZnxRotateInplace<B> + VecZnxRotateInplaceTmpBytes, M: LookupTableFactory,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{ {
let extension_factor: usize = self.extension_factor(); module.lookup_table_rotate(k, self);
let two_n: usize = 2 * self.data[0].n();
let two_n_ext: usize = two_n * extension_factor;
let mut scratch: ScratchOwned<_> = ScratchOwned::alloc(module.vec_znx_rotate_inplace_tmp_bytes());
let k_pos: usize = ((k + two_n_ext as i64) % two_n_ext as i64) as usize;
let k_hi: usize = k_pos / extension_factor;
let k_lo: usize = k_pos % extension_factor;
(0..extension_factor - k_lo).for_each(|i| {
module.vec_znx_rotate_inplace(k_hi as i64, &mut self.data[i], 0, scratch.borrow());
});
(extension_factor - k_lo..extension_factor).for_each(|i| {
module.vec_znx_rotate_inplace(k_hi as i64 + 1, &mut self.data[i], 0, scratch.borrow());
});
self.data.rotate_right(k_lo);
} }
} }
@@ -204,3 +173,116 @@ fn max_bit_size(vec: &[i64]) -> u32 {
.max() .max()
.unwrap_or(0) .unwrap_or(0)
} }
impl<BE: Backend> LookupTableFactory for Module<BE>
where
Self: VecZnxRotateInplace<BE>
+ VecZnxNormalizeInplace<BE>
+ VecZnxNormalizeTmpBytes
+ VecZnxSwitchRing
+ VecZnxCopy
+ VecZnxRotateInplaceTmpBytes
+ VecZnxRotateInplace<BE>
+ VecZnxRotateInplaceTmpBytes,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: TakeSlice,
{
fn lookup_table_set(&self, res: &mut LookupTable, f: &[i64], k: usize) {
assert!(f.len() <= self.n());
let base2k: usize = res.base2k.into();
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(
self.vec_znx_normalize_tmp_bytes()
.max(res.domain_size() << 3),
);
// Get the number minimum limb to store the message modulus
let limbs: usize = k.div_ceil(base2k);
#[cfg(debug_assertions)]
{
assert!(f.len() <= self.n());
assert!(
(max_bit_size(f) + (k % base2k) as u32) < i64::BITS,
"overflow: max(|f|) << (k%base2k) > i64::BITS"
);
assert!(limbs <= res.data[0].size());
}
// Scaling factor
let mut scale = 1;
if !k.is_multiple_of(base2k) {
scale <<= base2k - (k % base2k);
}
// #elements in lookup table
let f_len: usize = f.len();
// If LUT size > TakeScalarZnx
let domain_size: usize = res.domain_size();
let size: usize = res.k.div_ceil(res.base2k) as usize;
// Equivalent to AUTO([f(0), -f(n-1), -f(n-2), ..., -f(1)], -1)
let mut lut_full: VecZnx<Vec<u8>> = VecZnx::alloc(domain_size, 1, size);
let lut_at: &mut [i64] = lut_full.at_mut(0, limbs - 1);
let step: usize = domain_size.div_round(f_len);
f.iter().enumerate().for_each(|(i, fi)| {
let start: usize = i * step;
let end: usize = start + step;
lut_at[start..end].fill(fi * scale);
});
let drift: usize = step >> 1;
// Rotates half the step to the left
if res.extension_factor() > 1 {
let (tmp, _) = scratch.borrow().take_slice(lut_full.n());
for i in 0..res.extension_factor() {
self.vec_znx_switch_ring(&mut res.data[i], 0, &lut_full, 0);
if i < res.extension_factor() {
vec_znx_rotate_inplace::<_, ZnxRef>(-1, &mut lut_full, 0, tmp);
}
}
} else {
self.vec_znx_copy(&mut res.data[0], 0, &lut_full, 0);
}
for a in res.data.iter_mut() {
self.vec_znx_normalize_inplace(res.base2k.into(), a, 0, scratch.borrow());
}
res.rotate(self, -(drift as i64));
res.drift = drift
}
fn lookup_table_rotate(&self, k: i64, res: &mut LookupTable) {
let extension_factor: usize = res.extension_factor();
let two_n: usize = 2 * res.data[0].n();
let two_n_ext: usize = two_n * extension_factor;
let mut scratch: ScratchOwned<_> = ScratchOwned::alloc(self.vec_znx_rotate_inplace_tmp_bytes());
let k_pos: usize = ((k + two_n_ext as i64) % two_n_ext as i64) as usize;
let k_hi: usize = k_pos / extension_factor;
let k_lo: usize = k_pos % extension_factor;
(0..extension_factor - k_lo).for_each(|i| {
self.vec_znx_rotate_inplace(k_hi as i64, &mut res.data[i], 0, scratch.borrow());
});
(extension_factor - k_lo..extension_factor).for_each(|i| {
self.vec_znx_rotate_inplace(k_hi as i64 + 1, &mut res.data[i], 0, scratch.borrow());
});
res.data.rotate_right(k_lo);
}
}

View File

@@ -1,35 +1,11 @@
mod cggi_algo; mod algorithms;
mod cggi_key; mod encryption;
mod key; mod layouts;
mod key_compressed;
mod key_prepared;
mod lut; mod lut;
mod utils; mod utils;
pub use cggi_algo::*; pub use algorithms::*;
pub use key::*; pub use encryption::*;
pub use key_compressed::*; pub use layouts::*;
pub use key_prepared::*;
pub use lut::*; pub use lut::*;
pub mod tests; pub mod tests;
use poulpy_core::layouts::{GLWE, LWE};
use poulpy_hal::layouts::{Backend, DataMut, DataRef, Module, Scratch};
pub trait BlindRotationAlgo {}
#[derive(Clone)]
pub struct CGGI {}
impl BlindRotationAlgo for CGGI {}
pub trait BlincRotationExecute<B: Backend> {
fn execute<DR: DataMut, DI: DataRef>(
&self,
module: &Module<B>,
res: &mut GLWE<DR>,
lwe: &LWE<DI>,
lut: &LookUpTable,
scratch: &mut Scratch<B>,
);
}

View File

@@ -1,88 +1,40 @@
use poulpy_hal::{ use poulpy_hal::{
api::{ api::{ScratchOwnedAlloc, ScratchOwnedBorrow},
ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, layouts::{Backend, Scratch, ScratchOwned, ZnxView},
SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace,
VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAdd, VecZnxDftAddInplace,
VecZnxDftApply, VecZnxDftBytesOf, VecZnxDftSubInplace, VecZnxDftZero, VecZnxFillUniform, VecZnxIdftApply,
VecZnxIdftApplyConsume, VecZnxIdftApplyTmpBytes, VecZnxMulXpMinusOneInplace, VecZnxNormalize, VecZnxNormalizeInplace,
VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxSub, VecZnxSubInplace,
VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, ZnAddNormal,
ZnFillUniform, ZnNormalizeInplace,
},
layouts::{Backend, Module, ScratchOwned, ZnxView},
oep::{
ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeSliceImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl,
TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl,
},
source::Source, source::Source,
}; };
use crate::tfhe::blind_rotation::{ use crate::tfhe::blind_rotation::{
BlincRotationExecute, BlindRotationKey, BlindRotationKeyAlloc, BlindRotationKeyEncryptSk, BlindRotationKeyLayout, BlindRotationAlgo, BlindRotationExecute, BlindRotationKey, BlindRotationKeyEncryptSk, BlindRotationKeyFactory,
BlindRotationKeyPrepared, CGGI, LookUpTable, cggi_blind_rotate_tmp_bytes, mod_switch_2n, BlindRotationKeyLayout, BlindRotationKeyPrepared, BlindRotationKeyPreparedFactory, LookUpTableLayout, LookupTable,
LookupTableFactory, mod_switch_2n,
}; };
use poulpy_core::layouts::{ use poulpy_core::{
GLWE, GLWELayout, GLWEPlaintext, GLWESecret, LWE, LWEInfos, LWELayout, LWEPlaintext, LWESecret, LWEToRef, GLWEDecrypt, LWEEncryptSk, ScratchTakeCore,
prepared::{GLWESecretPrepared, PrepareAlloc}, layouts::{
GLWE, GLWELayout, GLWEPlaintext, GLWESecret, GLWESecretPreparedFactory, LWE, LWEInfos, LWELayout, LWEPlaintext,
LWESecret, LWEToRef, prepared::GLWESecretPrepared,
},
}; };
pub fn test_blind_rotation<B>(module: &Module<B>, n_lwe: usize, block_size: usize, extension_factor: usize) pub fn test_blind_rotation<BRA: BlindRotationAlgo, M, BE: Backend>(
where module: &M,
Module<B>: VecZnxBigBytesOf n_lwe: usize,
+ VecZnxDftBytesOf block_size: usize,
+ SvpPPolBytesOf extension_factor: usize,
+ VmpApplyDftToDftTmpBytes ) where
+ VecZnxBigNormalizeTmpBytes M: BlindRotationKeyEncryptSk<BE, BRA>
+ VecZnxIdftApplyTmpBytes + BlindRotationKeyPreparedFactory<BE, BRA>
+ VecZnxIdftApply<B> + BlindRotationExecute<BRA, BE>
+ VecZnxDftAdd<B> + GLWESecretPreparedFactory<BE>
+ VecZnxDftAddInplace<B> + BlindRotationExecute<BRA, BE>
+ VecZnxDftApply<B> + LWEEncryptSk<BE>
+ VecZnxDftZero<B> + LookupTableFactory
+ SvpApplyDftToDft<B> + GLWEDecrypt<BE>,
+ VecZnxDftSubInplace<B> BlindRotationKey<Vec<u8>, BRA>: BlindRotationKeyFactory<BRA>,
+ VecZnxBigAddSmallInplace<B> ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
+ VecZnxRotate Scratch<BE>: ScratchTakeCore<BE>,
+ VecZnxAddInplace
+ VecZnxSubInplace
+ VecZnxNormalize<B>
+ VecZnxNormalizeInplace<B>
+ VecZnxCopy
+ VecZnxMulXpMinusOneInplace<B>
+ SvpPrepare<B>
+ SvpPPolAlloc<B>
+ SvpApplyDftToDftInplace<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddInplace<B>
+ VecZnxBigNormalize<B>
+ VecZnxNormalizeTmpBytes
+ VecZnxFillUniform
+ VecZnxAddNormal
+ VecZnxAddScalarInplace
+ VecZnxRotateInplace<B>
+ VecZnxSwitchRing
+ VecZnxSub
+ VmpPMatAlloc<B>
+ VmpPrepare<B>
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ ZnFillUniform
+ ZnAddNormal
+ VecZnxRotateInplaceTmpBytes
+ ZnNormalizeInplace<B>,
B: Backend
+ VecZnxDftAllocBytesImpl<B>
+ VecZnxBigAllocBytesImpl<B>
+ ScratchOwnedAllocImpl<B>
+ ScratchOwnedBorrowImpl<B>
+ TakeVecZnxDftImpl<B>
+ TakeVecZnxBigImpl<B>
+ TakeVecZnxDftSliceImpl<B>
+ ScratchAvailableImpl<B>
+ TakeVecZnxImpl<B>
+ TakeVecZnxSliceImpl<B>
+ TakeSliceImpl<B>,
{ {
let n_glwe: usize = module.n(); let n_glwe: usize = module.n();
let base2k: usize = 19; let base2k: usize = 19;
@@ -123,18 +75,17 @@ where
base2k: base2k.into(), base2k: base2k.into(),
}; };
let mut scratch: ScratchOwned<B> = ScratchOwned::<B>::alloc(BlindRotationKey::generate_from_sk_tmp_bytes( let mut scratch: ScratchOwned<BE> = ScratchOwned::<BE>::alloc(BlindRotationKey::encrypt_sk_tmp_bytes(module, &brk_infos));
module, &brk_infos,
));
let mut sk_glwe: GLWESecret<Vec<u8>> = GLWESecret::alloc_from_infos(&glwe_infos); let mut sk_glwe: GLWESecret<Vec<u8>> = GLWESecret::alloc_from_infos(&glwe_infos);
sk_glwe.fill_ternary_prob(0.5, &mut source_xs); sk_glwe.fill_ternary_prob(0.5, &mut source_xs);
let sk_glwe_dft: GLWESecretPrepared<Vec<u8>, B> = sk_glwe.prepare_alloc(module, scratch.borrow()); let mut sk_glwe_dft: GLWESecretPrepared<Vec<u8>, BE> = GLWESecretPrepared::alloc_from_infos(module, &glwe_infos);
sk_glwe_dft.prepare(module, &sk_glwe);
let mut sk_lwe: LWESecret<Vec<u8>> = LWESecret::alloc(n_lwe.into()); let mut sk_lwe: LWESecret<Vec<u8>> = LWESecret::alloc(n_lwe.into());
sk_lwe.fill_binary_block(block_size, &mut source_xs); sk_lwe.fill_binary_block(block_size, &mut source_xs);
let mut scratch_br: ScratchOwned<B> = ScratchOwned::<B>::alloc(cggi_blind_rotate_tmp_bytes( let mut scratch_br: ScratchOwned<BE> = ScratchOwned::<BE>::alloc(BlindRotationKeyPrepared::execute_tmp_bytes(
module, module,
block_size, block_size,
extension_factor, extension_factor,
@@ -142,7 +93,7 @@ where
&brk_infos, &brk_infos,
)); ));
let mut brk: BlindRotationKey<Vec<u8>, CGGI> = BlindRotationKey::<Vec<u8>, CGGI>::alloc(&brk_infos); let mut brk: BlindRotationKey<Vec<u8>, BRA> = BlindRotationKey::<Vec<u8>, BRA>::alloc(&brk_infos);
brk.encrypt_sk( brk.encrypt_sk(
module, module,
@@ -171,12 +122,20 @@ where
.enumerate() .enumerate()
.for_each(|(i, x)| *x = f(i as i64)); .for_each(|(i, x)| *x = f(i as i64));
let mut lut: LookUpTable = LookUpTable::alloc(module, base2k, k_lut, extension_factor); let lut_infos = LookUpTableLayout {
n: module.n().into(),
extension_factor,
k: k_lut.into(),
base2k: base2k.into(),
};
let mut lut: LookupTable = LookupTable::alloc(&lut_infos);
lut.set(module, &f_vec, log_message_modulus + 1); lut.set(module, &f_vec, log_message_modulus + 1);
let mut res: GLWE<Vec<u8>> = GLWE::alloc_from_infos(&glwe_infos); let mut res: GLWE<Vec<u8>> = GLWE::alloc_from_infos(&glwe_infos);
let brk_prepared: BlindRotationKeyPrepared<Vec<u8>, CGGI, B> = brk.prepare_alloc(module, scratch.borrow()); let mut brk_prepared: BlindRotationKeyPrepared<Vec<u8>, BRA, BE> = BlindRotationKeyPrepared::alloc(module, &brk);
brk_prepared.prepare(module, &brk, scratch_br.borrow());
brk_prepared.execute(module, &mut res, &lwe, &lut, scratch_br.borrow()); brk_prepared.execute(module, &mut res, &lwe, &lut, scratch_br.borrow());

View File

@@ -1,25 +1,12 @@
use std::vec; use std::vec;
use poulpy_hal::{ use poulpy_hal::api::ModuleN;
api::{
VecZnxCopy, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes,
VecZnxSwitchRing,
},
layouts::{Backend, Module},
oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeSliceImpl},
};
use crate::tfhe::blind_rotation::{DivRound, LookUpTable}; use crate::tfhe::blind_rotation::{DivRound, LookUpTableLayout, LookupTable, LookupTableFactory};
pub fn test_lut_standard<B>(module: &Module<B>) pub fn test_lut_standard<M>(module: &M)
where where
Module<B>: VecZnxRotateInplace<B> M: LookupTableFactory + ModuleN,
+ VecZnxNormalizeInplace<B>
+ VecZnxNormalizeTmpBytes
+ VecZnxSwitchRing
+ VecZnxCopy
+ VecZnxRotateInplaceTmpBytes,
B: Backend + ScratchOwnedAllocImpl<B> + ScratchOwnedBorrowImpl<B> + TakeSliceImpl<B>,
{ {
let base2k: usize = 20; let base2k: usize = 20;
let k_lut: usize = 40; let k_lut: usize = 40;
@@ -33,7 +20,14 @@ where
.enumerate() .enumerate()
.for_each(|(i, x)| *x = (i as i64) - 8); .for_each(|(i, x)| *x = (i as i64) - 8);
let mut lut: LookUpTable = LookUpTable::alloc(module, base2k, k_lut, extension_factor); let lut_infos: LookUpTableLayout = LookUpTableLayout {
n: module.n().into(),
extension_factor,
k: k_lut.into(),
base2k: base2k.into(),
};
let mut lut: LookupTable = LookupTable::alloc(&lut_infos);
lut.set(module, &f, log_scale); lut.set(module, &f, log_scale);
let half_step: i64 = lut.domain_size().div_round(message_modulus << 1) as i64; let half_step: i64 = lut.domain_size().div_round(message_modulus << 1) as i64;
@@ -51,15 +45,9 @@ where
}); });
} }
pub fn test_lut_extended<B>(module: &Module<B>) pub fn test_lut_extended<M>(module: &M)
where where
Module<B>: VecZnxRotateInplace<B> M: LookupTableFactory + ModuleN,
+ VecZnxNormalizeInplace<B>
+ VecZnxNormalizeTmpBytes
+ VecZnxSwitchRing
+ VecZnxCopy
+ VecZnxRotateInplaceTmpBytes,
B: Backend + ScratchOwnedAllocImpl<B> + ScratchOwnedBorrowImpl<B> + TakeSliceImpl<B>,
{ {
let base2k: usize = 20; let base2k: usize = 20;
let k_lut: usize = 40; let k_lut: usize = 40;
@@ -73,7 +61,14 @@ where
.enumerate() .enumerate()
.for_each(|(i, x)| *x = (i as i64) - 8); .for_each(|(i, x)| *x = (i as i64) - 8);
let mut lut: LookUpTable = LookUpTable::alloc(module, base2k, k_lut, extension_factor); let lut_infos: LookUpTableLayout = LookUpTableLayout {
n: module.n().into(),
extension_factor,
k: k_lut.into(),
base2k: base2k.into(),
};
let mut lut: LookupTable = LookupTable::alloc(&lut_infos);
lut.set(module, &f, log_scale); lut.set(module, &f, log_scale);
let half_step: i64 = lut.domain_size().div_round(message_modulus << 1) as i64; let half_step: i64 = lut.domain_size().div_round(message_modulus << 1) as i64;

View File

@@ -1,8 +1,6 @@
use poulpy_hal::test_suite::serialization::test_reader_writer_interface; use poulpy_hal::test_suite::serialization::test_reader_writer_interface;
use crate::tfhe::blind_rotation::{ use crate::tfhe::blind_rotation::{BlindRotationKey, BlindRotationKeyCompressed, BlindRotationKeyLayout, CGGI};
BlindRotationKey, BlindRotationKeyAlloc, BlindRotationKeyCompressed, BlindRotationKeyLayout, CGGI,
};
#[test] #[test]
fn test_cggi_blind_rotation_key_serialization() { fn test_cggi_blind_rotation_key_serialization() {
@@ -14,7 +12,6 @@ fn test_cggi_blind_rotation_key_serialization() {
dnum: 2_usize.into(), dnum: 2_usize.into(),
rank: 2_usize.into(), rank: 2_usize.into(),
}; };
let original: BlindRotationKey<Vec<u8>, CGGI> = BlindRotationKey::alloc(&layout); let original: BlindRotationKey<Vec<u8>, CGGI> = BlindRotationKey::alloc(&layout);
test_reader_writer_interface(original); test_reader_writer_interface(original);
} }
@@ -29,7 +26,6 @@ fn test_cggi_blind_rotation_key_compressed_serialization() {
dnum: 2_usize.into(), dnum: 2_usize.into(),
rank: 2_usize.into(), rank: 2_usize.into(),
}; };
let original: BlindRotationKeyCompressed<Vec<u8>, CGGI> = BlindRotationKeyCompressed::alloc(&layout); let original: BlindRotationKeyCompressed<Vec<u8>, CGGI> = BlindRotationKeyCompressed::alloc(&layout);
test_reader_writer_interface(original); test_reader_writer_interface(original);
} }

View File

@@ -1,37 +0,0 @@
use poulpy_backend::cpu_spqlios::FFT64Spqlios;
use poulpy_hal::{api::ModuleNew, layouts::Module};
use crate::tfhe::blind_rotation::tests::{
generic_blind_rotation::test_blind_rotation,
generic_lut::{test_lut_extended, test_lut_standard},
};
#[test]
fn lut_standard() {
let module: Module<FFT64Spqlios> = Module::<FFT64Spqlios>::new(32);
test_lut_standard(&module);
}
#[test]
fn lut_extended() {
let module: Module<FFT64Spqlios> = Module::<FFT64Spqlios>::new(32);
test_lut_extended(&module);
}
#[test]
fn standard() {
let module: Module<FFT64Spqlios> = Module::<FFT64Spqlios>::new(512);
test_blind_rotation(&module, 224, 1, 1);
}
#[test]
fn block_binary() {
let module: Module<FFT64Spqlios> = Module::<FFT64Spqlios>::new(512);
test_blind_rotation(&module, 224, 7, 1);
}
#[test]
fn block_binary_extended() {
let module: Module<FFT64Spqlios> = Module::<FFT64Spqlios>::new(512);
test_blind_rotation(&module, 224, 7, 2);
}

View File

@@ -0,0 +1,40 @@
use poulpy_backend::cpu_fft64_ref::FFT64Ref;
use poulpy_hal::{api::ModuleNew, layouts::Module};
use crate::tfhe::blind_rotation::{
CGGI,
tests::{
generic_blind_rotation::test_blind_rotation,
generic_lut::{test_lut_extended, test_lut_standard},
},
};
#[test]
fn lut_standard() {
let module: Module<FFT64Ref> = Module::<FFT64Ref>::new(32);
test_lut_standard(&module);
}
#[test]
fn lut_extended() {
let module: Module<FFT64Ref> = Module::<FFT64Ref>::new(32);
test_lut_extended(&module);
}
#[test]
fn standard() {
let module: Module<FFT64Ref> = Module::<FFT64Ref>::new(512);
test_blind_rotation::<CGGI, _, FFT64Ref>(&module, 224, 1, 1);
}
#[test]
fn block_binary() {
let module: Module<FFT64Ref> = Module::<FFT64Ref>::new(512);
test_blind_rotation::<CGGI, _, FFT64Ref>(&module, 224, 7, 1);
}
#[test]
fn block_binary_extended() {
let module: Module<FFT64Ref> = Module::<FFT64Ref>::new(512);
test_blind_rotation::<CGGI, _, FFT64Ref>(&module, 224, 7, 2);
}

View File

@@ -1 +1 @@
mod cpu_spqlios; mod fft64;

View File

@@ -1,14 +1,13 @@
use poulpy_core::layouts::{ use poulpy_core::layouts::{
AutomorphismKey, AutomorphismKeyLayout, GGLWEInfos, GGSWInfos, GLWE, GLWEInfos, GLWESecret, LWEInfos, LWESecret, TensorKey, AutomorphismKey, AutomorphismKeyLayout, GGLWEInfos, GGSWInfos, GLWE, GLWEInfos, GLWESecret, LWEInfos, LWESecret, TensorKey,
TensorKeyLayout, TensorKeyLayout,
prepared::{GLWEAutomorphismKeyPrepared, GLWESecretPrepared, PrepareAlloc, TensorKeyPrepared}, prepared::{GLWEAutomorphismKeyPrepared, GLWESecretPrepared, TensorKeyPrepared},
}; };
use std::collections::HashMap; use std::collections::HashMap;
use poulpy_hal::{ use poulpy_hal::{
api::{ api::{
ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, TakeScalarZnx, ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace,
TakeSvpPPol, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace,
VecZnxAutomorphism, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxAutomorphism, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume,
VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace,
VecZnxSwitchRing, VmpPMatAlloc, VmpPrepare, VecZnxSwitchRing, VmpPMatAlloc, VmpPrepare,

View File

@@ -1,6 +1,8 @@
mod circuit; mod circuit;
mod key; mod key;
pub mod tests;
//[cfg(tests)]
//pub mod tests;
pub use circuit::*; pub use circuit::*;
pub use key::*; pub use key::*;

View File

@@ -1,4 +1,3 @@
pub mod circuit_bootstrapping; pub mod circuit_bootstrapping;
#[cfg(test)]
mod implementation; mod implementation;

View File

@@ -1,3 +1,3 @@
pub mod bdd_arithmetic; // pub mod bdd_arithmetic;
pub mod blind_rotation; pub mod blind_rotation;
pub mod circuit_bootstrapping; //pub mod circuit_bootstrapping;