This commit is contained in:
Pro7ech
2025-10-22 16:43:46 +02:00
parent 5755aea58c
commit cedf7b9c59
26 changed files with 713 additions and 723 deletions

View File

@@ -1,26 +1,21 @@
#[cfg(test)]
use crate::tfhe::bdd_arithmetic::FheUintBlocksPrepDebug;
use crate::tfhe::bdd_arithmetic::FheUintBlocksPreparedDebug;
use crate::tfhe::{
bdd_arithmetic::{FheUintBlocks, FheUintBlocksPrep, UnsignedInteger},
blind_rotation::{BlindRotationAlgo, BlindRotationKey, BlindRotationKeyAlloc, BlindRotationKeyEncryptSk},
bdd_arithmetic::{FheUintBlocks, FheUintBlocksPrepared, UnsignedInteger},
blind_rotation::{BlindRotationAlgo, BlindRotationKey, BlindRotationKeyFactory},
circuit_bootstrapping::{
CircuitBootstrappingKey, CircuitBootstrappingKeyEncryptSk, CircuitBootstrappingKeyLayout,
CircuitBootstrappingKeyPrepared, CirtuitBootstrappingExecute,
CircuitBootstrappingKeyPrepared, CircuitBootstrappingKeyPreparedFactory, CirtuitBootstrappingExecute,
},
};
use poulpy_core::{
GLWEToLWESwitchingKeyEncryptSk, GetDistribution, LWEFromGLWE, ScratchTakeCore,
layouts::{
prepared::GLWEToLWESwitchingKeyPrepared, GLWESecret, GLWEToLWEKeyLayout, GLWEToLWESwitchingKey, LWESecret
}, ScratchTakeCore,
GGSWInfos, GGSWPreparedFactory, GLWEInfos, GLWESecretToRef, GLWEToLWEKeyLayout, GLWEToLWESwitchingKey,
GLWEToLWESwitchingKeyPreparedFactory, LWE, LWEInfos, LWESecretToRef, prepared::GLWEToLWESwitchingKeyPrepared,
},
};
use poulpy_hal::{
api::{
ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace,
VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize,
VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing,
VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPrepare,
},
layouts::{Backend, Data, DataMut, DataRef, Module, Scratch},
source::Source,
};
@@ -46,193 +41,256 @@ impl BDDKeyInfos for BDDKeyLayout {
}
}
pub struct BDDKey<CBT, LWE, BRA>
pub struct BDDKey<D, BRA>
where
CBT: Data,
LWE: Data,
D: Data,
BRA: BlindRotationAlgo,
{
cbt: CircuitBootstrappingKey<CBT, BRA>,
ks: GLWEToLWESwitchingKey<LWE>,
cbt: CircuitBootstrappingKey<D, BRA>,
ks: GLWEToLWESwitchingKey<D>,
}
impl<BRA: BlindRotationAlgo> BDDKey<Vec<u8>, Vec<u8>, BRA> {
pub fn encrypt_sk<DLwe, DGlwe, A, BE: Backend>(
module: &Module<BE>,
sk_lwe: &LWESecret<DLwe>,
sk_glwe: &GLWESecret<DGlwe>,
infos: &A,
impl<BRA: BlindRotationAlgo> BDDKey<Vec<u8>, BRA>
where
BlindRotationKey<Vec<u8>, BRA>: BlindRotationKeyFactory<BRA>,
{
pub fn alloc_from_infos<A>(infos: &A) -> Self
where
A: BDDKeyInfos,
{
Self {
cbt: CircuitBootstrappingKey::alloc_from_infos(&infos.cbt_infos()),
ks: GLWEToLWESwitchingKey::alloc_from_infos(&infos.ks_infos()),
}
}
}
pub trait BDDKeyEncryptSk<BRA: BlindRotationAlgo, BE: Backend> {
fn bdd_key_encrypt_sk<D, S0, S1>(
&self,
res: &mut BDDKey<D, BRA>,
sk_lwe: &S0,
sk_glwe: &S1,
source_xa: &mut Source,
source_xe: &mut Source,
scratch: &mut Scratch<BE>,
) -> Self
where
A: BDDKeyInfos,
DLwe: DataRef,
DGlwe: DataRef,
BlindRotationKey<Vec<u8>, BRA>: BlindRotationKeyAlloc + BlindRotationKeyEncryptSk<BE>,
Module<BE>: SvpApplyDftToDft<BE>
+ VecZnxIdftApplyTmpA<BE>
+ VecZnxAddScalarInplace
+ VecZnxDftBytesOf
+ VecZnxBigNormalize<BE>
+ VecZnxDftApply<BE>
+ SvpApplyDftToDftInplace<BE>
+ VecZnxIdftApplyConsume<BE>
+ VecZnxNormalizeTmpBytes
+ VecZnxFillUniform
+ VecZnxSubInplace
+ VecZnxAddInplace
+ VecZnxNormalizeInplace<BE>
+ VecZnxAddNormal
+ VecZnxNormalize<BE>
+ VecZnxSub
+ SvpPrepare<BE>
+ VecZnxSwitchRing
+ SvpPPolBytesOf
+ SvpPPolAlloc<BE>
+ VecZnxAutomorphism
+ VecZnxAutomorphismInplace<BE>,
) where
D: DataMut,
S0: LWESecretToRef + GetDistribution + LWEInfos,
S1: GLWESecretToRef + GetDistribution + GLWEInfos;
}
impl<BE: Backend, BRA: BlindRotationAlgo> BDDKeyEncryptSk<BRA, BE> for Module<BE>
where
Self: CircuitBootstrappingKeyEncryptSk<BRA, BE> + GLWEToLWESwitchingKeyEncryptSk<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
fn bdd_key_encrypt_sk<D, S0, S1>(
&self,
res: &mut BDDKey<D, BRA>,
sk_lwe: &S0,
sk_glwe: &S1,
source_xa: &mut Source,
source_xe: &mut Source,
scratch: &mut Scratch<BE>,
) where
D: DataMut,
S0: LWESecretToRef + GetDistribution + LWEInfos,
S1: GLWESecretToRef + GetDistribution + GLWEInfos,
{
res.ks
.encrypt_sk(self, sk_lwe, sk_glwe, source_xa, source_xe, scratch);
res.cbt
.encrypt_sk(self, sk_lwe, sk_glwe, source_xa, source_xe, scratch);
}
}
impl<D: DataMut, BRA: BlindRotationAlgo> BDDKey<D, BRA> {
pub fn encrypt_sk<S0, S1, M, BE: Backend>(
&mut self,
module: &M,
sk_lwe: &S0,
sk_glwe: &S1,
source_xa: &mut Source,
source_xe: &mut Source,
scratch: &mut Scratch<BE>,
) where
S0: LWESecretToRef + GetDistribution + LWEInfos,
S1: GLWESecretToRef + GetDistribution + GLWEInfos,
M: BDDKeyEncryptSk<BRA, BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
let mut ks: GLWEToLWESwitchingKey<Vec<u8>> = GLWEToLWESwitchingKey::alloc(&infos.ks_infos());
ks.encrypt_sk(module, sk_lwe, sk_glwe, source_xa, source_xe, scratch);
Self {
cbt: CircuitBootstrappingKey::encrypt_sk(
module,
sk_lwe,
sk_glwe,
&infos.cbt_infos(),
source_xa,
source_xe,
scratch,
),
ks,
}
module.bdd_key_encrypt_sk(self, sk_lwe, sk_glwe, source_xa, source_xe, scratch);
}
}
pub struct BDDKeyPrepared<CBT, LWE, BRA, BE>
pub struct BDDKeyPrepared<D, BRA, BE>
where
CBT: Data,
LWE: Data,
D: Data,
BRA: BlindRotationAlgo,
BE: Backend,
{
cbt: CircuitBootstrappingKeyPrepared<CBT, BRA, BE>,
ks: GLWEToLWESwitchingKeyPrepared<LWE, BE>,
pub(crate) cbt: CircuitBootstrappingKeyPrepared<D, BRA, BE>,
pub(crate) ks: GLWEToLWESwitchingKeyPrepared<D, BE>,
}
impl<CBT: DataMut, LWE: DataMut, BRA: BlindRotationAlgo, BE: Backend> PrepareAlloc<BE, BDDKeyPrepared<CBT, LWE, BRA, BE>>
for BDDKey<CBT, LWE, BRA>
pub trait BDDKeyPreparedFactory<BRA: BlindRotationAlgo, BE: Backend>
where
CircuitBootstrappingKey<CBT, BRA>: PrepareAlloc<BE, CircuitBootstrappingKeyPrepared<CBT, BRA, BE>>,
GLWEToLWESwitchingKey<LWE>: PrepareAlloc<BE, GLWEToLWESwitchingKeyPrepared<LWE, BE>>,
Self: Sized + CircuitBootstrappingKeyPreparedFactory<BRA, BE> + GLWEToLWESwitchingKeyPreparedFactory<BE>,
{
fn prepare_alloc(&self, module: &Module<BE>, scratch: &mut Scratch<BE>) -> BDDKeyPrepared<CBT, LWE, BRA, BE> {
fn alloc_bdd_key_from_infos<A>(&self, infos: &A) -> BDDKeyPrepared<Vec<u8>, BRA, BE>
where
A: BDDKeyInfos,
{
BDDKeyPrepared {
cbt: self.cbt.prepare_alloc(module, scratch),
ks: self.ks.prepare_alloc(module, scratch),
cbt: CircuitBootstrappingKeyPrepared::alloc_from_infos(self, &infos.cbt_infos()),
ks: GLWEToLWESwitchingKeyPrepared::alloc_from_infos(self, &infos.ks_infos()),
}
}
fn prepare_bdd_key_tmp_bytes<A>(&self, infos: &A) -> usize
where
A: BDDKeyInfos,
{
self.circuit_bootstrapping_key_prepare_tmp_bytes(&infos.cbt_infos())
.max(self.prepare_glwe_to_lwe_switching_key_tmp_bytes(&infos.ks_infos()))
}
fn prepare_bdd_key<DM, DR>(&self, res: &mut BDDKeyPrepared<DM, BRA, BE>, other: &BDDKey<DR, BRA>, scratch: &mut Scratch<BE>)
where
DM: DataMut,
DR: DataRef,
Scratch<BE>: ScratchTakeCore<BE>,
{
res.cbt.prepare(self, &other.cbt, scratch);
res.ks.prepare(self, &other.ks, scratch);
}
}
impl<BRA: BlindRotationAlgo, BE: Backend> BDDKeyPreparedFactory<BRA, BE> for Module<BE> where
Self: Sized + CircuitBootstrappingKeyPreparedFactory<BRA, BE> + GLWEToLWESwitchingKeyPreparedFactory<BE>
{
}
impl<BRA: BlindRotationAlgo, BE: Backend> BDDKeyPrepared<Vec<u8>, BRA, BE> {
pub fn alloc_from_infos<M, A>(module: &M, infos: &A) -> Self
where
M: BDDKeyPreparedFactory<BRA, BE>,
A: BDDKeyInfos,
{
module.alloc_bdd_key_from_infos(infos)
}
}
impl<D: DataMut, BRA: BlindRotationAlgo, BE: Backend> BDDKeyPrepared<D, BRA, BE> {
pub fn prepare<DR, M>(&mut self, module: &M, other: &BDDKey<DR, BRA>, scratch: &mut Scratch<BE>)
where
DR: DataRef,
M: BDDKeyPreparedFactory<BRA, BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
module.prepare_bdd_key(self, other, scratch);
}
}
pub trait FheUintBlocksPrepare<BRA: BlindRotationAlgo, T: UnsignedInteger, BE: Backend> {
fn fhe_uint_blocks_prepare_tmp_bytes<R, A>(
&self,
block_size: usize,
extension_factor: usize,
res_infos: &R,
infos: &A,
) -> usize
where
R: GGSWInfos,
A: BDDKeyInfos;
fn fhe_uint_blocks_prepare<DM, DR0, DR1>(
&self,
res: &mut FheUintBlocksPrepared<DM, T, BE>,
bits: &FheUintBlocks<DR0, T>,
key: &BDDKeyPrepared<DR1, BRA, BE>,
scratch: &mut Scratch<BE>,
) where
DM: DataMut,
DR0: DataRef,
DR1: DataRef;
}
impl<BRA: BlindRotationAlgo, BE: Backend, T: UnsignedInteger> FheUintBlocksPrepare<BRA, T, BE> for Module<BE>
where
Self: LWEFromGLWE<BE> + CirtuitBootstrappingExecute<BRA, BE> + GGSWPreparedFactory<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
fn fhe_uint_blocks_prepare_tmp_bytes<R, A>(
&self,
block_size: usize,
extension_factor: usize,
res_infos: &R,
bdd_infos: &A,
) -> usize
where
R: GGSWInfos,
A: BDDKeyInfos,
{
self.circuit_bootstrapping_execute_tmp_bytes(
block_size,
extension_factor,
res_infos,
&bdd_infos.cbt_infos(),
)
}
fn fhe_uint_blocks_prepare<DM, DR0, DR1>(
&self,
res: &mut FheUintBlocksPrepared<DM, T, BE>,
bits: &FheUintBlocks<DR0, T>,
key: &BDDKeyPrepared<DR1, BRA, BE>,
scratch: &mut Scratch<BE>,
) where
DM: DataMut,
DR0: DataRef,
DR1: DataRef,
{
assert_eq!(res.blocks.len(), bits.blocks.len());
let mut lwe: LWE<Vec<u8>> = LWE::alloc_from_infos(&bits.blocks[0]); //TODO: add TakeLWE
let (mut tmp_ggsw, scratch_1) = scratch.take_ggsw(res);
for (dst, src) in res.blocks.iter_mut().zip(bits.blocks.iter()) {
lwe.from_glwe(self, src, &key.ks, scratch_1);
key.cbt
.execute_to_constant(self, &mut tmp_ggsw, &lwe, 1, 1, scratch_1);
dst.prepare(self, &tmp_ggsw, scratch_1);
}
}
}
pub trait FheUintPrepare<BE: Backend, OUT, IN> {
fn prepare(&self, module: &Module<BE>, out: &mut OUT, bits: &IN, scratch: &mut Scratch<BE>);
}
impl<CBT, OUT, IN, LWE, BRA, BE, T> FheUintPrepare<BE, FheUintBlocksPrep<OUT, BE, T>, FheUintBlocks<IN, T>>
for BDDKeyPrepared<CBT, LWE, BRA, BE>
where
T: UnsignedInteger,
CBT: DataRef,
OUT: DataMut,
IN: DataRef,
LWE: DataRef,
BRA: BlindRotationAlgo,
BE: Backend,
Module<BE>: VmpPrepare<BE>
+ VecZnxRotate
+ VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<BE>
+ VmpApplyDftToDftAdd<BE>
+ VecZnxDftApply<BE>
+ VecZnxIdftApplyConsume<BE>
+ VecZnxBigAddSmallInplace<BE>
+ VecZnxBigNormalize<BE>
+ VecZnxNormalize<BE>
+ VecZnxNormalizeTmpBytes,
Scratch<BE>: ScratchAvailable + TakeVecZnxDft<BE> + TakeGLWE + TakeVecZnx + TakeGGSW,
CircuitBootstrappingKeyPrepared<CBT, BRA, BE>: CirtuitBootstrappingExecute<BE>,
{
fn prepare(
&self,
module: &Module<BE>,
out: &mut FheUintBlocksPrep<OUT, BE, T>,
bits: &FheUintBlocks<IN, T>,
impl<D: DataMut, T: UnsignedInteger, BE: Backend> FheUintBlocksPrepared<D, T, BE> {
pub fn prepare<BRA, M, O, K>(
&mut self,
module: &M,
other: &FheUintBlocks<O, T>,
key: &BDDKeyPrepared<K, BRA, BE>,
scratch: &mut Scratch<BE>,
) {
#[cfg(debug_assertions)]
{
assert_eq!(out.blocks.len(), bits.blocks.len());
}
let mut lwe: LWE<Vec<u8>> = LWE::alloc(&bits.blocks[0]); //TODO: add TakeLWE
let (mut tmp_ggsw, scratch_1) = scratch.take_ggsw(out);
for (dst, src) in out.blocks.iter_mut().zip(bits.blocks.iter()) {
lwe.from_glwe(module, src, &self.ks, scratch_1);
self.cbt
.execute_to_constant(module, &mut tmp_ggsw, &lwe, 1, 1, scratch_1);
dst.prepare(module, &tmp_ggsw, scratch_1);
}
) where
BRA: BlindRotationAlgo,
O: DataRef,
K: DataRef,
M: FheUintBlocksPrepare<BRA, T, BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
module.fhe_uint_blocks_prepare(self, other, key, scratch);
}
}
#[cfg(test)]
impl<CBT, OUT, IN, LWE, BRA, BE, T> FheUintPrepare<BE, FheUintBlocksPrepDebug<OUT, T>, FheUintBlocks<IN, T>>
for BDDKeyPrepared<CBT, LWE, BRA, BE>
where
T: UnsignedInteger,
CBT: DataRef,
OUT: DataMut,
IN: DataRef,
LWE: DataRef,
BRA: BlindRotationAlgo,
BE: Backend,
Module<BE>: VmpPrepare<BE>
+ VecZnxRotate
+ VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<BE>
+ VmpApplyDftToDftAdd<BE>
+ VecZnxDftApply<BE>
+ VecZnxIdftApplyConsume<BE>
+ VecZnxBigAddSmallInplace<BE>
+ VecZnxBigNormalize<BE>
+ VecZnxNormalize<BE>
+ VecZnxNormalizeTmpBytes,
Scratch<BE>: ScratchTakeCore<BE>,
CircuitBootstrappingKeyPrepared<CBT, BRA, BE>: CirtuitBootstrappingExecute<BE>,
{
fn prepare(
pub(crate) trait FheUintBlockDebugPrepare<BRA: BlindRotationAlgo, T: UnsignedInteger, BE: Backend> {
fn fhe_uint_block_debug_prepare<DM, DR0, DR1>(
&self,
module: &Module<BE>,
out: &mut FheUintBlocksPrepDebug<OUT, T>,
bits: &FheUintBlocks<IN, T>,
res: &mut FheUintBlocksPreparedDebug<DM, T>,
bits: &FheUintBlocks<DR0, T>,
key: &BDDKeyPrepared<DR1, BRA, BE>,
scratch: &mut Scratch<BE>,
) {
#[cfg(debug_assertions)]
{
assert_eq!(out.blocks.len(), bits.blocks.len());
}
let mut lwe: LWE<Vec<u8>> = LWE::alloc(&bits.blocks[0]); //TODO: add TakeLWE
for (dst, src) in out.blocks.iter_mut().zip(bits.blocks.iter()) {
lwe.from_glwe(module, src, &self.ks, scratch);
self.cbt
.execute_to_constant(module, dst, &lwe, 1, 1, scratch);
}
}
) where
DM: DataMut,
DR0: DataRef,
DR1: DataRef;
}