Traits cleaning, CBT example & bug fixes (#72)

* Some cleaning, CBT example, fix mod switch and add LUT correctness test to BR test

* finished trait cleaning

* removed trait aliastoutside of backend
This commit is contained in:
Jean-Philippe Bossuat
2025-08-16 18:23:22 +02:00
committed by GitHub
parent c7219c35e9
commit 3a828740cc
99 changed files with 3267 additions and 1220 deletions

View File

@@ -1,10 +1,11 @@
use backend::hal::{
api::{
ScratchAvailable, SvpApply, SvpPPolAllocBytes, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice,
TakeVecZnxSlice, VecZnxAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalizeTmpBytes, VecZnxCopy,
VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftSubABInplace, VecZnxDftToVecZnxBig,
VecZnxDftToVecZnxBigTmpBytes, VecZnxDftZero, VecZnxMulXpMinusOneInplace, VecZnxNormalize, VecZnxNormalizeInplace,
VecZnxRotate, VecZnxSubABInplace, VmpApplyTmpBytes, ZnxView, ZnxZero,
TakeVecZnxSlice, VecZnxAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize,
VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftFromVecZnx,
VecZnxDftSubABInplace, VecZnxDftToVecZnxBig, VecZnxDftToVecZnxBigConsume, VecZnxDftToVecZnxBigTmpBytes, VecZnxDftZero,
VecZnxMulXpMinusOneInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate,
VecZnxSubABInplace, VmpApply, VmpApplyAdd, VmpApplyTmpBytes, ZnxView, ZnxZero,
},
layouts::{Backend, DataMut, DataRef, Module, Scratch, SvpPPol, VecZnx},
};
@@ -13,36 +14,12 @@ use itertools::izip;
use core::{
Distribution, GLWEOperations, TakeGLWECt,
layouts::{GLWECiphertext, GLWECiphertextToMut, Infos, LWECiphertext, LWECiphertextToRef},
trait_families::GLWEExternalProductFamily,
};
use crate::tfhe::blind_rotation::{
BlincRotationExecute, BlindRotationKeyPrepared, CGGI, LookUpTable, LookUpTableRotationDirection,
};
pub trait CCGIBlindRotationFamily<B: Backend> = VecZnxBigAllocBytes
+ VecZnxDftAllocBytes
+ SvpPPolAllocBytes
+ VmpApplyTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VecZnxDftToVecZnxBigTmpBytes
+ VecZnxDftToVecZnxBig<B>
+ VecZnxDftAdd<B>
+ VecZnxDftAddInplace<B>
+ VecZnxDftFromVecZnx<B>
+ VecZnxDftZero<B>
+ SvpApply<B>
+ VecZnxDftSubABInplace<B>
+ VecZnxBigAddSmallInplace<B>
+ GLWEExternalProductFamily<B>
+ VecZnxRotate
+ VecZnxAddInplace
+ VecZnxSubABInplace
+ VecZnxNormalize<B>
+ VecZnxNormalizeInplace<B>
+ VecZnxCopy
+ VecZnxMulXpMinusOneInplace;
pub fn cggi_blind_rotate_scratch_space<B: Backend>(
module: &Module<B>,
n: usize,
@@ -55,7 +32,12 @@ pub fn cggi_blind_rotate_scratch_space<B: Backend>(
rank: usize,
) -> usize
where
Module<B>: CCGIBlindRotationFamily<B>,
Module<B>: VecZnxDftAllocBytes
+ VmpApplyTmpBytes
+ VecZnxNormalizeTmpBytes
+ VecZnxBigAllocBytes
+ VecZnxDftToVecZnxBigTmpBytes
+ VecZnxBigNormalizeTmpBytes,
{
let brk_size: usize = k_brk.div_ceil(basek);
@@ -89,7 +71,32 @@ where
impl<D: DataRef, B: Backend> BlincRotationExecute<B> for BlindRotationKeyPrepared<D, CGGI, B>
where
Module<B>: CCGIBlindRotationFamily<B>,
Module<B>: VecZnxBigAllocBytes
+ VecZnxDftAllocBytes
+ SvpPPolAllocBytes
+ VmpApplyTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VecZnxDftToVecZnxBigTmpBytes
+ VecZnxDftToVecZnxBig<B>
+ VecZnxDftAdd<B>
+ VecZnxDftAddInplace<B>
+ VecZnxDftFromVecZnx<B>
+ VecZnxDftZero<B>
+ SvpApply<B>
+ VecZnxDftSubABInplace<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxRotate
+ VecZnxAddInplace
+ VecZnxSubABInplace
+ VecZnxNormalize<B>
+ VecZnxNormalizeInplace<B>
+ VecZnxCopy
+ VecZnxMulXpMinusOneInplace
+ VmpApply<B>
+ VmpApplyAdd<B>
+ VecZnxDftToVecZnxBigConsume<B>
+ VecZnxBigNormalize<B>
+ VecZnxNormalizeTmpBytes,
Scratch<B>: TakeVecZnxDftSlice<B> + TakeVecZnxDft<B> + TakeVecZnxBig<B> + TakeVecZnxSlice + TakeVecZnx + ScratchAvailable,
{
fn execute<DR: DataMut, DI: DataRef>(
@@ -126,7 +133,29 @@ fn execute_block_binary_extended<DataRes, DataIn, DataBrk, B: Backend>(
DataRes: DataMut,
DataIn: DataRef,
DataBrk: DataRef,
Module<B>: CCGIBlindRotationFamily<B>,
Module<B>: VecZnxBigAllocBytes
+ VecZnxDftAllocBytes
+ SvpPPolAllocBytes
+ VmpApplyTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VecZnxDftToVecZnxBigTmpBytes
+ VecZnxDftToVecZnxBig<B>
+ VecZnxDftAdd<B>
+ VecZnxDftAddInplace<B>
+ VecZnxDftFromVecZnx<B>
+ VecZnxDftZero<B>
+ SvpApply<B>
+ VecZnxDftSubABInplace<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxRotate
+ VecZnxAddInplace
+ VecZnxSubABInplace
+ VecZnxNormalize<B>
+ VecZnxNormalizeInplace<B>
+ VecZnxCopy
+ VecZnxMulXpMinusOneInplace
+ VecZnxBigNormalize<B>
+ VmpApply<B>,
Scratch<B>: TakeVecZnxDftSlice<B> + TakeVecZnxDft<B> + TakeVecZnxBig<B> + TakeVecZnxSlice + ScratchAvailable + TakeVecZnx,
{
let n_glwe: usize = brk.n();
@@ -271,7 +300,29 @@ fn execute_block_binary<DataRes, DataIn, DataBrk, B: Backend>(
DataRes: DataMut,
DataIn: DataRef,
DataBrk: DataRef,
Module<B>: CCGIBlindRotationFamily<B>,
Module<B>: VecZnxBigAllocBytes
+ VecZnxDftAllocBytes
+ SvpPPolAllocBytes
+ VmpApplyTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VecZnxDftToVecZnxBigTmpBytes
+ VecZnxDftToVecZnxBig<B>
+ VecZnxDftAdd<B>
+ VecZnxDftAddInplace<B>
+ VecZnxDftFromVecZnx<B>
+ VecZnxDftZero<B>
+ SvpApply<B>
+ VecZnxDftSubABInplace<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxRotate
+ VecZnxAddInplace
+ VecZnxSubABInplace
+ VecZnxNormalize<B>
+ VecZnxNormalizeInplace<B>
+ VecZnxCopy
+ VecZnxMulXpMinusOneInplace
+ VmpApply<B>
+ VecZnxBigNormalize<B>,
Scratch<B>: TakeVecZnxDftSlice<B> + TakeVecZnxDft<B> + TakeVecZnxBig<B> + TakeVecZnxSlice + ScratchAvailable + TakeVecZnx,
{
let n_glwe: usize = brk.n();
@@ -363,7 +414,32 @@ fn execute_standard<DataRes, DataIn, DataBrk, B: Backend>(
DataRes: DataMut,
DataIn: DataRef,
DataBrk: DataRef,
Module<B>: CCGIBlindRotationFamily<B>,
Module<B>: VecZnxBigAllocBytes
+ VecZnxDftAllocBytes
+ SvpPPolAllocBytes
+ VmpApplyTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VecZnxDftToVecZnxBigTmpBytes
+ VecZnxDftToVecZnxBig<B>
+ VecZnxDftAdd<B>
+ VecZnxDftAddInplace<B>
+ VecZnxDftFromVecZnx<B>
+ VecZnxDftZero<B>
+ SvpApply<B>
+ VecZnxDftSubABInplace<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxRotate
+ VecZnxAddInplace
+ VecZnxSubABInplace
+ VecZnxNormalize<B>
+ VecZnxNormalizeInplace<B>
+ VecZnxCopy
+ VecZnxMulXpMinusOneInplace
+ VmpApply<B>
+ VmpApplyAdd<B>
+ VecZnxDftToVecZnxBigConsume<B>
+ VecZnxBigNormalize<B>
+ VecZnxNormalizeTmpBytes,
Scratch<B>: TakeVecZnxDftSlice<B> + TakeVecZnxDft<B> + TakeVecZnxBig<B> + TakeVecZnxSlice + ScratchAvailable + TakeVecZnx,
{
#[cfg(debug_assertions)]
@@ -454,7 +530,7 @@ pub(crate) fn mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphertext<&[u8]
}
if basek > log2n {
let diff: usize = basek - log2n;
let diff: usize = basek - (log2n - 1); // additional -1 because we map to [-N/2, N/2) instead of [0, N)
res.iter_mut().for_each(|x| {
*x = div_round_by_pow2(x, diff);
})