Ref. + AVX code & generic tests + benches (#85)

This commit is contained in:
Jean-Philippe Bossuat
2025-09-15 16:16:11 +02:00
committed by GitHub
parent 99b9e3e10e
commit 56dbd29c59
286 changed files with 27797 additions and 7270 deletions

View File

@@ -2,12 +2,13 @@ use std::collections::HashMap;
use poulpy_hal::{
api::{
DFT, IDFTConsume, IDFTTmpA, ScratchAvailable, TakeMatZnx, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice,
TakeVecZnxSlice, VecZnxAddInplace, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes,
VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace, VecZnxCopy,
VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftCopy, VecZnxNegateInplace, VecZnxNormalizeInplace,
VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubABInplace,
VecZnxSwithcDegree, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
ScratchAvailable, TakeMatZnx, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice, TakeVecZnxSlice,
VecZnxAddInplace, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigAutomorphismInplace,
VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace, VecZnxCopy, VecZnxDftAddInplace,
VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNegateInplace,
VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes,
VecZnxRshInplace, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd,
VmpApplyDftToDftTmpBytes,
},
layouts::{Backend, DataMut, DataRef, Module, Scratch, ToOwnedDeep},
oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl},
@@ -26,14 +27,14 @@ use crate::tfhe::{
impl<D: DataRef, BRA: BlindRotationAlgo, B> CirtuitBootstrappingExecute<B> for CircuitBootstrappingKeyPrepared<D, BRA, B>
where
Module<B>: VecZnxRotateInplace
Module<B>: VecZnxRotateInplace<B>
+ VecZnxNormalizeInplace<B>
+ VecZnxNormalizeTmpBytes
+ VecZnxSwithcDegree
+ VecZnxSwitchRing
+ VecZnxBigAutomorphismInplace<B>
+ VecZnxRshInplace
+ VecZnxRshInplace<B>
+ VecZnxDftCopy<B>
+ IDFTTmpA<B>
+ VecZnxIdftApplyTmpA<B>
+ VecZnxSub
+ VecZnxAddInplace
+ VecZnxNegateInplace
@@ -44,12 +45,13 @@ where
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxDftApply<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>
+ VecZnxAutomorphismInplace
+ VecZnxAutomorphismInplace<B>
+ VecZnxBigSubSmallBInplace<B>
+ VecZnxRotateInplaceTmpBytes
+ VecZnxBigAllocBytes
+ VecZnxDftAddInplace<B>
+ VecZnxRotate,
@@ -124,14 +126,14 @@ pub fn circuit_bootstrap_core<DRes, DLwe, DBrk, BRA: BlindRotationAlgo, B>(
DRes: DataMut,
DLwe: DataRef,
DBrk: DataRef,
Module<B>: VecZnxRotateInplace
Module<B>: VecZnxRotateInplace<B>
+ VecZnxNormalizeInplace<B>
+ VecZnxNormalizeTmpBytes
+ VecZnxSwithcDegree
+ VecZnxSwitchRing
+ VecZnxBigAutomorphismInplace<B>
+ VecZnxRshInplace
+ VecZnxRshInplace<B>
+ VecZnxDftCopy<B>
+ IDFTTmpA<B>
+ VecZnxIdftApplyTmpA<B>
+ VecZnxSub
+ VecZnxAddInplace
+ VecZnxNegateInplace
@@ -142,14 +144,15 @@ pub fn circuit_bootstrap_core<DRes, DLwe, DBrk, BRA: BlindRotationAlgo, B>(
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxDftApply<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>
+ VecZnxAutomorphismInplace
+ VecZnxAutomorphismInplace<B>
+ VecZnxBigSubSmallBInplace<B>
+ VecZnxBigAllocBytes
+ VecZnxDftAddInplace<B>
+ VecZnxRotateInplaceTmpBytes
+ VecZnxRotate,
B: Backend + ScratchOwnedAllocImpl<B> + ScratchOwnedBorrowImpl<B>,
Scratch<B>: TakeVecZnxDftSlice<B>
@@ -199,10 +202,10 @@ pub fn circuit_bootstrap_core<DRes, DLwe, DBrk, BRA: BlindRotationAlgo, B>(
}
// TODO: separate GGSW k from output of blind rotation k
let (mut res_glwe, scratch1) = scratch.take_glwe_ct(n, basek, k, rank);
let (mut tmp_gglwe, scratch2) = scratch1.take_gglwe(n, basek, k, rows, 1, rank.max(1), rank);
let (mut res_glwe, scratch_1) = scratch.take_glwe_ct(n, basek, k, rank);
let (mut tmp_gglwe, scratch_2) = scratch_1.take_gglwe(n, basek, k, rows, 1, rank.max(1), rank);
key.brk.execute(module, &mut res_glwe, lwe, &lut, scratch2);
key.brk.execute(module, &mut res_glwe, lwe, &lut, scratch_2);
let gap: usize = 2 * lut.drift / lut.extension_factor();
@@ -221,19 +224,19 @@ pub fn circuit_bootstrap_core<DRes, DLwe, DBrk, BRA: BlindRotationAlgo, B>(
log_gap_out,
log_domain,
&key.atk,
scratch2,
scratch_2,
);
} else {
tmp_glwe.trace(module, 0, module.log_n(), &res_glwe, &key.atk, scratch2);
tmp_glwe.trace(module, 0, module.log_n(), &res_glwe, &key.atk, scratch_2);
}
if i < rows {
res_glwe.rotate_inplace(module, -(gap as i64));
res_glwe.rotate_inplace(module, -(gap as i64), scratch_2);
}
});
// Expands GGLWE to GGSW using GGLWE(s^2)
res.from_gglwe(module, &tmp_gglwe, &key.tsk, scratch2);
res.from_gglwe(module, &tmp_gglwe, &key.tsk, scratch_2);
}
#[allow(clippy::too_many_arguments)]
@@ -249,14 +252,14 @@ fn post_process<DataRes, DataA, B: Backend>(
) where
DataRes: DataMut,
DataA: DataRef,
Module<B>: VecZnxRotateInplace
Module<B>: VecZnxRotateInplace<B>
+ VecZnxNormalizeInplace<B>
+ VecZnxNormalizeTmpBytes
+ VecZnxSwithcDegree
+ VecZnxSwitchRing
+ VecZnxBigAutomorphismInplace<B>
+ VecZnxRshInplace
+ VecZnxRshInplace<B>
+ VecZnxDftCopy<B>
+ IDFTTmpA<B>
+ VecZnxIdftApplyTmpA<B>
+ VecZnxSub
+ VecZnxAddInplace
+ VecZnxNegateInplace
@@ -267,11 +270,11 @@ fn post_process<DataRes, DataA, B: Backend>(
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxDftApply<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>
+ VecZnxAutomorphismInplace
+ VecZnxAutomorphismInplace<B>
+ VecZnxBigSubSmallBInplace<B>
+ VecZnxRotate,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
@@ -297,7 +300,7 @@ fn post_process<DataRes, DataA, B: Backend>(
let steps: i32 = 1 << log_domain;
(0..steps).for_each(|i| {
if i != 0 {
res.rotate_inplace(module, -(1 << log_gap_in));
res.rotate_inplace(module, -(1 << log_gap_in), scratch);
}
cts.insert(i as usize * (1 << log_gap_out), res.to_owned_deep());
});
@@ -321,14 +324,14 @@ pub fn pack<D: DataMut, B: Backend>(
auto_keys: &HashMap<i64, GGLWEAutomorphismKeyPrepared<Vec<u8>, B>>,
scratch: &mut Scratch<B>,
) where
Module<B>: VecZnxRotateInplace
Module<B>: VecZnxRotateInplace<B>
+ VecZnxNormalizeInplace<B>
+ VecZnxNormalizeTmpBytes
+ VecZnxSwithcDegree
+ VecZnxSwitchRing
+ VecZnxBigAutomorphismInplace<B>
+ VecZnxRshInplace
+ VecZnxRshInplace<B>
+ VecZnxDftCopy<B>
+ IDFTTmpA<B>
+ VecZnxIdftApplyTmpA<B>
+ VecZnxSub
+ VecZnxAddInplace
+ VecZnxNegateInplace
@@ -339,11 +342,11 @@ pub fn pack<D: DataMut, B: Backend>(
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxDftApply<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>
+ VecZnxAutomorphismInplace
+ VecZnxAutomorphismInplace<B>
+ VecZnxBigSubSmallBInplace<B>
+ VecZnxRotate,
Scratch<B>: TakeVecZnx + TakeVecZnxDft<B> + ScratchAvailable,
@@ -400,14 +403,14 @@ fn combine<A: DataMut, D: DataMut, DataAK: DataRef, B: Backend>(
auto_key: &GGLWEAutomorphismKeyPrepared<DataAK, B>,
scratch: &mut Scratch<B>,
) where
Module<B>: VecZnxRotateInplace
Module<B>: VecZnxRotateInplace<B>
+ VecZnxNormalizeInplace<B>
+ VecZnxNormalizeTmpBytes
+ VecZnxSwithcDegree
+ VecZnxSwitchRing
+ VecZnxBigAutomorphismInplace<B>
+ VecZnxRshInplace
+ VecZnxRshInplace<B>
+ VecZnxDftCopy<B>
+ IDFTTmpA<B>
+ VecZnxIdftApplyTmpA<B>
+ VecZnxSub
+ VecZnxAddInplace
+ VecZnxNegateInplace
@@ -418,11 +421,11 @@ fn combine<A: DataMut, D: DataMut, DataAK: DataRef, B: Backend>(
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxDftApply<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>
+ VecZnxAutomorphismInplace
+ VecZnxAutomorphismInplace<B>
+ VecZnxBigSubSmallBInplace<B>
+ VecZnxRotate,
Scratch<B>: TakeVecZnx + TakeVecZnxDft<B> + ScratchAvailable,
@@ -446,15 +449,15 @@ fn combine<A: DataMut, D: DataMut, DataAK: DataRef, B: Backend>(
let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(n, basek, k, rank);
// a = a * X^-t
a.rotate_inplace(module, -t);
a.rotate_inplace(module, -t, scratch_1);
// tmp_b = a * X^-t - b
tmp_b.sub(module, a, b);
tmp_b.rsh(module, 1);
tmp_b.rsh(module, 1, scratch_1);
// a = a * X^-t + b
a.add_inplace(module, b);
a.rsh(module, 1);
a.rsh(module, 1, scratch_1);
tmp_b.normalize_inplace(module, scratch_1);
@@ -468,9 +471,9 @@ fn combine<A: DataMut, D: DataMut, DataAK: DataRef, B: Backend>(
// a = a + b * X^t - phi(a * X^-t - b) * X^t
// = a + b * X^t - phi(a * X^-t - b) * - phi(X^t)
// = a + b * X^t + phi(a - b * X^t)
a.rotate_inplace(module, t);
a.rotate_inplace(module, t, scratch_1);
} else {
a.rsh(module, 1);
a.rsh(module, 1, scratch);
// a = a + phi(a)
a.automorphism_add_inplace(module, auto_key, scratch);
}
@@ -481,7 +484,7 @@ fn combine<A: DataMut, D: DataMut, DataAK: DataRef, B: Backend>(
let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(n, basek, k, rank);
tmp_b.rotate(module, t, b);
tmp_b.rsh(module, 1);
tmp_b.rsh(module, 1, scratch_1);
// a = (b* X^t - phi(b* X^t))
b.automorphism_sub_ba(module, &tmp_b, auto_key, scratch_1);