Add BDD Arithmetic (#98)

* Added some circuit, evaluation + some layouts

* Refactor + memory reduction

* Rows -> Dnum, Digits -> Dsize

* fix #96 + glwe_packing (indirectly CBT)

* clippy
This commit is contained in:
Jean-Philippe Bossuat
2025-10-08 17:52:03 +02:00
committed by GitHub
parent 37e13b965c
commit 6357a05509
119 changed files with 15996 additions and 1659 deletions

View File

@@ -0,0 +1,220 @@
use itertools::Itertools;
use poulpy_core::layouts::prepared::GGSWCiphertextPreparedToRef;
use poulpy_hal::layouts::{Backend, DataMut, DataRef, Module, Scratch};
use crate::tfhe::bdd_arithmetic::{
BitCircuitInfo, Circuit, CircuitExecute, FheUintBlocks, FheUintBlocksPrep, UnsignedInteger, circuits,
};
/// Operations Z x Z -> Z
pub(crate) struct Circuits2WTo1W<C: BitCircuitInfo + 'static, const WORD_SIZE: usize>(pub &'static Circuit<C, WORD_SIZE>);
pub trait EvalBDD2WTo1W<BE: Backend, T: UnsignedInteger> {
fn eval_bdd_2w_to_1w<R, A, B>(
&self,
module: &Module<BE>,
out: &mut FheUintBlocks<R, T>,
a: &FheUintBlocksPrep<A, BE, T>,
b: &FheUintBlocksPrep<B, BE, T>,
scratch: &mut Scratch<BE>,
) where
R: DataMut,
A: DataRef,
B: DataRef;
}
impl<C: BitCircuitInfo + 'static, const WORD_SIZE: usize, BE: Backend, T: UnsignedInteger> EvalBDD2WTo1W<BE, T>
for Circuits2WTo1W<C, WORD_SIZE>
where
Circuit<C, WORD_SIZE>: CircuitExecute<BE, T>,
{
fn eval_bdd_2w_to_1w<R, A, B>(
&self,
module: &Module<BE>,
out: &mut FheUintBlocks<R, T>,
a: &FheUintBlocksPrep<A, BE, T>,
b: &FheUintBlocksPrep<B, BE, T>,
scratch: &mut Scratch<BE>,
) where
R: DataMut,
A: DataRef,
B: DataRef,
{
eval_bdd_2w_to_1w(module, self.0, out, a, b, scratch);
}
}
pub fn eval_bdd_2w_to_1w<R: DataMut, A: DataRef, B: DataRef, T: UnsignedInteger, C: CircuitExecute<BE, T>, BE: Backend>(
module: &Module<BE>,
circuit: &C,
out: &mut FheUintBlocks<R, T>,
a: &FheUintBlocksPrep<A, BE, T>,
b: &FheUintBlocksPrep<B, BE, T>,
scratch: &mut Scratch<BE>,
) {
#[cfg(debug_assertions)]
{
assert_eq!(out.blocks.len(), T::WORD_SIZE);
assert_eq!(b.blocks.len(), T::WORD_SIZE);
assert_eq!(b.blocks.len(), T::WORD_SIZE);
}
// Collects inputs into a single array
let inputs: Vec<&dyn GGSWCiphertextPreparedToRef<BE>> = a
.blocks
.iter()
.map(|x| x as &dyn GGSWCiphertextPreparedToRef<BE>)
.chain(
b.blocks
.iter()
.map(|x| x as &dyn GGSWCiphertextPreparedToRef<BE>),
)
.collect_vec();
// Evaluates out[i] = circuit[i](a, b)
circuit.execute(module, &mut out.blocks, &inputs, scratch);
}
#[macro_export]
macro_rules! define_bdd_2w_to_1w_trait {
($(#[$meta:meta])* $vis:vis $trait_name:ident, $method_name:ident) => {
$(#[$meta])*
$vis trait $trait_name<T: UnsignedInteger, BE: Backend> {
fn $method_name<A, B>(
&mut self,
module: &Module<BE>,
a: &FheUintBlocksPrep<A, BE, T>,
b: &FheUintBlocksPrep<B, BE, T>,
scratch: &mut Scratch<BE>,
) where
A: DataRef,
B: DataRef;
}
};
}
#[macro_export]
macro_rules! impl_bdd_2w_to_1w_trait {
($trait_name:ident, $method_name:ident, $ty:ty, $n:literal, $circuit_ty:ty, $output_circuits:path) => {
impl<D: DataMut, BE: Backend> $trait_name<$ty, BE> for FheUintBlocks<D, $ty>
where
Circuits2WTo1W<$circuit_ty, $n>: EvalBDD2WTo1W<BE, $ty>,
{
fn $method_name<A, B>(
&mut self,
module: &Module<BE>,
a: &FheUintBlocksPrep<A, BE, $ty>,
b: &FheUintBlocksPrep<B, BE, $ty>,
scratch: &mut Scratch<BE>,
) where
A: DataRef,
B: DataRef,
{
const OP: Circuits2WTo1W<$circuit_ty, $n> = Circuits2WTo1W::<$circuit_ty, $n>(&$output_circuits);
OP.eval_bdd_2w_to_1w(module, self, a, b, scratch);
}
}
};
}
define_bdd_2w_to_1w_trait!(pub Add, add);
define_bdd_2w_to_1w_trait!(pub Sub, sub);
define_bdd_2w_to_1w_trait!(pub Sll, sll);
define_bdd_2w_to_1w_trait!(pub Sra, sra);
define_bdd_2w_to_1w_trait!(pub Srl, srl);
define_bdd_2w_to_1w_trait!(pub Slt, slt);
define_bdd_2w_to_1w_trait!(pub Sltu, sltu);
define_bdd_2w_to_1w_trait!(pub Or, or);
define_bdd_2w_to_1w_trait!(pub And, and);
define_bdd_2w_to_1w_trait!(pub Xor, xor);
impl_bdd_2w_to_1w_trait!(
Add,
add,
u32,
32,
circuits::u32::add_codegen::AnyBitCircuit,
circuits::u32::add_codegen::OUTPUT_CIRCUITS
);
impl_bdd_2w_to_1w_trait!(
Sub,
sub,
u32,
32,
circuits::u32::sub_codegen::AnyBitCircuit,
circuits::u32::sub_codegen::OUTPUT_CIRCUITS
);
impl_bdd_2w_to_1w_trait!(
Sll,
sll,
u32,
32,
circuits::u32::sll_codegen::AnyBitCircuit,
circuits::u32::sll_codegen::OUTPUT_CIRCUITS
);
impl_bdd_2w_to_1w_trait!(
Sra,
sra,
u32,
32,
circuits::u32::sra_codegen::AnyBitCircuit,
circuits::u32::sra_codegen::OUTPUT_CIRCUITS
);
impl_bdd_2w_to_1w_trait!(
Srl,
srl,
u32,
32,
circuits::u32::srl_codegen::AnyBitCircuit,
circuits::u32::srl_codegen::OUTPUT_CIRCUITS
);
impl_bdd_2w_to_1w_trait!(
Slt,
slt,
u32,
1,
circuits::u32::slt_codegen::AnyBitCircuit,
circuits::u32::slt_codegen::OUTPUT_CIRCUITS
);
impl_bdd_2w_to_1w_trait!(
Sltu,
sltu,
u32,
1,
circuits::u32::sltu_codegen::AnyBitCircuit,
circuits::u32::sltu_codegen::OUTPUT_CIRCUITS
);
impl_bdd_2w_to_1w_trait!(
And,
and,
u32,
32,
circuits::u32::and_codegen::AnyBitCircuit,
circuits::u32::and_codegen::OUTPUT_CIRCUITS
);
impl_bdd_2w_to_1w_trait!(
Or,
or,
u32,
1,
circuits::u32::or_codegen::AnyBitCircuit,
circuits::u32::or_codegen::OUTPUT_CIRCUITS
);
impl_bdd_2w_to_1w_trait!(
Xor,
xor,
u32,
1,
circuits::u32::xor_codegen::AnyBitCircuit,
circuits::u32::xor_codegen::OUTPUT_CIRCUITS
);

View File

@@ -0,0 +1,215 @@
use std::marker::PhantomData;
use poulpy_core::layouts::{Base2K, GLWECiphertext, GLWEInfos, GLWEPlaintextLayout, LWEInfos, Rank, TorusPrecision};
use poulpy_core::{TakeGLWEPt, layouts::prepared::GLWESecretPrepared};
use poulpy_hal::api::VecZnxBigAllocBytes;
#[cfg(test)]
use poulpy_hal::api::{
ScratchAvailable, TakeVecZnx, VecZnxAddInplace, VecZnxAddNormal, VecZnxFillUniform, VecZnxNormalize, VecZnxSub,
};
#[cfg(test)]
use poulpy_hal::source::Source;
use poulpy_hal::{
api::{
TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxDftAllocBytes,
VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalizeTmpBytes,
},
layouts::{Backend, Data, DataMut, DataRef, Module, Scratch},
};
use poulpy_hal::api::{SvpApplyDftToDftInplace, VecZnxNormalizeInplace, VecZnxSubInplace};
use crate::tfhe::bdd_arithmetic::{FromBits, ToBits, UnsignedInteger};
/// An FHE ciphertext encrypting the bits of an [UnsignedInteger].
pub struct FheUintBlocks<D: Data, T: UnsignedInteger> {
pub(crate) blocks: Vec<GLWECiphertext<D>>,
pub(crate) _base: u8,
pub(crate) _phantom: PhantomData<T>,
}
impl<D: DataRef, T: UnsignedInteger> LWEInfos for FheUintBlocks<D, T> {
fn base2k(&self) -> poulpy_core::layouts::Base2K {
self.blocks[0].base2k()
}
fn k(&self) -> poulpy_core::layouts::TorusPrecision {
self.blocks[0].k()
}
fn n(&self) -> poulpy_core::layouts::Degree {
self.blocks[0].n()
}
}
impl<D: DataRef, T: UnsignedInteger> GLWEInfos for FheUintBlocks<D, T> {
fn rank(&self) -> poulpy_core::layouts::Rank {
self.blocks[0].rank()
}
}
impl<T: UnsignedInteger> FheUintBlocks<Vec<u8>, T> {
#[allow(dead_code)]
pub(crate) fn alloc<A, BE: Backend>(module: &Module<BE>, infos: &A) -> Self
where
A: GLWEInfos,
{
Self::alloc_with(module, infos.base2k(), infos.k(), infos.rank())
}
#[allow(dead_code)]
pub(crate) fn alloc_with<BE: Backend>(module: &Module<BE>, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self {
Self {
blocks: (0..T::WORD_SIZE)
.map(|_| GLWECiphertext::alloc_with(module.n().into(), base2k, k, rank))
.collect(),
_base: 1,
_phantom: PhantomData,
}
}
}
impl<D: DataMut, T: UnsignedInteger + ToBits> FheUintBlocks<D, T> {
#[allow(dead_code)]
#[cfg(test)]
pub(crate) fn encrypt_sk<S, BE: Backend>(
&mut self,
module: &Module<BE>,
value: T,
sk: &GLWESecretPrepared<S, BE>,
source_xa: &mut Source,
source_xe: &mut Source,
scratch: &mut Scratch<BE>,
) where
S: DataRef,
Module<BE>: VecZnxDftAllocBytes
+ VecZnxBigNormalize<BE>
+ VecZnxDftApply<BE>
+ SvpApplyDftToDftInplace<BE>
+ VecZnxIdftApplyConsume<BE>
+ VecZnxNormalizeTmpBytes
+ VecZnxFillUniform
+ VecZnxSubInplace
+ VecZnxAddInplace
+ VecZnxNormalizeInplace<BE>
+ VecZnxAddNormal
+ VecZnxNormalize<BE>
+ VecZnxSub,
Scratch<BE>: TakeVecZnxDft<BE> + ScratchAvailable + TakeVecZnx + TakeGLWEPt<BE>,
{
use poulpy_core::layouts::GLWEPlaintextLayout;
#[cfg(debug_assertions)]
{
assert!(module.n().is_multiple_of(T::WORD_SIZE));
assert_eq!(self.n(), module.n() as u32);
assert_eq!(sk.n(), module.n() as u32);
}
let pt_infos = GLWEPlaintextLayout {
n: self.n(),
base2k: self.base2k(),
k: 1_usize.into(),
};
let (mut pt, scratch_1) = scratch.take_glwe_pt(&pt_infos);
for i in 0..T::WORD_SIZE {
pt.encode_coeff_i64(value.bit(i) as i64, TorusPrecision(1), 0);
self.blocks[i].encrypt_sk(&module, &pt, sk, source_xa, source_xe, scratch_1);
}
}
}
impl<D: DataRef, T: UnsignedInteger + FromBits + ToBits> FheUintBlocks<D, T> {
pub fn decrypt<S: DataRef, BE: Backend>(
&self,
module: &Module<BE>,
sk: &GLWESecretPrepared<S, BE>,
scratch: &mut Scratch<BE>,
) -> T
where
Module<BE>: VecZnxDftApply<BE>
+ SvpApplyDftToDftInplace<BE>
+ VecZnxIdftApplyConsume<BE>
+ VecZnxBigAddInplace<BE>
+ VecZnxBigAddSmallInplace<BE>
+ VecZnxBigNormalize<BE>,
Scratch<BE>: TakeVecZnxDft<BE> + TakeVecZnxBig<BE> + TakeGLWEPt<BE>,
{
#[cfg(debug_assertions)]
{
assert!(module.n().is_multiple_of(T::WORD_SIZE));
assert_eq!(self.n(), module.n() as u32);
assert_eq!(sk.n(), module.n() as u32);
}
let pt_infos = GLWEPlaintextLayout {
n: self.n(),
base2k: self.base2k(),
k: self.k(),
};
let (mut pt, scratch_1) = scratch.take_glwe_pt(&pt_infos);
let mut bits: Vec<u8> = vec![0u8; T::WORD_SIZE];
let base2k: usize = self.base2k().into();
let scale: f64 = 4.0 / ((1 << base2k) as f64);
for (i, bit) in bits.iter_mut().enumerate().take(T::WORD_SIZE) {
self.blocks[i].decrypt(module, &mut pt, sk, scratch_1);
let value: i64 = pt.decode_coeff_i64(base2k.into(), 0);
*bit = ((value as f64) * scale).round() as u8;
}
T::from_bits(&bits)
}
pub fn noise<S: DataRef, BE: Backend>(
&self,
module: &Module<BE>,
sk: &GLWESecretPrepared<S, BE>,
want: T,
scratch: &mut Scratch<BE>,
) -> Vec<f64>
where
Module<BE>: VecZnxDftAllocBytes
+ VecZnxBigAllocBytes
+ VecZnxDftApply<BE>
+ SvpApplyDftToDftInplace<BE>
+ VecZnxIdftApplyConsume<BE>
+ VecZnxBigAddInplace<BE>
+ VecZnxBigAddSmallInplace<BE>
+ VecZnxBigNormalize<BE>
+ VecZnxNormalizeTmpBytes
+ VecZnxSubInplace
+ VecZnxNormalizeInplace<BE>,
Scratch<BE>: TakeGLWEPt<BE> + TakeVecZnxDft<BE> + TakeVecZnxBig<BE>,
{
#[cfg(debug_assertions)]
{
assert!(module.n().is_multiple_of(T::WORD_SIZE));
assert_eq!(self.n(), module.n() as u32);
assert_eq!(sk.n(), module.n() as u32);
}
let pt_infos = GLWEPlaintextLayout {
n: self.n(),
base2k: self.base2k(),
k: 1_usize.into(),
};
let (mut pt_want, scratch_1) = scratch.take_glwe_pt(&pt_infos);
let mut noise: Vec<f64> = vec![0f64; T::WORD_SIZE];
for (i, noise_i) in noise.iter_mut().enumerate().take(T::WORD_SIZE) {
pt_want.encode_coeff_i64(want.bit(i) as i64, TorusPrecision(2), 0);
*noise_i = self.blocks[i].noise(module, sk, &pt_want, scratch_1);
}
noise
}
}

View File

@@ -0,0 +1,282 @@
use std::marker::PhantomData;
use poulpy_core::layouts::{
Base2K, Dnum, Dsize, GGSWInfos, GLWEInfos, LWEInfos, Rank, TorusPrecision, prepared::GGSWCiphertextPrepared,
};
#[cfg(test)]
use poulpy_core::{
TakeGGSW,
layouts::{GGSWCiphertext, prepared::GLWESecretPrepared},
};
use poulpy_hal::{
api::VmpPMatAlloc,
layouts::{Backend, Data, DataMut, DataRef, Module, Scratch},
};
#[cfg(test)]
use poulpy_hal::{
api::{
ScratchAvailable, SvpApplyDftToDftInplace, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal,
VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAllocBytes,
VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform,
VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub,
VecZnxSubInplace, VmpPrepare,
},
oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl},
source::Source,
};
use crate::tfhe::bdd_arithmetic::{FheUintBlocks, FheUintPrepare, ToBits, UnsignedInteger};
#[cfg(test)]
pub(crate) struct FheUintBlocksPrepDebug<D: Data, T: UnsignedInteger> {
pub(crate) blocks: Vec<GGSWCiphertext<D>>,
pub(crate) _base: u8,
pub(crate) _phantom: PhantomData<T>,
}
#[cfg(test)]
impl<T: UnsignedInteger> FheUintBlocksPrepDebug<Vec<u8>, T> {
#[allow(dead_code)]
pub(crate) fn alloc<A, BE: Backend>(module: &Module<BE>, infos: &A) -> Self
where
A: GGSWInfos,
{
Self::alloc_with(
module,
infos.base2k(),
infos.k(),
infos.dnum(),
infos.dsize(),
infos.rank(),
)
}
#[allow(dead_code)]
pub(crate) fn alloc_with<BE: Backend>(
module: &Module<BE>,
base2k: Base2K,
k: TorusPrecision,
dnum: Dnum,
dsize: Dsize,
rank: Rank,
) -> Self {
Self {
blocks: (0..T::WORD_SIZE)
.map(|_| GGSWCiphertext::alloc_with(module.n().into(), base2k, k, rank, dnum, dsize))
.collect(),
_base: 1,
_phantom: PhantomData,
}
}
}
/// A prepared FHE ciphertext encrypting the bits of an [UnsignedInteger].
pub struct FheUintBlocksPrep<D: Data, B: Backend, T: UnsignedInteger> {
pub(crate) blocks: Vec<GGSWCiphertextPrepared<D, B>>,
pub(crate) _base: u8,
pub(crate) _phantom: PhantomData<T>,
}
impl<T: UnsignedInteger, BE: Backend> FheUintBlocksPrep<Vec<u8>, BE, T>
where
Module<BE>: VmpPMatAlloc<BE>,
{
#[allow(dead_code)]
pub(crate) fn alloc<A>(module: &Module<BE>, infos: &A) -> Self
where
A: GGSWInfos,
{
Self::alloc_with(
module,
infos.base2k(),
infos.k(),
infos.dnum(),
infos.dsize(),
infos.rank(),
)
}
#[allow(dead_code)]
pub(crate) fn alloc_with(module: &Module<BE>, base2k: Base2K, k: TorusPrecision, dnum: Dnum, dsize: Dsize, rank: Rank) -> Self
where
Module<BE>: VmpPMatAlloc<BE>,
{
Self {
blocks: (0..T::WORD_SIZE)
.map(|_| GGSWCiphertextPrepared::alloc_with(module, base2k, k, dnum, dsize, rank))
.collect(),
_base: 1,
_phantom: PhantomData,
}
}
}
impl<D: DataMut, T: UnsignedInteger + ToBits, BE: Backend> FheUintBlocksPrep<D, BE, T> {
#[allow(dead_code)]
#[cfg(test)]
pub(crate) fn encrypt_sk<S>(
&mut self,
module: &Module<BE>,
value: T,
sk: &GLWESecretPrepared<S, BE>,
source_xa: &mut Source,
source_xe: &mut Source,
scratch: &mut Scratch<BE>,
) where
S: DataRef,
Module<BE>: VecZnxAddScalarInplace
+ VecZnxDftAllocBytes
+ VecZnxBigNormalize<BE>
+ VecZnxDftApply<BE>
+ SvpApplyDftToDftInplace<BE>
+ VecZnxIdftApplyConsume<BE>
+ VecZnxNormalizeTmpBytes
+ VecZnxFillUniform
+ VecZnxSubInplace
+ VecZnxAddInplace
+ VecZnxNormalizeInplace<BE>
+ VecZnxAddNormal
+ VecZnxNormalize<BE>
+ VecZnxSub
+ VmpPrepare<BE>,
Scratch<BE>: TakeVecZnxDft<BE> + ScratchAvailable + TakeVecZnx + TakeGGSW + TakeScalarZnx,
{
#[cfg(debug_assertions)]
{
assert!(module.n().is_multiple_of(T::WORD_SIZE));
assert_eq!(self.n(), module.n() as u32);
assert_eq!(sk.n(), module.n() as u32);
}
let (mut tmp_ggsw, scratch_1) = scratch.take_ggsw(self);
let (mut pt, scratch_2) = scratch_1.take_scalar_znx(module.n(), 1);
for i in 0..T::WORD_SIZE {
use poulpy_core::layouts::prepared::Prepare;
use poulpy_hal::layouts::ZnxViewMut;
pt.at_mut(0, 0)[0] = value.bit(i) as i64;
tmp_ggsw.encrypt_sk(&module, &pt, sk, source_xa, source_xe, scratch_2);
self.blocks[i].prepare(module, &tmp_ggsw, scratch_2);
}
}
/// Prepares [FheUintBits] to [FheUintBitsPrep].
pub fn prepare<BIT, KEY>(&mut self, module: &Module<BE>, bits: &FheUintBlocks<BIT, T>, key: &KEY, scratch: &mut Scratch<BE>)
where
BIT: DataRef,
KEY: FheUintPrepare<BE, FheUintBlocksPrep<D, BE, T>, FheUintBlocks<BIT, T>>,
{
key.prepare(module, self, bits, scratch);
}
}
#[cfg(test)]
impl<D: DataMut, T: UnsignedInteger + ToBits> FheUintBlocksPrepDebug<D, T> {
pub(crate) fn prepare<BIT, KEY, BE: Backend>(
&mut self,
module: &Module<BE>,
bits: &FheUintBlocks<BIT, T>,
key: &KEY,
scratch: &mut Scratch<BE>,
) where
BIT: DataRef,
KEY: FheUintPrepare<BE, FheUintBlocksPrepDebug<D, T>, FheUintBlocks<BIT, T>>,
{
key.prepare(module, self, bits, scratch);
}
}
#[cfg(test)]
impl<D: DataRef, T: UnsignedInteger + ToBits> FheUintBlocksPrepDebug<D, T> {
#[allow(dead_code)]
pub(crate) fn noise<S: DataRef, BE: Backend>(&self, module: &Module<BE>, sk: &GLWESecretPrepared<S, BE>, want: T)
where
Module<BE>: VecZnxDftAllocBytes
+ VecZnxBigAllocBytes
+ VecZnxDftApply<BE>
+ SvpApplyDftToDftInplace<BE>
+ VecZnxIdftApplyConsume<BE>
+ VecZnxBigAddInplace<BE>
+ VecZnxBigAddSmallInplace<BE>
+ VecZnxBigNormalize<BE>
+ VecZnxNormalizeTmpBytes
+ VecZnxBigAlloc<BE>
+ VecZnxDftAlloc<BE>
+ VecZnxBigNormalizeTmpBytes
+ VecZnxIdftApplyTmpA<BE>
+ VecZnxAddScalarInplace
+ VecZnxSubInplace,
BE: Backend + TakeVecZnxDftImpl<BE> + TakeVecZnxBigImpl<BE> + ScratchOwnedAllocImpl<BE> + ScratchOwnedBorrowImpl<BE>,
{
for (i, ggsw) in self.blocks.iter().enumerate() {
use poulpy_hal::layouts::{ScalarZnx, ZnxViewMut};
let mut pt_want = ScalarZnx::alloc(self.n().into(), 1);
pt_want.at_mut(0, 0)[0] = want.bit(i) as i64;
ggsw.print_noise(module, sk, &pt_want);
}
}
}
impl<D: DataRef, T: UnsignedInteger, B: Backend> LWEInfos for FheUintBlocksPrep<D, B, T> {
fn base2k(&self) -> poulpy_core::layouts::Base2K {
self.blocks[0].base2k()
}
fn k(&self) -> poulpy_core::layouts::TorusPrecision {
self.blocks[0].k()
}
fn n(&self) -> poulpy_core::layouts::Degree {
self.blocks[0].n()
}
}
impl<D: DataRef, T: UnsignedInteger, B: Backend> GLWEInfos for FheUintBlocksPrep<D, B, T> {
fn rank(&self) -> poulpy_core::layouts::Rank {
self.blocks[0].rank()
}
}
impl<D: DataRef, T: UnsignedInteger, B: Backend> GGSWInfos for FheUintBlocksPrep<D, B, T> {
fn dsize(&self) -> poulpy_core::layouts::Dsize {
self.blocks[0].dsize()
}
fn dnum(&self) -> poulpy_core::layouts::Dnum {
self.blocks[0].dnum()
}
}
#[cfg(test)]
impl<D: DataRef, T: UnsignedInteger> LWEInfos for FheUintBlocksPrepDebug<D, T> {
fn base2k(&self) -> poulpy_core::layouts::Base2K {
self.blocks[0].base2k()
}
fn k(&self) -> poulpy_core::layouts::TorusPrecision {
self.blocks[0].k()
}
fn n(&self) -> poulpy_core::layouts::Degree {
self.blocks[0].n()
}
}
#[cfg(test)]
impl<D: DataRef, T: UnsignedInteger> GLWEInfos for FheUintBlocksPrepDebug<D, T> {
fn rank(&self) -> poulpy_core::layouts::Rank {
self.blocks[0].rank()
}
}
#[cfg(test)]
impl<D: DataRef, T: UnsignedInteger> GGSWInfos for FheUintBlocksPrepDebug<D, T> {
fn dsize(&self) -> poulpy_core::layouts::Dsize {
self.blocks[0].dsize()
}
fn dnum(&self) -> poulpy_core::layouts::Dnum {
self.blocks[0].dnum()
}
}

View File

@@ -0,0 +1,7 @@
mod block;
mod block_prepared;
mod word;
pub use block::*;
pub use block_prepared::*;
pub use word::*;

View File

@@ -0,0 +1,198 @@
use itertools::Itertools;
use poulpy_core::{
GLWEOperations, TakeGLWECtSlice, TakeGLWEPt, glwe_packing,
layouts::{
GLWECiphertext, GLWEInfos, GLWEPlaintextLayout, LWEInfos, TorusPrecision,
prepared::{GGLWEAutomorphismKeyPrepared, GLWESecretPrepared},
},
};
use poulpy_hal::{
api::{
ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal,
VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace,
VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxCopy,
VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA,
VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace,
VecZnxRshInplace, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd,
VmpApplyDftToDftTmpBytes,
},
layouts::{Backend, Data, DataMut, DataRef, Module, Scratch},
source::Source,
};
use std::{collections::HashMap, marker::PhantomData};
use crate::tfhe::bdd_arithmetic::{FromBits, ToBits, UnsignedInteger};
/// A FHE ciphertext encrypting a [UnsignedInteger].
pub struct FheUintWord<D: Data, T: UnsignedInteger>(pub(crate) GLWECiphertext<D>, pub(crate) PhantomData<T>);
impl<D: DataMut, T: UnsignedInteger> FheUintWord<D, T> {
#[allow(dead_code)]
fn post_process<ATK, BE: Backend>(
&mut self,
module: &Module<BE>,
mut tmp_res: Vec<GLWECiphertext<&mut [u8]>>,
auto_keys: &HashMap<i64, GGLWEAutomorphismKeyPrepared<ATK, BE>>,
scratch: &mut Scratch<BE>,
) where
ATK: DataRef,
Module<BE>: VecZnxSub
+ VecZnxCopy
+ VecZnxNegateInplace
+ VecZnxDftAllocBytes
+ VecZnxAddInplace
+ VmpApplyDftToDftTmpBytes
+ VecZnxNormalizeTmpBytes
+ VecZnxDftApply<BE>
+ VmpApplyDftToDft<BE>
+ VmpApplyDftToDftAdd<BE>
+ VecZnxIdftApplyConsume<BE>
+ VecZnxBigNormalize<BE>
+ VecZnxNormalize<BE>
+ VecZnxRotateInplace<BE>
+ VecZnxNormalizeInplace<BE>
+ VecZnxSwitchRing
+ VecZnxBigAutomorphismInplace<BE>
+ VecZnxRshInplace<BE>
+ VecZnxDftCopy<BE>
+ VecZnxIdftApplyTmpA<BE>
+ VecZnxSubInplace
+ VecZnxBigNormalizeTmpBytes
+ VecZnxBigAddSmallInplace<BE>
+ VecZnxAutomorphismInplace<BE>
+ VecZnxBigSubSmallNegateInplace<BE>
+ VecZnxRotate,
Scratch<BE>: TakeVecZnxDft<BE> + ScratchAvailable + TakeVecZnx + TakeGLWECtSlice,
{
// Repacks the GLWE ciphertexts bits
let gap: usize = module.n() / T::WORD_SIZE;
let log_gap: usize = (usize::BITS - (gap - 1).leading_zeros()) as usize;
let mut cts: HashMap<usize, &mut GLWECiphertext<&mut [u8]>> = HashMap::new();
for (i, ct) in tmp_res.iter_mut().enumerate().take(T::WORD_SIZE) {
cts.insert(i * gap, ct);
}
glwe_packing(module, &mut cts, log_gap, auto_keys, scratch);
// And copies the repacked ciphertext on the receiver.
self.0.copy(module, cts.remove(&0).unwrap())
}
}
impl<D: DataRef, T: UnsignedInteger> LWEInfos for FheUintWord<D, T> {
fn base2k(&self) -> poulpy_core::layouts::Base2K {
self.0.base2k()
}
fn k(&self) -> poulpy_core::layouts::TorusPrecision {
self.0.k()
}
fn n(&self) -> poulpy_core::layouts::Degree {
self.0.n()
}
}
impl<D: DataRef, T: UnsignedInteger> GLWEInfos for FheUintWord<D, T> {
fn rank(&self) -> poulpy_core::layouts::Rank {
self.0.rank()
}
}
impl<D: DataMut, T: UnsignedInteger + ToBits> FheUintWord<D, T> {
pub fn encrypt_sk<S: DataRef, BE: Backend>(
&mut self,
module: &Module<BE>,
data: T,
sk: &GLWESecretPrepared<S, BE>,
source_xa: &mut Source,
source_xe: &mut Source,
scratch: &mut Scratch<BE>,
) where
Module<BE>: VecZnxAddScalarInplace
+ VecZnxDftAllocBytes
+ VecZnxBigNormalize<BE>
+ VecZnxDftApply<BE>
+ SvpApplyDftToDftInplace<BE>
+ VecZnxIdftApplyConsume<BE>
+ VecZnxNormalizeTmpBytes
+ VecZnxFillUniform
+ VecZnxSubInplace
+ VecZnxAddInplace
+ VecZnxNormalizeInplace<BE>
+ VecZnxAddNormal
+ VecZnxNormalize<BE>
+ VecZnxSub,
Scratch<BE>: TakeVecZnxDft<BE> + ScratchAvailable + TakeVecZnx + TakeGLWEPt<BE>,
{
#[cfg(debug_assertions)]
{
assert!(module.n().is_multiple_of(T::WORD_SIZE));
assert_eq!(self.n(), module.n() as u32);
assert_eq!(sk.n(), module.n() as u32);
}
let gap: usize = module.n() / T::WORD_SIZE;
let mut data_bits: Vec<i64> = vec![0i64; module.n()];
for i in 0..T::WORD_SIZE {
data_bits[i * gap] = data.bit(i) as i64
}
let pt_infos = GLWEPlaintextLayout {
n: self.n(),
base2k: self.base2k(),
k: 1_usize.into(),
};
let (mut pt, scratch_1) = scratch.take_glwe_pt(&pt_infos);
pt.encode_vec_i64(&data_bits, TorusPrecision(1));
self.0
.encrypt_sk(module, &pt, sk, source_xa, source_xe, scratch_1);
}
}
impl<D: DataRef, T: UnsignedInteger + FromBits> FheUintWord<D, T> {
pub fn decrypt<S: DataRef, BE: Backend>(
&self,
module: &Module<BE>,
sk: &GLWESecretPrepared<S, BE>,
scratch: &mut Scratch<BE>,
) -> T
where
Module<BE>: VecZnxDftApply<BE>
+ SvpApplyDftToDftInplace<BE>
+ VecZnxIdftApplyConsume<BE>
+ VecZnxBigAddInplace<BE>
+ VecZnxBigAddSmallInplace<BE>
+ VecZnxBigNormalize<BE>,
Scratch<BE>: TakeVecZnxDft<BE> + TakeVecZnxBig<BE> + TakeGLWEPt<BE>,
{
#[cfg(debug_assertions)]
{
assert!(module.n().is_multiple_of(T::WORD_SIZE));
assert_eq!(self.n(), module.n() as u32);
assert_eq!(sk.n(), module.n() as u32);
}
let gap: usize = module.n() / T::WORD_SIZE;
let pt_infos = GLWEPlaintextLayout {
n: self.n(),
base2k: self.base2k(),
k: 1_usize.into(),
};
let (mut pt, scratch_1) = scratch.take_glwe_pt(&pt_infos);
self.0.decrypt(module, &mut pt, sk, scratch_1);
let mut data: Vec<i64> = vec![0i64; module.n()];
pt.decode_vec_i64(&mut data, TorusPrecision(1));
let bits: Vec<u8> = data.iter().step_by(gap).map(|c| *c as u8).collect_vec();
T::from_bits(&bits)
}
}

View File

@@ -0,0 +1 @@
pub mod u32;

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,465 @@
use crate::tfhe::bdd_arithmetic::{BitCircuit, BitCircuitInfo, Circuit, GetBitCircuitInfo, Node};
pub(crate) enum AnyBitCircuit {
B0(BitCircuit<3, 2>),
B1(BitCircuit<3, 2>),
B2(BitCircuit<3, 2>),
B3(BitCircuit<3, 2>),
B4(BitCircuit<3, 2>),
B5(BitCircuit<3, 2>),
B6(BitCircuit<3, 2>),
B7(BitCircuit<3, 2>),
B8(BitCircuit<3, 2>),
B9(BitCircuit<3, 2>),
B10(BitCircuit<3, 2>),
B11(BitCircuit<3, 2>),
B12(BitCircuit<3, 2>),
B13(BitCircuit<3, 2>),
B14(BitCircuit<3, 2>),
B15(BitCircuit<3, 2>),
B16(BitCircuit<3, 2>),
B17(BitCircuit<3, 2>),
B18(BitCircuit<3, 2>),
B19(BitCircuit<3, 2>),
B20(BitCircuit<3, 2>),
B21(BitCircuit<3, 2>),
B22(BitCircuit<3, 2>),
B23(BitCircuit<3, 2>),
B24(BitCircuit<3, 2>),
B25(BitCircuit<3, 2>),
B26(BitCircuit<3, 2>),
B27(BitCircuit<3, 2>),
B28(BitCircuit<3, 2>),
B29(BitCircuit<3, 2>),
B30(BitCircuit<3, 2>),
B31(BitCircuit<3, 2>),
}
impl BitCircuitInfo for AnyBitCircuit {
fn info(&self) -> (&[Node], &[usize], usize) {
match self {
AnyBitCircuit::B0(bit_circuit) => (
bit_circuit.nodes.as_ref(),
bit_circuit.levels.as_ref(),
bit_circuit.max_inter_state,
),
AnyBitCircuit::B1(bit_circuit) => (
bit_circuit.nodes.as_ref(),
bit_circuit.levels.as_ref(),
bit_circuit.max_inter_state,
),
AnyBitCircuit::B2(bit_circuit) => (
bit_circuit.nodes.as_ref(),
bit_circuit.levels.as_ref(),
bit_circuit.max_inter_state,
),
AnyBitCircuit::B3(bit_circuit) => (
bit_circuit.nodes.as_ref(),
bit_circuit.levels.as_ref(),
bit_circuit.max_inter_state,
),
AnyBitCircuit::B4(bit_circuit) => (
bit_circuit.nodes.as_ref(),
bit_circuit.levels.as_ref(),
bit_circuit.max_inter_state,
),
AnyBitCircuit::B5(bit_circuit) => (
bit_circuit.nodes.as_ref(),
bit_circuit.levels.as_ref(),
bit_circuit.max_inter_state,
),
AnyBitCircuit::B6(bit_circuit) => (
bit_circuit.nodes.as_ref(),
bit_circuit.levels.as_ref(),
bit_circuit.max_inter_state,
),
AnyBitCircuit::B7(bit_circuit) => (
bit_circuit.nodes.as_ref(),
bit_circuit.levels.as_ref(),
bit_circuit.max_inter_state,
),
AnyBitCircuit::B8(bit_circuit) => (
bit_circuit.nodes.as_ref(),
bit_circuit.levels.as_ref(),
bit_circuit.max_inter_state,
),
AnyBitCircuit::B9(bit_circuit) => (
bit_circuit.nodes.as_ref(),
bit_circuit.levels.as_ref(),
bit_circuit.max_inter_state,
),
AnyBitCircuit::B10(bit_circuit) => (
bit_circuit.nodes.as_ref(),
bit_circuit.levels.as_ref(),
bit_circuit.max_inter_state,
),
AnyBitCircuit::B11(bit_circuit) => (
bit_circuit.nodes.as_ref(),
bit_circuit.levels.as_ref(),
bit_circuit.max_inter_state,
),
AnyBitCircuit::B12(bit_circuit) => (
bit_circuit.nodes.as_ref(),
bit_circuit.levels.as_ref(),
bit_circuit.max_inter_state,
),
AnyBitCircuit::B13(bit_circuit) => (
bit_circuit.nodes.as_ref(),
bit_circuit.levels.as_ref(),
bit_circuit.max_inter_state,
),
AnyBitCircuit::B14(bit_circuit) => (
bit_circuit.nodes.as_ref(),
bit_circuit.levels.as_ref(),
bit_circuit.max_inter_state,
),
AnyBitCircuit::B15(bit_circuit) => (
bit_circuit.nodes.as_ref(),
bit_circuit.levels.as_ref(),
bit_circuit.max_inter_state,
),
AnyBitCircuit::B16(bit_circuit) => (
bit_circuit.nodes.as_ref(),
bit_circuit.levels.as_ref(),
bit_circuit.max_inter_state,
),
AnyBitCircuit::B17(bit_circuit) => (
bit_circuit.nodes.as_ref(),
bit_circuit.levels.as_ref(),
bit_circuit.max_inter_state,
),
AnyBitCircuit::B18(bit_circuit) => (
bit_circuit.nodes.as_ref(),
bit_circuit.levels.as_ref(),
bit_circuit.max_inter_state,
),
AnyBitCircuit::B19(bit_circuit) => (
bit_circuit.nodes.as_ref(),
bit_circuit.levels.as_ref(),
bit_circuit.max_inter_state,
),
AnyBitCircuit::B20(bit_circuit) => (
bit_circuit.nodes.as_ref(),
bit_circuit.levels.as_ref(),
bit_circuit.max_inter_state,
),
AnyBitCircuit::B21(bit_circuit) => (
bit_circuit.nodes.as_ref(),
bit_circuit.levels.as_ref(),
bit_circuit.max_inter_state,
),
AnyBitCircuit::B22(bit_circuit) => (
bit_circuit.nodes.as_ref(),
bit_circuit.levels.as_ref(),
bit_circuit.max_inter_state,
),
AnyBitCircuit::B23(bit_circuit) => (
bit_circuit.nodes.as_ref(),
bit_circuit.levels.as_ref(),
bit_circuit.max_inter_state,
),
AnyBitCircuit::B24(bit_circuit) => (
bit_circuit.nodes.as_ref(),
bit_circuit.levels.as_ref(),
bit_circuit.max_inter_state,
),
AnyBitCircuit::B25(bit_circuit) => (
bit_circuit.nodes.as_ref(),
bit_circuit.levels.as_ref(),
bit_circuit.max_inter_state,
),
AnyBitCircuit::B26(bit_circuit) => (
bit_circuit.nodes.as_ref(),
bit_circuit.levels.as_ref(),
bit_circuit.max_inter_state,
),
AnyBitCircuit::B27(bit_circuit) => (
bit_circuit.nodes.as_ref(),
bit_circuit.levels.as_ref(),
bit_circuit.max_inter_state,
),
AnyBitCircuit::B28(bit_circuit) => (
bit_circuit.nodes.as_ref(),
bit_circuit.levels.as_ref(),
bit_circuit.max_inter_state,
),
AnyBitCircuit::B29(bit_circuit) => (
bit_circuit.nodes.as_ref(),
bit_circuit.levels.as_ref(),
bit_circuit.max_inter_state,
),
AnyBitCircuit::B30(bit_circuit) => (
bit_circuit.nodes.as_ref(),
bit_circuit.levels.as_ref(),
bit_circuit.max_inter_state,
),
AnyBitCircuit::B31(bit_circuit) => (
bit_circuit.nodes.as_ref(),
bit_circuit.levels.as_ref(),
bit_circuit.max_inter_state,
),
}
}
}
impl GetBitCircuitInfo<u32> for Circuit<AnyBitCircuit, 32usize> {
fn input_size(&self) -> usize {
2 * u32::BITS as usize
}
fn output_size(&self) -> usize {
u32::BITS as usize
}
fn get_circuit(&self, bit: usize) -> (&[Node], &[usize], usize) {
self.0[bit].info()
}
}
pub(crate) static OUTPUT_CIRCUITS: Circuit<AnyBitCircuit, 32usize> = Circuit([
AnyBitCircuit::B0(BitCircuit::new(
[Node::new(0, 0, 0), Node::new(32, 1, 0), Node::new(0, 1, 0)],
[0, 2],
2,
)),
AnyBitCircuit::B1(BitCircuit::new(
[Node::new(1, 0, 0), Node::new(33, 1, 0), Node::new(1, 1, 0)],
[0, 2],
2,
)),
AnyBitCircuit::B2(BitCircuit::new(
[Node::new(2, 0, 0), Node::new(34, 1, 0), Node::new(2, 1, 0)],
[0, 2],
2,
)),
AnyBitCircuit::B3(BitCircuit::new(
[Node::new(3, 0, 0), Node::new(35, 1, 0), Node::new(3, 1, 0)],
[0, 2],
2,
)),
AnyBitCircuit::B4(BitCircuit::new(
[Node::new(4, 0, 0), Node::new(36, 1, 0), Node::new(4, 1, 0)],
[0, 2],
2,
)),
AnyBitCircuit::B5(BitCircuit::new(
[Node::new(5, 0, 0), Node::new(37, 1, 0), Node::new(5, 1, 0)],
[0, 2],
2,
)),
AnyBitCircuit::B6(BitCircuit::new(
[Node::new(6, 0, 0), Node::new(38, 1, 0), Node::new(6, 1, 0)],
[0, 2],
2,
)),
AnyBitCircuit::B7(BitCircuit::new(
[Node::new(7, 0, 0), Node::new(39, 1, 0), Node::new(7, 1, 0)],
[0, 2],
2,
)),
AnyBitCircuit::B8(BitCircuit::new(
[Node::new(8, 0, 0), Node::new(40, 1, 0), Node::new(8, 1, 0)],
[0, 2],
2,
)),
AnyBitCircuit::B9(BitCircuit::new(
[Node::new(9, 0, 0), Node::new(41, 1, 0), Node::new(9, 1, 0)],
[0, 2],
2,
)),
AnyBitCircuit::B10(BitCircuit::new(
[
Node::new(10, 0, 0),
Node::new(42, 1, 0),
Node::new(10, 1, 0),
],
[0, 2],
2,
)),
AnyBitCircuit::B11(BitCircuit::new(
[
Node::new(11, 0, 0),
Node::new(43, 1, 0),
Node::new(11, 1, 0),
],
[0, 2],
2,
)),
AnyBitCircuit::B12(BitCircuit::new(
[
Node::new(12, 0, 0),
Node::new(44, 1, 0),
Node::new(12, 1, 0),
],
[0, 2],
2,
)),
AnyBitCircuit::B13(BitCircuit::new(
[
Node::new(13, 0, 0),
Node::new(45, 1, 0),
Node::new(13, 1, 0),
],
[0, 2],
2,
)),
AnyBitCircuit::B14(BitCircuit::new(
[
Node::new(14, 0, 0),
Node::new(46, 1, 0),
Node::new(14, 1, 0),
],
[0, 2],
2,
)),
AnyBitCircuit::B15(BitCircuit::new(
[
Node::new(15, 0, 0),
Node::new(47, 1, 0),
Node::new(15, 1, 0),
],
[0, 2],
2,
)),
AnyBitCircuit::B16(BitCircuit::new(
[
Node::new(16, 0, 0),
Node::new(48, 1, 0),
Node::new(16, 1, 0),
],
[0, 2],
2,
)),
AnyBitCircuit::B17(BitCircuit::new(
[
Node::new(17, 0, 0),
Node::new(49, 1, 0),
Node::new(17, 1, 0),
],
[0, 2],
2,
)),
AnyBitCircuit::B18(BitCircuit::new(
[
Node::new(18, 0, 0),
Node::new(50, 1, 0),
Node::new(18, 1, 0),
],
[0, 2],
2,
)),
AnyBitCircuit::B19(BitCircuit::new(
[
Node::new(19, 0, 0),
Node::new(51, 1, 0),
Node::new(19, 1, 0),
],
[0, 2],
2,
)),
AnyBitCircuit::B20(BitCircuit::new(
[
Node::new(20, 0, 0),
Node::new(52, 1, 0),
Node::new(20, 1, 0),
],
[0, 2],
2,
)),
AnyBitCircuit::B21(BitCircuit::new(
[
Node::new(21, 0, 0),
Node::new(53, 1, 0),
Node::new(21, 1, 0),
],
[0, 2],
2,
)),
AnyBitCircuit::B22(BitCircuit::new(
[
Node::new(22, 0, 0),
Node::new(54, 1, 0),
Node::new(22, 1, 0),
],
[0, 2],
2,
)),
AnyBitCircuit::B23(BitCircuit::new(
[
Node::new(23, 0, 0),
Node::new(55, 1, 0),
Node::new(23, 1, 0),
],
[0, 2],
2,
)),
AnyBitCircuit::B24(BitCircuit::new(
[
Node::new(24, 0, 0),
Node::new(56, 1, 0),
Node::new(24, 1, 0),
],
[0, 2],
2,
)),
AnyBitCircuit::B25(BitCircuit::new(
[
Node::new(25, 0, 0),
Node::new(57, 1, 0),
Node::new(25, 1, 0),
],
[0, 2],
2,
)),
AnyBitCircuit::B26(BitCircuit::new(
[
Node::new(26, 0, 0),
Node::new(58, 1, 0),
Node::new(26, 1, 0),
],
[0, 2],
2,
)),
AnyBitCircuit::B27(BitCircuit::new(
[
Node::new(27, 0, 0),
Node::new(59, 1, 0),
Node::new(27, 1, 0),
],
[0, 2],
2,
)),
AnyBitCircuit::B28(BitCircuit::new(
[
Node::new(28, 0, 0),
Node::new(60, 1, 0),
Node::new(28, 1, 0),
],
[0, 2],
2,
)),
AnyBitCircuit::B29(BitCircuit::new(
[
Node::new(29, 0, 0),
Node::new(61, 1, 0),
Node::new(29, 1, 0),
],
[0, 2],
2,
)),
AnyBitCircuit::B30(BitCircuit::new(
[
Node::new(30, 0, 0),
Node::new(62, 1, 0),
Node::new(30, 1, 0),
],
[0, 2],
2,
)),
AnyBitCircuit::B31(BitCircuit::new(
[
Node::new(31, 0, 0),
Node::new(63, 1, 0),
Node::new(31, 1, 0),
],
[0, 2],
2,
)),
]);

View File

@@ -0,0 +1,10 @@
pub(crate) mod add_codegen;
pub(crate) mod and_codegen;
pub(crate) mod or_codegen;
pub(crate) mod sll_codegen;
pub(crate) mod slt_codegen;
pub(crate) mod sltu_codegen;
pub(crate) mod sra_codegen;
pub(crate) mod srl_codegen;
pub(crate) mod sub_codegen;
pub(crate) mod xor_codegen;

View File

@@ -0,0 +1,34 @@
use crate::tfhe::bdd_arithmetic::{BitCircuit, BitCircuitInfo, Circuit, GetBitCircuitInfo, Node};
pub(crate) enum AnyBitCircuit {
B0(BitCircuit<3, 2>),
}
impl BitCircuitInfo for AnyBitCircuit {
fn info(&self) -> (&[Node], &[usize], usize) {
match self {
AnyBitCircuit::B0(bit_circuit) => (
bit_circuit.nodes.as_ref(),
bit_circuit.levels.as_ref(),
bit_circuit.max_inter_state,
),
}
}
}
impl GetBitCircuitInfo<u32> for Circuit<AnyBitCircuit, 1usize> {
fn input_size(&self) -> usize {
2 * u32::BITS as usize
}
fn output_size(&self) -> usize {
u32::BITS as usize
}
fn get_circuit(&self, _bit: usize) -> (&[Node], &[usize], usize) {
self.0[0].info()
}
}
pub(crate) static OUTPUT_CIRCUITS: Circuit<AnyBitCircuit, 1usize> = Circuit([AnyBitCircuit::B0(BitCircuit::new(
[Node::new(0, 0, 0), Node::new(1, 1, 0), Node::new(0, 1, 1)],
[0, 2],
2,
))]);

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,257 @@
use crate::tfhe::bdd_arithmetic::{BitCircuit, BitCircuitInfo, Circuit, GetBitCircuitInfo, Node};
pub(crate) enum AnyBitCircuit {
B0(BitCircuit<219, 64>),
}
impl BitCircuitInfo for AnyBitCircuit {
fn info(&self) -> (&[Node], &[usize], usize) {
match self {
AnyBitCircuit::B0(bit_circuit) => (
bit_circuit.nodes.as_ref(),
bit_circuit.levels.as_ref(),
bit_circuit.max_inter_state,
),
}
}
}
impl GetBitCircuitInfo<u32> for Circuit<AnyBitCircuit, 1usize> {
fn input_size(&self) -> usize {
2 * u32::BITS as usize
}
fn output_size(&self) -> usize {
1
}
fn get_circuit(&self, bit: usize) -> (&[Node], &[usize], usize) {
self.0[bit].info()
}
}
pub(crate) static OUTPUT_CIRCUITS: Circuit<AnyBitCircuit, 1usize> = Circuit([AnyBitCircuit::B0(BitCircuit::new(
[
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(32, 1, 0),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(0, 0, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(33, 1, 2),
Node::new(33, 2, 0),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(1, 3, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(34, 1, 2),
Node::new(34, 2, 0),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(2, 3, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(35, 2, 0),
Node::new(35, 1, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(3, 2, 3),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(36, 1, 2),
Node::new(36, 2, 0),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(4, 3, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(37, 1, 2),
Node::new(37, 2, 0),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(5, 3, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(38, 2, 0),
Node::new(38, 1, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(6, 2, 3),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(39, 2, 0),
Node::new(39, 1, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(7, 2, 3),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(40, 1, 2),
Node::new(40, 2, 0),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(8, 3, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(41, 1, 2),
Node::new(41, 2, 0),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(9, 3, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(42, 2, 0),
Node::new(42, 1, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(10, 2, 3),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(43, 2, 0),
Node::new(43, 1, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(11, 2, 3),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(44, 1, 2),
Node::new(44, 2, 0),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(12, 3, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(45, 1, 2),
Node::new(45, 2, 0),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(13, 3, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(46, 2, 0),
Node::new(46, 1, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(14, 2, 3),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(47, 1, 2),
Node::new(47, 2, 0),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(15, 3, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(48, 2, 0),
Node::new(48, 1, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(16, 2, 3),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(49, 1, 2),
Node::new(49, 2, 0),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(17, 3, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(50, 2, 0),
Node::new(50, 1, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(18, 2, 3),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(51, 2, 0),
Node::new(51, 1, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(19, 2, 3),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(52, 1, 2),
Node::new(52, 2, 0),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(20, 3, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(53, 2, 0),
Node::new(53, 1, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(21, 2, 3),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(54, 1, 2),
Node::new(54, 2, 0),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(22, 3, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(55, 1, 2),
Node::new(55, 2, 0),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(23, 3, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(56, 2, 0),
Node::new(56, 1, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(24, 2, 3),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(57, 1, 2),
Node::new(57, 2, 0),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(25, 3, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(58, 2, 0),
Node::new(58, 1, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(26, 2, 3),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(59, 2, 0),
Node::new(59, 1, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(27, 2, 3),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(60, 2, 0),
Node::new(60, 1, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(28, 2, 3),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(61, 1, 2),
Node::new(61, 2, 0),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(29, 3, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(62, 1, 2),
Node::new(62, 2, 0),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(30, 3, 2),
Node::new(63, 2, 1),
Node::new(63, 0, 2),
Node::new(31, 0, 1),
],
[
0, 3, 6, 10, 13, 17, 20, 24, 27, 31, 34, 38, 41, 45, 48, 52, 55, 59, 62, 66, 69, 73, 76, 80, 83, 87, 90, 94, 97, 101,
104, 108, 111, 115, 118, 122, 125, 129, 132, 136, 139, 143, 146, 150, 153, 157, 160, 164, 167, 171, 174, 178, 181, 185,
188, 192, 195, 199, 202, 206, 209, 213, 216, 218,
],
4,
))]);

View File

@@ -0,0 +1,257 @@
use crate::tfhe::bdd_arithmetic::{BitCircuit, BitCircuitInfo, Circuit, GetBitCircuitInfo, Node};
pub(crate) enum AnyBitCircuit {
B0(BitCircuit<219, 64>),
}
impl BitCircuitInfo for AnyBitCircuit {
fn info(&self) -> (&[Node], &[usize], usize) {
match self {
AnyBitCircuit::B0(bit_circuit) => (
bit_circuit.nodes.as_ref(),
bit_circuit.levels.as_ref(),
bit_circuit.max_inter_state,
),
}
}
}
impl GetBitCircuitInfo<u32> for Circuit<AnyBitCircuit, 1usize> {
fn input_size(&self) -> usize {
2 * u32::BITS as usize
}
fn output_size(&self) -> usize {
1
}
fn get_circuit(&self, bit: usize) -> (&[Node], &[usize], usize) {
self.0[bit].info()
}
}
pub(crate) static OUTPUT_CIRCUITS: Circuit<AnyBitCircuit, 1usize> = Circuit([AnyBitCircuit::B0(BitCircuit::new(
[
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(32, 1, 0),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(0, 0, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(33, 2, 0),
Node::new(33, 1, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(1, 2, 3),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(34, 2, 0),
Node::new(34, 1, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(2, 2, 3),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(35, 1, 2),
Node::new(35, 2, 0),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(3, 3, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(36, 2, 0),
Node::new(36, 1, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(4, 2, 3),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(37, 2, 0),
Node::new(37, 1, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(5, 2, 3),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(38, 1, 2),
Node::new(38, 2, 0),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(6, 3, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(39, 2, 0),
Node::new(39, 1, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(7, 2, 3),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(40, 1, 2),
Node::new(40, 2, 0),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(8, 3, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(41, 1, 2),
Node::new(41, 2, 0),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(9, 3, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(42, 1, 2),
Node::new(42, 2, 0),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(10, 3, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(43, 2, 0),
Node::new(43, 1, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(11, 2, 3),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(44, 1, 2),
Node::new(44, 2, 0),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(12, 3, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(45, 2, 0),
Node::new(45, 1, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(13, 2, 3),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(46, 1, 2),
Node::new(46, 2, 0),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(14, 3, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(47, 2, 0),
Node::new(47, 1, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(15, 2, 3),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(48, 2, 0),
Node::new(48, 1, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(16, 2, 3),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(49, 1, 2),
Node::new(49, 2, 0),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(17, 3, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(50, 1, 2),
Node::new(50, 2, 0),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(18, 3, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(51, 1, 2),
Node::new(51, 2, 0),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(19, 3, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(52, 2, 0),
Node::new(52, 1, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(20, 2, 3),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(53, 2, 0),
Node::new(53, 1, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(21, 2, 3),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(54, 2, 0),
Node::new(54, 1, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(22, 2, 3),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(55, 1, 2),
Node::new(55, 2, 0),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(23, 3, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(56, 2, 0),
Node::new(56, 1, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(24, 2, 3),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(57, 1, 2),
Node::new(57, 2, 0),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(25, 3, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(58, 2, 0),
Node::new(58, 1, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(26, 2, 3),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(59, 2, 0),
Node::new(59, 1, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(27, 2, 3),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(60, 2, 0),
Node::new(60, 1, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(28, 2, 3),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(61, 2, 0),
Node::new(61, 1, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(29, 2, 3),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(62, 2, 0),
Node::new(62, 1, 2),
Node::new(0, 0, 0),
Node::new(0, 1, 1),
Node::new(30, 2, 3),
Node::new(63, 2, 0),
Node::new(63, 1, 2),
Node::new(31, 0, 1),
],
[
0, 3, 6, 10, 13, 17, 20, 24, 27, 31, 34, 38, 41, 45, 48, 52, 55, 59, 62, 66, 69, 73, 76, 80, 83, 87, 90, 94, 97, 101,
104, 108, 111, 115, 118, 122, 125, 129, 132, 136, 139, 143, 146, 150, 153, 157, 160, 164, 167, 171, 174, 178, 181, 185,
188, 192, 195, 199, 202, 206, 209, 213, 216, 218,
],
4,
))]);

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,34 @@
use crate::tfhe::bdd_arithmetic::{BitCircuit, BitCircuitInfo, Circuit, GetBitCircuitInfo, Node};
pub(crate) enum AnyBitCircuit {
B0(BitCircuit<3, 2>),
}
impl BitCircuitInfo for AnyBitCircuit {
fn info(&self) -> (&[Node], &[usize], usize) {
match self {
AnyBitCircuit::B0(bit_circuit) => (
bit_circuit.nodes.as_ref(),
bit_circuit.levels.as_ref(),
bit_circuit.max_inter_state,
),
}
}
}
impl GetBitCircuitInfo<u32> for Circuit<AnyBitCircuit, 1usize> {
fn input_size(&self) -> usize {
2 * u32::BITS as usize
}
fn output_size(&self) -> usize {
u32::BITS as usize
}
fn get_circuit(&self, _bit: usize) -> (&[Node], &[usize], usize) {
self.0[0].info()
}
}
pub(crate) static OUTPUT_CIRCUITS: Circuit<AnyBitCircuit, 1usize> = Circuit([AnyBitCircuit::B0(BitCircuit::new(
[Node::new(1, 1, 0), Node::new(1, 0, 1), Node::new(0, 1, 0)],
[0, 2],
2,
))]);

View File

@@ -0,0 +1,198 @@
use itertools::Itertools;
use poulpy_core::{
GLWEExternalProductInplace, GLWEOperations, TakeGLWECtSlice,
layouts::{
GLWECiphertext, GLWECiphertextToMut, LWEInfos,
prepared::{GGSWCiphertextPrepared, GGSWCiphertextPreparedToRef},
},
};
use poulpy_hal::{
api::{VecZnxAddInplace, VecZnxCopy, VecZnxNegateInplace, VecZnxSub},
layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero},
};
use crate::tfhe::bdd_arithmetic::UnsignedInteger;
pub trait BitCircuitInfo {
fn info(&self) -> (&[Node], &[usize], usize);
}
pub trait GetBitCircuitInfo<T: UnsignedInteger> {
fn input_size(&self) -> usize;
fn output_size(&self) -> usize;
fn get_circuit(&self, bit: usize) -> (&[Node], &[usize], usize);
}
pub(crate) struct BitCircuit<const N: usize, const K: usize> {
pub(crate) nodes: [Node; N],
pub(crate) levels: [usize; K],
pub(crate) max_inter_state: usize,
}
pub struct Circuit<C: BitCircuitInfo, const N: usize>(pub [C; N]);
pub trait CircuitExecute<BE: Backend, T: UnsignedInteger>
where
Self: GetBitCircuitInfo<T>,
{
fn execute<O>(
&self,
module: &Module<BE>,
out: &mut [GLWECiphertext<O>],
inputs: &[&dyn GGSWCiphertextPreparedToRef<BE>],
scratch: &mut Scratch<BE>,
) where
O: DataMut;
}
impl<C: BitCircuitInfo, const N: usize, T: UnsignedInteger, BE: Backend> CircuitExecute<BE, T> for Circuit<C, N>
where
Self: GetBitCircuitInfo<T>,
Module<BE>: Cmux<BE> + VecZnxCopy,
Scratch<BE>: TakeGLWECtSlice,
{
fn execute<O>(
&self,
module: &Module<BE>,
out: &mut [GLWECiphertext<O>],
inputs: &[&dyn GGSWCiphertextPreparedToRef<BE>],
scratch: &mut Scratch<BE>,
) where
O: DataMut,
{
#[cfg(debug_assertions)]
{
assert_eq!(inputs.len(), self.input_size());
assert!(out.len() >= self.output_size());
}
for (i, out_i) in out.iter_mut().enumerate().take(self.output_size()) {
let (nodes, levels, max_inter_state) = self.get_circuit(i);
let (mut level, scratch_1) = scratch.take_glwe_ct_slice(max_inter_state * 2, out_i);
level.iter_mut().for_each(|ct| ct.data_mut().zero());
// TODO: implement API on GLWE
level[1]
.data_mut()
.encode_coeff_i64(out_i.base2k().into(), 0, 2, 0, 1);
let mut level_ref = level.iter_mut().collect_vec();
let (mut prev_level, mut next_level) = level_ref.split_at_mut(max_inter_state);
for i in 0..levels.len() - 1 {
let start: usize = levels[i];
let end: usize = levels[i + 1];
let nodes_lvl: &[Node] = &nodes[start..end];
for (j, node) in nodes_lvl.iter().enumerate() {
if node.low_index == node.high_index {
next_level[j].copy(module, prev_level[node.low_index]);
} else {
module.cmux(
next_level[j],
prev_level[node.high_index],
prev_level[node.low_index],
&inputs[node.input_index].to_ref(),
scratch_1,
);
}
}
(prev_level, next_level) = (next_level, prev_level);
}
// handle last output
// there's always only 1 node at last level
let node: &Node = nodes.last().unwrap();
module.cmux(
out_i,
prev_level[node.high_index],
prev_level[node.low_index],
&inputs[node.input_index].to_ref(),
scratch_1,
);
}
for out_i in out.iter_mut().skip(self.output_size()) {
out_i.data_mut().zero();
}
}
}
impl<const N: usize, const K: usize> BitCircuit<N, K> {
pub(crate) const fn new(nodes: [Node; N], levels: [usize; K], max_inter_state: usize) -> Self {
Self {
nodes,
levels,
max_inter_state,
}
}
}
impl<const N: usize, const K: usize> BitCircuitInfo for BitCircuit<N, K> {
fn info(&self) -> (&[Node], &[usize], usize) {
(
self.nodes.as_ref(),
self.levels.as_ref(),
self.max_inter_state,
)
}
}
#[derive(Debug)]
pub struct Node {
input_index: usize,
high_index: usize,
low_index: usize,
}
impl Node {
pub(crate) const fn new(input_index: usize, high_index: usize, low_index: usize) -> Self {
Self {
input_index,
high_index,
low_index,
}
}
}
pub trait Cmux<BE: Backend> {
fn cmux<O, T, F, S>(
&self,
out: &mut GLWECiphertext<O>,
t: &GLWECiphertext<T>,
f: &GLWECiphertext<F>,
s: &GGSWCiphertextPrepared<S, BE>,
scratch: &mut Scratch<BE>,
) where
O: DataMut,
T: DataRef,
F: DataRef,
S: DataRef;
}
impl<BE: Backend> Cmux<BE> for Module<BE>
where
Module<BE>: GLWEExternalProductInplace<BE> + VecZnxSub + VecZnxCopy + VecZnxNegateInplace + VecZnxAddInplace,
{
fn cmux<O, T, F, S>(
&self,
out: &mut GLWECiphertext<O>,
t: &GLWECiphertext<T>,
f: &GLWECiphertext<F>,
s: &GGSWCiphertextPrepared<S, BE>,
scratch: &mut Scratch<BE>,
) where
O: DataMut,
T: DataRef,
F: DataRef,
S: DataRef,
{
// let mut out: GLWECiphertext<&mut [u8]> = out.to_mut();
out.sub(self, t, f);
out.external_product_inplace(self, s, scratch);
out.to_mut().add_inplace(self, f);
}
}

View File

@@ -0,0 +1,241 @@
#[cfg(test)]
use crate::tfhe::bdd_arithmetic::FheUintBlocksPrepDebug;
use crate::tfhe::{
bdd_arithmetic::{FheUintBlocks, FheUintBlocksPrep, UnsignedInteger},
blind_rotation::{BlindRotationAlgo, BlindRotationKey, BlindRotationKeyAlloc, BlindRotationKeyEncryptSk},
circuit_bootstrapping::{
CircuitBootstrappingKey, CircuitBootstrappingKeyEncryptSk, CircuitBootstrappingKeyLayout,
CircuitBootstrappingKeyPrepared, CirtuitBootstrappingExecute,
},
};
use poulpy_core::{
TakeGGSW, TakeGLWECt,
layouts::{
GLWESecret, GLWEToLWEKey, GLWEToLWEKeyLayout, LWECiphertext, LWESecret,
prepared::{GLWEToLWESwitchingKeyPrepared, Prepare, PrepareAlloc},
},
};
use poulpy_hal::{
api::{
ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx,
TakeSvpPPol, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace,
VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize,
VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing,
VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPrepare,
},
layouts::{Backend, Data, DataMut, DataRef, Module, Scratch},
source::Source,
};
pub trait BDDKeyInfos {
fn cbt_infos(&self) -> CircuitBootstrappingKeyLayout;
fn ks_infos(&self) -> GLWEToLWEKeyLayout;
}
#[derive(Debug, Clone, Copy)]
pub struct BDDKeyLayout {
pub cbt: CircuitBootstrappingKeyLayout,
pub ks: GLWEToLWEKeyLayout,
}
impl BDDKeyInfos for BDDKeyLayout {
fn cbt_infos(&self) -> CircuitBootstrappingKeyLayout {
self.cbt
}
fn ks_infos(&self) -> GLWEToLWEKeyLayout {
self.ks
}
}
pub struct BDDKey<CBT, LWE, BRA>
where
CBT: Data,
LWE: Data,
BRA: BlindRotationAlgo,
{
cbt: CircuitBootstrappingKey<CBT, BRA>,
ks: GLWEToLWEKey<LWE>,
}
impl<BRA: BlindRotationAlgo> BDDKey<Vec<u8>, Vec<u8>, BRA> {
pub fn encrypt_sk<DLwe, DGlwe, A, BE: Backend>(
module: &Module<BE>,
sk_lwe: &LWESecret<DLwe>,
sk_glwe: &GLWESecret<DGlwe>,
infos: &A,
source_xa: &mut Source,
source_xe: &mut Source,
scratch: &mut Scratch<BE>,
) -> Self
where
A: BDDKeyInfos,
DLwe: DataRef,
DGlwe: DataRef,
BlindRotationKey<Vec<u8>, BRA>: BlindRotationKeyAlloc + BlindRotationKeyEncryptSk<BE>,
Module<BE>: SvpApplyDftToDft<BE>
+ VecZnxIdftApplyTmpA<BE>
+ VecZnxAddScalarInplace
+ VecZnxDftAllocBytes
+ VecZnxBigNormalize<BE>
+ VecZnxDftApply<BE>
+ SvpApplyDftToDftInplace<BE>
+ VecZnxIdftApplyConsume<BE>
+ VecZnxNormalizeTmpBytes
+ VecZnxFillUniform
+ VecZnxSubInplace
+ VecZnxAddInplace
+ VecZnxNormalizeInplace<BE>
+ VecZnxAddNormal
+ VecZnxNormalize<BE>
+ VecZnxSub
+ SvpPrepare<BE>
+ VecZnxSwitchRing
+ SvpPPolAllocBytes
+ SvpPPolAlloc<BE>
+ VecZnxAutomorphism
+ VecZnxAutomorphismInplace<BE>,
Scratch<BE>: TakeVecZnxDft<BE> + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeSvpPPol<BE> + TakeVecZnxBig<BE>,
{
let mut ks: GLWEToLWEKey<Vec<u8>> = GLWEToLWEKey::alloc(&infos.ks_infos());
ks.encrypt_sk(module, sk_lwe, sk_glwe, source_xa, source_xe, scratch);
Self {
cbt: CircuitBootstrappingKey::encrypt_sk(
module,
sk_lwe,
sk_glwe,
&infos.cbt_infos(),
source_xa,
source_xe,
scratch,
),
ks,
}
}
}
pub struct BDDKeyPrepared<CBT, LWE, BRA, BE>
where
CBT: Data,
LWE: Data,
BRA: BlindRotationAlgo,
BE: Backend,
{
cbt: CircuitBootstrappingKeyPrepared<CBT, BRA, BE>,
ks: GLWEToLWESwitchingKeyPrepared<LWE, BE>,
}
impl<CBT: DataMut, LWE: DataMut, BRA: BlindRotationAlgo, BE: Backend> PrepareAlloc<BE, BDDKeyPrepared<CBT, LWE, BRA, BE>>
for BDDKey<CBT, LWE, BRA>
where
CircuitBootstrappingKey<CBT, BRA>: PrepareAlloc<BE, CircuitBootstrappingKeyPrepared<CBT, BRA, BE>>,
GLWEToLWEKey<LWE>: PrepareAlloc<BE, GLWEToLWESwitchingKeyPrepared<LWE, BE>>,
{
fn prepare_alloc(&self, module: &Module<BE>, scratch: &mut Scratch<BE>) -> BDDKeyPrepared<CBT, LWE, BRA, BE> {
BDDKeyPrepared {
cbt: self.cbt.prepare_alloc(module, scratch),
ks: self.ks.prepare_alloc(module, scratch),
}
}
}
pub trait FheUintPrepare<BE: Backend, OUT, IN> {
fn prepare(&self, module: &Module<BE>, out: &mut OUT, bits: &IN, scratch: &mut Scratch<BE>);
}
impl<CBT, OUT, IN, LWE, BRA, BE, T> FheUintPrepare<BE, FheUintBlocksPrep<OUT, BE, T>, FheUintBlocks<IN, T>>
for BDDKeyPrepared<CBT, LWE, BRA, BE>
where
T: UnsignedInteger,
CBT: DataRef,
OUT: DataMut,
IN: DataRef,
LWE: DataRef,
BRA: BlindRotationAlgo,
BE: Backend,
Module<BE>: VmpPrepare<BE>
+ VecZnxRotate
+ VecZnxDftAllocBytes
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<BE>
+ VmpApplyDftToDftAdd<BE>
+ VecZnxDftApply<BE>
+ VecZnxIdftApplyConsume<BE>
+ VecZnxBigAddSmallInplace<BE>
+ VecZnxBigNormalize<BE>
+ VecZnxNormalize<BE>
+ VecZnxNormalizeTmpBytes,
Scratch<BE>: ScratchAvailable + TakeVecZnxDft<BE> + TakeGLWECt + TakeVecZnx + TakeGGSW,
CircuitBootstrappingKeyPrepared<CBT, BRA, BE>: CirtuitBootstrappingExecute<BE>,
{
fn prepare(
&self,
module: &Module<BE>,
out: &mut FheUintBlocksPrep<OUT, BE, T>,
bits: &FheUintBlocks<IN, T>,
scratch: &mut Scratch<BE>,
) {
#[cfg(debug_assertions)]
{
assert_eq!(out.blocks.len(), bits.blocks.len());
}
let mut lwe: LWECiphertext<Vec<u8>> = LWECiphertext::alloc(&bits.blocks[0]); //TODO: add TakeLWE
let (mut tmp_ggsw, scratch_1) = scratch.take_ggsw(out);
for (dst, src) in out.blocks.iter_mut().zip(bits.blocks.iter()) {
lwe.from_glwe(module, src, &self.ks, scratch_1);
self.cbt
.execute_to_constant(module, &mut tmp_ggsw, &lwe, 1, 1, scratch_1);
dst.prepare(module, &tmp_ggsw, scratch_1);
}
}
}
#[cfg(test)]
impl<CBT, OUT, IN, LWE, BRA, BE, T> FheUintPrepare<BE, FheUintBlocksPrepDebug<OUT, T>, FheUintBlocks<IN, T>>
for BDDKeyPrepared<CBT, LWE, BRA, BE>
where
T: UnsignedInteger,
CBT: DataRef,
OUT: DataMut,
IN: DataRef,
LWE: DataRef,
BRA: BlindRotationAlgo,
BE: Backend,
Module<BE>: VmpPrepare<BE>
+ VecZnxRotate
+ VecZnxDftAllocBytes
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<BE>
+ VmpApplyDftToDftAdd<BE>
+ VecZnxDftApply<BE>
+ VecZnxIdftApplyConsume<BE>
+ VecZnxBigAddSmallInplace<BE>
+ VecZnxBigNormalize<BE>
+ VecZnxNormalize<BE>
+ VecZnxNormalizeTmpBytes,
Scratch<BE>: ScratchAvailable + TakeVecZnxDft<BE> + TakeGLWECt + TakeVecZnx + TakeGGSW,
CircuitBootstrappingKeyPrepared<CBT, BRA, BE>: CirtuitBootstrappingExecute<BE>,
{
fn prepare(
&self,
module: &Module<BE>,
out: &mut FheUintBlocksPrepDebug<OUT, T>,
bits: &FheUintBlocks<IN, T>,
scratch: &mut Scratch<BE>,
) {
#[cfg(debug_assertions)]
{
assert_eq!(out.blocks.len(), bits.blocks.len());
}
let mut lwe: LWECiphertext<Vec<u8>> = LWECiphertext::alloc(&bits.blocks[0]); //TODO: add TakeLWE
for (dst, src) in out.blocks.iter_mut().zip(bits.blocks.iter()) {
lwe.from_glwe(module, src, &self.ks, scratch);
self.cbt
.execute_to_constant(module, dst, &lwe, 1, 1, scratch);
}
}
}

View File

@@ -0,0 +1,86 @@
mod bdd_2w_to_1w;
mod ciphertexts;
mod circuits;
mod eval;
mod key;
mod parameters;
pub use bdd_2w_to_1w::*;
pub use ciphertexts::*;
pub(crate) use circuits::*;
pub(crate) use eval::*;
pub use key::*;
#[cfg(test)]
pub(crate) use parameters::*;
#[cfg(test)]
mod test;
pub trait UnsignedInteger: Copy + 'static {
const WORD_SIZE: usize;
}
impl UnsignedInteger for u8 {
const WORD_SIZE: usize = 8;
}
impl UnsignedInteger for u16 {
const WORD_SIZE: usize = 16;
}
impl UnsignedInteger for u32 {
const WORD_SIZE: usize = 32;
}
impl UnsignedInteger for u64 {
const WORD_SIZE: usize = 64;
}
impl UnsignedInteger for u128 {
const WORD_SIZE: usize = 128;
}
pub trait ToBits {
fn bit(&self, i: usize) -> u8;
}
macro_rules! impl_tobits {
($($t:ty),*) => {
$(
impl ToBits for $t {
fn bit(&self, i: usize) -> u8 {
if i >= (std::mem::size_of::<$t>() * 8) {
panic!("bit index {} out of range for {}", i, stringify!($t));
}
((self >> i) & 1) as u8
}
}
)*
};
}
impl_tobits!(u8, u16, u32, u64, u128);
pub trait FromBits: Sized {
fn from_bits(bits: &[u8]) -> Self;
}
macro_rules! impl_from_bits {
($($t:ty),*) => {
$(
impl FromBits for $t {
fn from_bits(bits: &[u8]) -> Self {
let mut value: $t = 0;
let max_bits = std::mem::size_of::<$t>() * 8;
let n = bits.len().min(max_bits);
for (i, &bit) in bits.iter().take(n).enumerate() {
if bit != 0 {
value |= 1 << i;
}
}
value
}
}
)*
};
}
impl_from_bits!(u8, u16, u32, u64, u128);

View File

@@ -0,0 +1,80 @@
#[cfg(test)]
use poulpy_core::layouts::{
Base2K, Degree, Dnum, Dsize, GGLWEAutomorphismKeyLayout, GGLWETensorKeyLayout, GGSWCiphertextLayout, GLWECiphertextLayout,
GLWEToLWEKeyLayout, Rank, TorusPrecision,
};
#[cfg(test)]
use crate::tfhe::{
bdd_arithmetic::BDDKeyLayout, blind_rotation::BlindRotationKeyLayout, circuit_bootstrapping::CircuitBootstrappingKeyLayout,
};
#[cfg(test)]
pub(crate) const TEST_N_GLWE: u32 = 512;
#[cfg(test)]
pub(crate) const TEST_N_LWE: u32 = 77;
#[cfg(test)]
pub(crate) const TEST_BASE2K: u32 = 13;
#[cfg(test)]
pub(crate) const TEST_K_GLWE: u32 = 26;
#[cfg(test)]
pub(crate) const TEST_K_GGSW: u32 = 39;
#[cfg(test)]
pub(crate) const TEST_BLOCK_SIZE: u32 = 7;
#[cfg(test)]
pub(crate) const TEST_RANK: u32 = 2;
#[cfg(test)]
pub(crate) static TEST_GLWE_INFOS: GLWECiphertextLayout = GLWECiphertextLayout {
n: Degree(TEST_N_GLWE),
base2k: Base2K(TEST_BASE2K),
k: TorusPrecision(TEST_K_GLWE),
rank: Rank(TEST_RANK),
};
#[cfg(test)]
pub(crate) static TEST_GGSW_INFOS: GGSWCiphertextLayout = GGSWCiphertextLayout {
n: Degree(TEST_N_GLWE),
base2k: Base2K(TEST_BASE2K),
k: TorusPrecision(TEST_K_GGSW),
rank: Rank(TEST_RANK),
dnum: Dnum(2),
dsize: Dsize(1),
};
#[cfg(test)]
pub(crate) static TEST_BDD_KEY_LAYOUT: BDDKeyLayout = BDDKeyLayout {
cbt: CircuitBootstrappingKeyLayout {
layout_brk: BlindRotationKeyLayout {
n_glwe: Degree(TEST_N_GLWE),
n_lwe: Degree(TEST_N_LWE),
base2k: Base2K(TEST_BASE2K),
k: TorusPrecision(52),
dnum: Dnum(3),
rank: Rank(TEST_RANK),
},
layout_atk: GGLWEAutomorphismKeyLayout {
n: Degree(TEST_N_GLWE),
base2k: Base2K(TEST_BASE2K),
k: TorusPrecision(52),
rank: Rank(TEST_RANK),
dnum: Dnum(3),
dsize: Dsize(1),
},
layout_tsk: GGLWETensorKeyLayout {
n: Degree(TEST_N_GLWE),
base2k: Base2K(TEST_BASE2K),
k: TorusPrecision(52),
rank: Rank(TEST_RANK),
dnum: Dnum(3),
dsize: Dsize(1),
},
},
ks: GLWEToLWEKeyLayout {
n: Degree(TEST_N_GLWE),
base2k: Base2K(TEST_BASE2K),
k: TorusPrecision(39),
rank_in: Rank(TEST_RANK),
dnum: Dnum(2),
},
};

View File

@@ -0,0 +1,224 @@
use std::time::Instant;
use poulpy_backend::FFT64Ref;
use poulpy_core::{
TakeGGSW, TakeGLWEPt,
layouts::{
GGSWCiphertextLayout, GLWECiphertextLayout, GLWESecret, LWEInfos, LWESecret,
prepared::{GLWESecretPrepared, PrepareAlloc},
},
};
use poulpy_hal::{
api::{
ModuleNew, ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDft, SvpApplyDftToDftInplace,
SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeSlice, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft,
VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace,
VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigAutomorphismInplace,
VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxCopy, VecZnxDftAddInplace,
VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy, VecZnxFillUniform, VecZnxIdftApplyConsume,
VecZnxIdftApplyTmpA, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate,
VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxRshInplace, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing,
VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, ZnAddNormal, ZnFillUniform,
ZnNormalizeInplace,
},
layouts::{Backend, Module, Scratch, ScratchOwned},
oep::{
ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeMatZnxImpl, TakeScalarZnxImpl, TakeSvpPPolImpl,
TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl,
},
source::Source,
};
use rand::RngCore;
use crate::tfhe::{
bdd_arithmetic::{
Add, BDDKey, BDDKeyLayout, BDDKeyPrepared, FheUintBlocks, FheUintBlocksPrep, FheUintBlocksPrepDebug, Sub,
TEST_BDD_KEY_LAYOUT, TEST_BLOCK_SIZE, TEST_GGSW_INFOS, TEST_GLWE_INFOS, TEST_N_LWE,
},
blind_rotation::{
BlincRotationExecute, BlindRotationAlgo, BlindRotationKey, BlindRotationKeyAlloc, BlindRotationKeyEncryptSk,
BlindRotationKeyPrepared, CGGI,
},
};
#[test]
fn test_bdd_2w_to_1w_fft64_ref() {
test_bdd_2w_to_1w::<FFT64Ref, CGGI>()
}
fn test_bdd_2w_to_1w<BE: Backend, BRA: BlindRotationAlgo>()
where
Module<BE>: ModuleNew<BE> + SvpPPolAlloc<BE> + SvpPrepare<BE> + VmpPMatAlloc<BE>,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Module<BE>: VecZnxAddScalarInplace
+ VecZnxDftAllocBytes
+ VecZnxBigNormalize<BE>
+ VecZnxDftApply<BE>
+ SvpApplyDftToDftInplace<BE>
+ VecZnxIdftApplyConsume<BE>
+ VecZnxNormalizeTmpBytes
+ VecZnxFillUniform
+ VecZnxSubInplace
+ VecZnxAddInplace
+ VecZnxNormalizeInplace<BE>
+ VecZnxAddNormal
+ VecZnxNormalize<BE>
+ VecZnxSub
+ VmpPrepare<BE>,
Scratch<BE>: TakeVecZnxDft<BE> + ScratchAvailable + TakeVecZnx + TakeGGSW + TakeScalarZnx + TakeSlice,
Module<BE>: VecZnxCopy + VecZnxNegateInplace + VmpApplyDftToDftTmpBytes + VmpApplyDftToDft<BE> + VmpApplyDftToDftAdd<BE>,
Module<BE>: VecZnxBigAddInplace<BE> + VecZnxBigAddSmallInplace<BE> + VecZnxBigNormalize<BE>,
Scratch<BE>: TakeVecZnxDft<BE> + TakeVecZnxBig<BE> + TakeGLWEPt<BE>,
Module<BE>: VecZnxAutomorphism
+ VecZnxSwitchRing
+ VecZnxBigAllocBytes
+ VecZnxIdftApplyTmpA<BE>
+ SvpApplyDftToDft<BE>
+ VecZnxBigAlloc<BE>
+ VecZnxDftAlloc<BE>
+ VecZnxBigNormalizeTmpBytes
+ SvpPPolAllocBytes
+ VecZnxRotateInplace<BE>
+ VecZnxBigAutomorphismInplace<BE>
+ VecZnxRshInplace<BE>
+ VecZnxDftCopy<BE>
+ VecZnxAutomorphismInplace<BE>
+ VecZnxBigSubSmallNegateInplace<BE>
+ VecZnxRotateInplaceTmpBytes
+ VecZnxBigAllocBytes
+ VecZnxDftAddInplace<BE>
+ VecZnxRotate
+ ZnFillUniform
+ ZnAddNormal
+ ZnNormalizeInplace<BE>,
BE: Backend
+ ScratchOwnedAllocImpl<BE>
+ ScratchOwnedBorrowImpl<BE>
+ TakeVecZnxDftImpl<BE>
+ ScratchAvailableImpl<BE>
+ TakeVecZnxImpl<BE>
+ TakeScalarZnxImpl<BE>
+ TakeSvpPPolImpl<BE>
+ TakeVecZnxBigImpl<BE>
+ TakeVecZnxDftSliceImpl<BE>
+ TakeMatZnxImpl<BE>
+ TakeVecZnxSliceImpl<BE>,
BlindRotationKey<Vec<u8>, BRA>: PrepareAlloc<BE, BlindRotationKeyPrepared<Vec<u8>, BRA, BE>>,
BlindRotationKeyPrepared<Vec<u8>, BRA, BE>: BlincRotationExecute<BE>,
BlindRotationKey<Vec<u8>, BRA>: BlindRotationKeyAlloc + BlindRotationKeyEncryptSk<BE>,
{
let glwe_infos: GLWECiphertextLayout = TEST_GLWE_INFOS;
let ggsw_infos: GGSWCiphertextLayout = TEST_GGSW_INFOS;
let n_glwe: usize = glwe_infos.n().into();
let module: Module<BE> = Module::<BE>::new(n_glwe as u64);
let mut source: Source = Source::new([6u8; 32]);
let mut source_xs: Source = Source::new([1u8; 32]);
let mut source_xa: Source = Source::new([2u8; 32]);
let mut source_xe: Source = Source::new([3u8; 32]);
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(1 << 22);
let mut sk_glwe: GLWESecret<Vec<u8>> = GLWESecret::alloc(&glwe_infos);
sk_glwe.fill_ternary_prob(0.5, &mut source_xs);
let sk_glwe_prep: GLWESecretPrepared<Vec<u8>, BE> = sk_glwe.prepare_alloc(&module, scratch.borrow());
let a: u32 = source.next_u32();
let b: u32 = source.next_u32();
println!("a: {a}");
println!("b: {b}");
let mut a_enc_prep: FheUintBlocksPrep<Vec<u8>, BE, u32> = FheUintBlocksPrep::<Vec<u8>, BE, u32>::alloc(&module, &ggsw_infos);
let mut b_enc_prep: FheUintBlocksPrep<Vec<u8>, BE, u32> = FheUintBlocksPrep::<Vec<u8>, BE, u32>::alloc(&module, &ggsw_infos);
let mut c_enc: FheUintBlocks<Vec<u8>, u32> = FheUintBlocks::<Vec<u8>, u32>::alloc(&module, &glwe_infos);
let mut c_enc_prep_debug: FheUintBlocksPrepDebug<Vec<u8>, u32> =
FheUintBlocksPrepDebug::<Vec<u8>, u32>::alloc(&module, &ggsw_infos);
let mut c_enc_prep: FheUintBlocksPrep<Vec<u8>, BE, u32> = FheUintBlocksPrep::<Vec<u8>, BE, u32>::alloc(&module, &ggsw_infos);
a_enc_prep.encrypt_sk(
&module,
a,
&sk_glwe_prep,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
b_enc_prep.encrypt_sk(
&module,
b,
&sk_glwe_prep,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
let start: Instant = Instant::now();
c_enc.sub(&module, &a_enc_prep, &b_enc_prep, scratch.borrow());
let duration: std::time::Duration = start.elapsed();
println!("add: {} ms", duration.as_millis());
println!(
"have: {}",
c_enc.decrypt(&module, &sk_glwe_prep, scratch.borrow())
);
println!("want: {}", a.wrapping_sub(b));
println!(
"noise: {:?}",
c_enc.noise(&module, &sk_glwe_prep, a.wrapping_sub(b), scratch.borrow())
);
let n_lwe: u32 = TEST_N_LWE;
let block_size: u32 = TEST_BLOCK_SIZE;
let mut sk_lwe: LWESecret<Vec<u8>> = LWESecret::alloc(n_lwe.into());
sk_lwe.fill_binary_block(block_size as usize, &mut source_xs);
let bdd_key_infos: BDDKeyLayout = TEST_BDD_KEY_LAYOUT;
let now: Instant = Instant::now();
let bdd_key: BDDKey<Vec<u8>, Vec<u8>, BRA> = BDDKey::encrypt_sk(
&module,
&sk_lwe,
&sk_glwe,
&bdd_key_infos,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
let bdd_key_prepared: BDDKeyPrepared<Vec<u8>, Vec<u8>, BRA, BE> = bdd_key.prepare_alloc(&module, scratch.borrow());
println!("BDD-KGEN: {} ms", now.elapsed().as_millis());
let now: Instant = Instant::now();
c_enc_prep_debug.prepare(&module, &c_enc, &bdd_key_prepared, scratch.borrow());
println!("CBT: {} ms", now.elapsed().as_millis());
c_enc_prep_debug.noise(&module, &sk_glwe_prep, a.wrapping_sub(b));
let now: Instant = Instant::now();
c_enc_prep.prepare(&module, &c_enc, &bdd_key_prepared, scratch.borrow());
println!("CBT: {} ms", now.elapsed().as_millis());
let start: Instant = Instant::now();
c_enc.add(&module, &c_enc_prep, &b_enc_prep, scratch.borrow());
let duration: std::time::Duration = start.elapsed();
println!("add: {} ms", duration.as_millis());
println!(
"have: {}",
c_enc.decrypt(&module, &sk_glwe_prep, scratch.borrow())
);
println!("want: {}", b.wrapping_add(a.wrapping_sub(b)));
println!(
"noise: {:?}",
c_enc.noise(
&module,
&sk_glwe_prep,
b.wrapping_add(a.wrapping_sub(b)),
scratch.borrow()
)
);
}

View File

@@ -42,13 +42,13 @@ where
if block_size > 1 {
let cols: usize = (brk_infos.rank() + 1).into();
let rows: usize = brk_infos.rows().into();
let acc_dft: usize = module.vec_znx_dft_alloc_bytes(cols, rows) * extension_factor;
let dnum: usize = brk_infos.dnum().into();
let acc_dft: usize = module.vec_znx_dft_alloc_bytes(cols, dnum) * extension_factor;
let acc_big: usize = module.vec_znx_big_alloc_bytes(1, brk_size);
let vmp_res: usize = module.vec_znx_dft_alloc_bytes(cols, brk_size) * extension_factor;
let vmp_xai: usize = module.vec_znx_dft_alloc_bytes(1, brk_size);
let acc_dft_add: usize = vmp_res;
let vmp: usize = module.vmp_apply_dft_to_dft_tmp_bytes(brk_size, rows, rows, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2)
let vmp: usize = module.vmp_apply_dft_to_dft_tmp_bytes(brk_size, dnum, dnum, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2)
let acc: usize = if extension_factor > 1 {
VecZnx::alloc_bytes(module.n(), cols, glwe_infos.size()) * extension_factor
} else {
@@ -158,11 +158,11 @@ fn execute_block_binary_extended<DataRes, DataIn, DataBrk, B: Backend>(
let n_glwe: usize = brk.n_glwe().into();
let extension_factor: usize = lut.extension_factor();
let base2k: usize = res.base2k().into();
let rows: usize = brk.rows().into();
let dnum: usize = brk.dnum().into();
let cols: usize = (res.rank() + 1).into();
let (mut acc, scratch_1) = scratch.take_vec_znx_slice(extension_factor, n_glwe, cols, res.size());
let (mut acc_dft, scratch_2) = scratch_1.take_vec_znx_dft_slice(extension_factor, n_glwe, cols, rows);
let (mut acc_dft, scratch_2) = scratch_1.take_vec_znx_dft_slice(extension_factor, n_glwe, cols, dnum);
let (mut vmp_res, scratch_3) = scratch_2.take_vec_znx_dft_slice(extension_factor, n_glwe, cols, brk.size());
let (mut acc_add_dft, scratch_4) = scratch_3.take_vec_znx_dft_slice(extension_factor, n_glwe, cols, brk.size());
let (mut vmp_xai, scratch_5) = scratch_4.take_vec_znx_dft(n_glwe, 1, brk.size());
@@ -328,7 +328,7 @@ fn execute_block_binary<DataRes, DataIn, DataBrk, B: Backend>(
let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref();
let two_n: usize = n_glwe << 1;
let base2k: usize = brk.base2k().into();
let rows: usize = brk.rows().into();
let dnum: usize = brk.dnum().into();
let cols: usize = (out_mut.rank() + 1).into();
@@ -351,7 +351,7 @@ fn execute_block_binary<DataRes, DataIn, DataBrk, B: Backend>(
// ACC + [sum DFT(X^ai -1) * (DFT(ACC) x BRKi)]
let (mut acc_dft, scratch_1) = scratch.take_vec_znx_dft(n_glwe, cols, rows);
let (mut acc_dft, scratch_1) = scratch.take_vec_znx_dft(n_glwe, cols, dnum);
let (mut vmp_res, scratch_2) = scratch_1.take_vec_znx_dft(n_glwe, cols, brk.size());
let (mut acc_add_dft, scratch_3) = scratch_2.take_vec_znx_dft(n_glwe, cols, brk.size());
let (mut vmp_xai, scratch_4) = scratch_3.take_vec_znx_dft(n_glwe, 1, brk.size());

View File

@@ -8,7 +8,7 @@ use std::{fmt, marker::PhantomData};
use poulpy_core::{
Distribution,
layouts::{
Base2K, Degree, Digits, GGSWCiphertext, GGSWInfos, GLWEInfos, LWEInfos, LWESecret, Rank, Rows, TorusPrecision,
Base2K, Degree, Dnum, Dsize, GGSWCiphertext, GGSWInfos, GLWEInfos, LWEInfos, LWESecret, Rank, TorusPrecision,
prepared::GLWESecretPrepared,
},
};
@@ -23,7 +23,7 @@ pub struct BlindRotationKeyLayout {
pub n_lwe: Degree,
pub base2k: Base2K,
pub k: TorusPrecision,
pub rows: Rows,
pub dnum: Dnum,
pub rank: Rank,
}
@@ -38,12 +38,12 @@ impl BlindRotationKeyInfos for BlindRotationKeyLayout {
}
impl GGSWInfos for BlindRotationKeyLayout {
fn digits(&self) -> Digits {
Digits(1)
fn dsize(&self) -> Dsize {
Dsize(1)
}
fn rows(&self) -> Rows {
self.rows
fn dnum(&self) -> Dnum {
self.dnum
}
}
@@ -221,11 +221,11 @@ impl<D: DataRef, BRT: BlindRotationAlgo> GLWEInfos for BlindRotationKey<D, BRT>
}
}
impl<D: DataRef, BRT: BlindRotationAlgo> GGSWInfos for BlindRotationKey<D, BRT> {
fn digits(&self) -> poulpy_core::layouts::Digits {
Digits(1)
fn dsize(&self) -> poulpy_core::layouts::Dsize {
Dsize(1)
}
fn rows(&self) -> Rows {
self.keys[0].rows()
fn dnum(&self) -> Dnum {
self.keys[0].dnum()
}
}

View File

@@ -8,7 +8,7 @@ use std::{fmt, marker::PhantomData};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use poulpy_core::{
Distribution,
layouts::{Base2K, Degree, Digits, GGSWInfos, GLWEInfos, LWEInfos, TorusPrecision, compressed::GGSWCiphertextCompressed},
layouts::{Base2K, Degree, Dsize, GGSWInfos, GLWEInfos, LWEInfos, TorusPrecision, compressed::GGSWCiphertextCompressed},
};
use crate::tfhe::blind_rotation::{BlindRotationAlgo, BlindRotationKeyInfos};
@@ -128,12 +128,12 @@ impl<D: DataRef, BRA: BlindRotationAlgo> GLWEInfos for BlindRotationKeyCompresse
}
impl<D: DataRef, BRA: BlindRotationAlgo> GGSWInfos for BlindRotationKeyCompressed<D, BRA> {
fn rows(&self) -> poulpy_core::layouts::Rows {
self.keys[0].rows()
fn dnum(&self) -> poulpy_core::layouts::Dnum {
self.keys[0].dnum()
}
fn digits(&self) -> poulpy_core::layouts::Digits {
Digits(1)
fn dsize(&self) -> poulpy_core::layouts::Dsize {
Dsize(1)
}
}

View File

@@ -8,7 +8,7 @@ use std::marker::PhantomData;
use poulpy_core::{
Distribution,
layouts::{
Base2K, Degree, Digits, GGSWInfos, GLWEInfos, LWEInfos, Rank, Rows, TorusPrecision,
Base2K, Degree, Dnum, Dsize, GGSWInfos, GLWEInfos, LWEInfos, Rank, TorusPrecision,
prepared::{GGSWCiphertextPrepared, Prepare, PrepareAlloc},
},
};
@@ -63,12 +63,12 @@ impl<D: Data, BRT: BlindRotationAlgo, B: Backend> GLWEInfos for BlindRotationKey
}
}
impl<D: Data, BRT: BlindRotationAlgo, B: Backend> GGSWInfos for BlindRotationKeyPrepared<D, BRT, B> {
fn digits(&self) -> poulpy_core::layouts::Digits {
Digits(1)
fn dsize(&self) -> poulpy_core::layouts::Dsize {
Dsize(1)
}
fn rows(&self) -> Rows {
self.data[0].rows()
fn dnum(&self) -> Dnum {
self.data[0].dnum()
}
}

View File

@@ -1,9 +1,10 @@
use poulpy_hal::{
api::{
ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxCopy, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotateInplace,
VecZnxRotateInplaceTmpBytes, VecZnxSwitchRing,
ScratchOwnedAlloc, ScratchOwnedBorrow, TakeSlice, VecZnxCopy, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes,
VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxSwitchRing,
},
layouts::{Backend, Module, ScratchOwned, VecZnx, ZnxInfos, ZnxViewMut},
layouts::{Backend, Module, Scratch, ScratchOwned, VecZnx, ZnxInfos, ZnxViewMut},
reference::{vec_znx::vec_znx_rotate_inplace, znx::ZnxRef},
};
#[derive(Debug, Clone, Copy)]
@@ -76,12 +77,13 @@ impl LookUpTable {
+ VecZnxCopy
+ VecZnxRotateInplaceTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
Scratch<B>: TakeSlice,
{
assert!(f.len() <= module.n());
let base2k: usize = self.base2k;
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(module.vec_znx_normalize_tmp_bytes());
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(module.vec_znx_normalize_tmp_bytes() | (self.domain_size() << 3));
// Get the number minimum limb to store the message modulus
let limbs: usize = k.div_ceil(base2k);
@@ -128,19 +130,21 @@ impl LookUpTable {
// Rotates half the step to the left
if self.extension_factor() > 1 {
(0..self.extension_factor()).for_each(|i| {
let (tmp, _) = scratch.borrow().take_slice(lut_full.n());
for i in 0..self.extension_factor() {
module.vec_znx_switch_ring(&mut self.data[i], 0, &lut_full, 0);
if i < self.extension_factor() {
module.vec_znx_rotate_inplace(-1, &mut lut_full, 0, scratch.borrow());
vec_znx_rotate_inplace::<_, ZnxRef>(-1, &mut lut_full, 0, tmp);
}
});
}
} else {
module.vec_znx_copy(&mut self.data[0], 0, &lut_full, 0);
}
self.data.iter_mut().for_each(|a| {
for a in self.data.iter_mut() {
module.vec_znx_normalize_inplace(self.base2k, a, 0, scratch.borrow());
});
}
self.rotate(module, -(drift as i64));

View File

@@ -11,7 +11,7 @@ use poulpy_hal::{
},
layouts::{Backend, Module, ScratchOwned, ZnxView},
oep::{
ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl,
ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeSliceImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl,
TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl,
},
source::Source,
@@ -82,7 +82,8 @@ where
+ TakeVecZnxDftSliceImpl<B>
+ ScratchAvailableImpl<B>
+ TakeVecZnxImpl<B>
+ TakeVecZnxSliceImpl<B>,
+ TakeVecZnxSliceImpl<B>
+ TakeSliceImpl<B>,
{
let n_glwe: usize = module.n();
let base2k: usize = 19;
@@ -106,7 +107,7 @@ where
n_lwe: n_lwe.into(),
base2k: base2k.into(),
k: k_brk.into(),
rows: rows_brk.into(),
dnum: rows_brk.into(),
rank: rank.into(),
};

View File

@@ -6,7 +6,7 @@ use poulpy_hal::{
VecZnxSwitchRing,
},
layouts::{Backend, Module},
oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl},
oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeSliceImpl},
};
use crate::tfhe::blind_rotation::{DivRound, LookUpTable};
@@ -19,7 +19,7 @@ where
+ VecZnxSwitchRing
+ VecZnxCopy
+ VecZnxRotateInplaceTmpBytes,
B: Backend + ScratchOwnedAllocImpl<B> + ScratchOwnedBorrowImpl<B>,
B: Backend + ScratchOwnedAllocImpl<B> + ScratchOwnedBorrowImpl<B> + TakeSliceImpl<B>,
{
let base2k: usize = 20;
let k_lut: usize = 40;
@@ -59,7 +59,7 @@ where
+ VecZnxSwitchRing
+ VecZnxCopy
+ VecZnxRotateInplaceTmpBytes,
B: Backend + ScratchOwnedAllocImpl<B> + ScratchOwnedBorrowImpl<B>,
B: Backend + ScratchOwnedAllocImpl<B> + ScratchOwnedBorrowImpl<B> + TakeSliceImpl<B>,
{
let base2k: usize = 20;
let k_lut: usize = 40;

View File

@@ -11,7 +11,7 @@ fn test_cggi_blind_rotation_key_serialization() {
n_lwe: 64_usize.into(),
base2k: 12_usize.into(),
k: 54_usize.into(),
rows: 2_usize.into(),
dnum: 2_usize.into(),
rank: 2_usize.into(),
};
@@ -26,7 +26,7 @@ fn test_cggi_blind_rotation_key_compressed_serialization() {
n_lwe: 64_usize.into(),
base2k: 12_usize.into(),
k: 54_usize.into(),
rows: 2_usize.into(),
dnum: 2_usize.into(),
rank: 2_usize.into(),
};

View File

@@ -2,7 +2,7 @@ use std::collections::HashMap;
use poulpy_hal::{
api::{
ScratchAvailable, TakeMatZnx, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice, TakeVecZnxSlice,
ScratchAvailable, TakeMatZnx, TakeSlice, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice, TakeVecZnxSlice,
VecZnxAddInplace, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigAutomorphismInplace,
VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxCopy, VecZnxDftAddInplace,
VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNegateInplace,
@@ -16,9 +16,10 @@ use poulpy_hal::{
use poulpy_core::{
GLWEOperations, TakeGGLWE, TakeGLWECt,
layouts::{Digits, GGLWECiphertextLayout, GGSWInfos, GLWEInfos, LWEInfos},
layouts::{Dsize, GGLWECiphertextLayout, GGSWInfos, GLWEInfos, LWEInfos},
};
use poulpy_core::glwe_packing;
use poulpy_core::layouts::{GGSWCiphertext, GLWECiphertext, LWECiphertext, prepared::GGLWEAutomorphismKeyPrepared};
use crate::tfhe::{
@@ -66,7 +67,8 @@ where
+ TakeVecZnxDft<B>
+ TakeMatZnx
+ ScratchAvailable
+ TakeVecZnxSlice,
+ TakeVecZnxSlice
+ TakeSlice,
BlindRotationKeyPrepared<D, BRA, B>: BlincRotationExecute<B>,
{
fn execute_to_constant<DM: DataMut, DR: DataRef>(
@@ -166,7 +168,8 @@ pub fn circuit_bootstrap_core<DRes, DLwe, DBrk, BRA: BlindRotationAlgo, B>(
+ TakeVecZnx
+ ScratchAvailable
+ TakeVecZnxSlice
+ TakeMatZnx,
+ TakeMatZnx
+ TakeSlice,
BlindRotationKeyPrepared<DBrk, BRA, B>: BlincRotationExecute<B>,
{
#[cfg(debug_assertions)]
@@ -180,29 +183,29 @@ pub fn circuit_bootstrap_core<DRes, DLwe, DBrk, BRA: BlindRotationAlgo, B>(
let n: usize = res.n().into();
let base2k: usize = res.base2k().into();
let rows: usize = res.rows().into();
let dnum: usize = res.dnum().into();
let rank: usize = res.rank().into();
let k: usize = res.k().into();
let alpha: usize = rows.next_power_of_two();
let alpha: usize = dnum.next_power_of_two();
let mut f: Vec<i64> = vec![0i64; (1 << log_domain) * alpha];
if to_exponent {
(0..rows).for_each(|i| {
f[i] = 1 << (base2k * (rows - 1 - i));
(0..dnum).for_each(|i| {
f[i] = 1 << (base2k * (dnum - 1 - i));
});
} else {
(0..1 << log_domain).for_each(|j| {
(0..rows).for_each(|i| {
f[j * alpha + i] = j as i64 * (1 << (base2k * (rows - 1 - i)));
(0..dnum).for_each(|i| {
f[j * alpha + i] = j as i64 * (1 << (base2k * (dnum - 1 - i)));
});
});
}
// Lut precision, basically must be able to hold the decomposition power basis of the GGSW
let mut lut: LookUpTable = LookUpTable::alloc(module, base2k, base2k * rows, extension_factor);
lut.set(module, &f, base2k * rows);
let mut lut: LookUpTable = LookUpTable::alloc(module, base2k, base2k * dnum, extension_factor);
lut.set(module, &f, base2k * dnum);
if to_exponent {
lut.set_rotation_direction(LookUpTableRotationDirection::Right);
@@ -215,8 +218,8 @@ pub fn circuit_bootstrap_core<DRes, DLwe, DBrk, BRA: BlindRotationAlgo, B>(
n: n.into(),
base2k: base2k.into(),
k: k.into(),
rows: rows.into(),
digits: Digits(1),
dnum: dnum.into(),
dsize: Dsize(1),
rank_in: rank.max(1).into(),
rank_out: rank.into(),
};
@@ -229,7 +232,7 @@ pub fn circuit_bootstrap_core<DRes, DLwe, DBrk, BRA: BlindRotationAlgo, B>(
let log_gap_in: usize = (usize::BITS - (gap * alpha - 1).leading_zeros()) as _;
(0..rows).for_each(|i| {
(0..dnum).for_each(|i| {
let mut tmp_glwe: GLWECiphertext<&mut [u8]> = tmp_gglwe.at_mut(i, 0);
if to_exponent {
@@ -248,7 +251,7 @@ pub fn circuit_bootstrap_core<DRes, DLwe, DBrk, BRA: BlindRotationAlgo, B>(
tmp_glwe.trace(module, 0, module.log_n(), &res_glwe, &key.atk, scratch_2);
}
if i < rows {
if i < dnum {
res_glwe.rotate_inplace(module, -(gap as i64), scratch_2);
}
});
@@ -300,7 +303,7 @@ fn post_process<DataRes, DataA, B: Backend>(
{
let log_n: usize = module.log_n();
let mut cts: HashMap<usize, GLWECiphertext<Vec<u8>>> = HashMap::new();
let mut cts: HashMap<usize, &mut GLWECiphertext<Vec<u8>>> = HashMap::new();
// First partial trace, vanishes all coefficients which are not multiples of gap_in
// [1, 1, 1, 1, 0, 0, 0, ..., 0, 0, -1, -1, -1, -1] -> [1, 0, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0, 0]
@@ -316,177 +319,31 @@ fn post_process<DataRes, DataA, B: Backend>(
// TODO: optimize with packing and final partial trace
// If gap_out < gap_in, then we need to repack, i.e. reduce the cap between coefficients.
if log_gap_in != log_gap_out {
let steps: i32 = 1 << log_domain;
(0..steps).for_each(|i| {
let steps: usize = 1 << log_domain;
// TODO: from Scratch
let mut cts_vec: Vec<GLWECiphertext<Vec<u8>>> = Vec::new();
for i in 0..steps {
if i != 0 {
res.rotate_inplace(module, -(1 << log_gap_in), scratch);
}
cts.insert(i as usize * (1 << log_gap_out), res.to_owned_deep());
});
pack(module, &mut cts, log_gap_out, auto_keys, scratch);
let packed: GLWECiphertext<Vec<u8>> = cts.remove(&0).unwrap();
cts_vec.push(res.to_owned_deep());
}
for (i, ct) in cts_vec.iter_mut().enumerate().take(steps) {
cts.insert(i * (1 << log_gap_out), ct);
}
glwe_packing(module, &mut cts, log_gap_out, auto_keys, scratch);
let packed: &mut GLWECiphertext<Vec<u8>> = cts.remove(&0).unwrap();
res.trace(
module,
log_n - log_gap_out,
log_n,
&packed,
packed,
auto_keys,
scratch,
);
}
}
pub fn pack<D: DataMut, B: Backend>(
module: &Module<B>,
cts: &mut HashMap<usize, GLWECiphertext<D>>,
log_gap_out: usize,
auto_keys: &HashMap<i64, GGLWEAutomorphismKeyPrepared<Vec<u8>, B>>,
scratch: &mut Scratch<B>,
) where
Module<B>: VecZnxRotateInplace<B>
+ VecZnxNormalizeInplace<B>
+ VecZnxNormalizeTmpBytes
+ VecZnxSwitchRing
+ VecZnxBigAutomorphismInplace<B>
+ VecZnxRshInplace<B>
+ VecZnxDftCopy<B>
+ VecZnxIdftApplyTmpA<B>
+ VecZnxSub
+ VecZnxAddInplace
+ VecZnxNegateInplace
+ VecZnxCopy
+ VecZnxSubInplace
+ VecZnxDftAllocBytes
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ VecZnxDftApply<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>
+ VecZnxAutomorphismInplace<B>
+ VecZnxBigSubSmallNegateInplace<B>
+ VecZnxRotate
+ VecZnxNormalize<B>,
Scratch<B>: TakeVecZnx + TakeVecZnxDft<B> + ScratchAvailable,
{
let log_n: usize = module.log_n();
(0..log_n - log_gap_out).for_each(|i| {
let t: usize = 16.min(1 << (log_n - 1 - i));
let auto_key: &GGLWEAutomorphismKeyPrepared<Vec<u8>, B> = if i == 0 {
auto_keys.get(&-1).unwrap()
} else {
auto_keys.get(&module.galois_element(1 << (i - 1))).unwrap()
};
(0..t).for_each(|j| {
let mut a: Option<GLWECiphertext<D>> = cts.remove(&j);
let mut b: Option<GLWECiphertext<D>> = cts.remove(&(j + t));
combine(module, a.as_mut(), b.as_mut(), i, auto_key, scratch);
if let Some(a) = a {
cts.insert(j, a);
} else if let Some(b) = b {
cts.insert(j, b);
}
});
});
}
#[allow(clippy::too_many_arguments)]
fn combine<A: DataMut, D: DataMut, DataAK: DataRef, B: Backend>(
module: &Module<B>,
a: Option<&mut GLWECiphertext<A>>,
b: Option<&mut GLWECiphertext<D>>,
i: usize,
auto_key: &GGLWEAutomorphismKeyPrepared<DataAK, B>,
scratch: &mut Scratch<B>,
) where
Module<B>: VecZnxRotateInplace<B>
+ VecZnxNormalizeInplace<B>
+ VecZnxNormalizeTmpBytes
+ VecZnxSwitchRing
+ VecZnxBigAutomorphismInplace<B>
+ VecZnxRshInplace<B>
+ VecZnxDftCopy<B>
+ VecZnxIdftApplyTmpA<B>
+ VecZnxSub
+ VecZnxAddInplace
+ VecZnxNegateInplace
+ VecZnxCopy
+ VecZnxSubInplace
+ VecZnxDftAllocBytes
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ VecZnxDftApply<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>
+ VecZnxAutomorphismInplace<B>
+ VecZnxBigSubSmallNegateInplace<B>
+ VecZnxRotate
+ VecZnxNormalize<B>,
Scratch<B>: TakeVecZnx + TakeVecZnxDft<B> + ScratchAvailable,
{
// Goal is to evaluate: a = a + b*X^t + phi(a - b*X^t))
// We also use the identity: AUTO(a * X^t, g) = -X^t * AUTO(a, g)
// where t = 2^(log_n - i - 1) and g = 5^{2^(i - 1)}
// Different cases for wether a and/or b are zero.
//
// Implicite RSH without modulus switch, introduces extra I(X) * Q/2 on decryption.
// Necessary so that the scaling of the plaintext remains constant.
// It however is ok to do so here because coefficients are eventually
// either mapped to garbage or twice their value which vanishes I(X)
// since 2*(I(X) * Q/2) = I(X) * Q = 0 mod Q.
if let Some(a) = a {
let t: i64 = 1 << (a.n().log2() - i - 1);
if let Some(b) = b {
let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(a);
// a = a * X^-t
a.rotate_inplace(module, -t, scratch_1);
// tmp_b = a * X^-t - b
tmp_b.sub(module, a, b);
tmp_b.rsh(module, 1, scratch_1);
// a = a * X^-t + b
a.add_inplace(module, b);
a.rsh(module, 1, scratch_1);
tmp_b.normalize_inplace(module, scratch_1);
// tmp_b = phi(a * X^-t - b)
tmp_b.automorphism_inplace(module, auto_key, scratch_1);
// a = a * X^-t + b - phi(a * X^-t - b)
a.sub_inplace_ab(module, &tmp_b);
a.normalize_inplace(module, scratch_1);
// a = a + b * X^t - phi(a * X^-t - b) * X^t
// = a + b * X^t - phi(a * X^-t - b) * - phi(X^t)
// = a + b * X^t + phi(a - b * X^t)
a.rotate_inplace(module, t, scratch_1);
} else {
a.rsh(module, 1, scratch);
// a = a + phi(a)
a.automorphism_add_inplace(module, auto_key, scratch);
}
} else if let Some(b) = b {
let t: i64 = 1 << (b.n().log2() - i - 1);
let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(b);
tmp_b.rotate(module, t, b);
tmp_b.rsh(module, 1, scratch_1);
// a = (b* X^t - phi(b* X^t))
b.automorphism_sub_negate(module, &tmp_b, auto_key, scratch_1);
}
}

View File

@@ -1,5 +1,5 @@
use poulpy_core::layouts::{
GGLWEAutomorphismKey, GGLWEAutomorphismKeyLayout, GGLWELayoutInfos, GGLWETensorKey, GGLWETensorKeyLayout, GGSWInfos,
GGLWEAutomorphismKey, GGLWEAutomorphismKeyLayout, GGLWEInfos, GGLWETensorKey, GGLWETensorKeyLayout, GGSWInfos,
GLWECiphertext, GLWEInfos, GLWESecret, LWEInfos, LWESecret,
prepared::{GGLWEAutomorphismKeyPrepared, GGLWETensorKeyPrepared, GLWESecretPrepared, PrepareAlloc},
};
@@ -23,11 +23,12 @@ use crate::tfhe::blind_rotation::{
};
pub trait CircuitBootstrappingKeyInfos {
fn layout_brk(&self) -> BlindRotationKeyLayout;
fn layout_atk(&self) -> GGLWEAutomorphismKeyLayout;
fn layout_tsk(&self) -> GGLWETensorKeyLayout;
fn brk_infos(&self) -> BlindRotationKeyLayout;
fn atk_infos(&self) -> GGLWEAutomorphismKeyLayout;
fn tsk_infos(&self) -> GGLWETensorKeyLayout;
}
#[derive(Debug, Clone, Copy)]
pub struct CircuitBootstrappingKeyLayout {
pub layout_brk: BlindRotationKeyLayout,
pub layout_atk: GGLWEAutomorphismKeyLayout,
@@ -35,15 +36,15 @@ pub struct CircuitBootstrappingKeyLayout {
}
impl CircuitBootstrappingKeyInfos for CircuitBootstrappingKeyLayout {
fn layout_atk(&self) -> GGLWEAutomorphismKeyLayout {
fn atk_infos(&self) -> GGLWEAutomorphismKeyLayout {
self.layout_atk
}
fn layout_brk(&self) -> BlindRotationKeyLayout {
fn brk_infos(&self) -> BlindRotationKeyLayout {
self.layout_brk
}
fn layout_tsk(&self) -> GGLWETensorKeyLayout {
fn tsk_infos(&self) -> GGLWETensorKeyLayout {
self.layout_tsk
}
}
@@ -110,16 +111,15 @@ where
INFOS: CircuitBootstrappingKeyInfos,
DLwe: DataRef,
DGlwe: DataRef,
Module<B>:,
{
assert_eq!(sk_lwe.n(), cbt_infos.layout_brk().n_lwe());
assert_eq!(sk_glwe.n(), cbt_infos.layout_brk().n_glwe());
assert_eq!(sk_glwe.n(), cbt_infos.layout_atk().n());
assert_eq!(sk_glwe.n(), cbt_infos.layout_tsk().n());
assert_eq!(sk_lwe.n(), cbt_infos.brk_infos().n_lwe());
assert_eq!(sk_glwe.n(), cbt_infos.brk_infos().n_glwe());
assert_eq!(sk_glwe.n(), cbt_infos.atk_infos().n());
assert_eq!(sk_glwe.n(), cbt_infos.tsk_infos().n());
let atk_infos: GGLWEAutomorphismKeyLayout = cbt_infos.layout_atk();
let brk_infos: BlindRotationKeyLayout = cbt_infos.layout_brk();
let trk_infos: GGLWETensorKeyLayout = cbt_infos.layout_tsk();
let atk_infos: GGLWEAutomorphismKeyLayout = cbt_infos.atk_infos();
let brk_infos: BlindRotationKeyLayout = cbt_infos.brk_infos();
let trk_infos: GGLWETensorKeyLayout = cbt_infos.tsk_infos();
let mut auto_keys: HashMap<i64, GGLWEAutomorphismKey<Vec<u8>>> = HashMap::new();
let gal_els: Vec<i64> = GLWECiphertext::trace_galois_elements(module);
@@ -159,36 +159,36 @@ pub struct CircuitBootstrappingKeyPrepared<D: Data, BRA: BlindRotationAlgo, B: B
}
impl<D: DataRef, BRA: BlindRotationAlgo, B: Backend> CircuitBootstrappingKeyInfos for CircuitBootstrappingKeyPrepared<D, BRA, B> {
fn layout_atk(&self) -> GGLWEAutomorphismKeyLayout {
fn atk_infos(&self) -> GGLWEAutomorphismKeyLayout {
let (_, atk) = self.atk.iter().next().expect("atk is empty");
GGLWEAutomorphismKeyLayout {
n: atk.n(),
base2k: atk.base2k(),
k: atk.k(),
rows: atk.rows(),
digits: atk.digits(),
dnum: atk.dnum(),
dsize: atk.dsize(),
rank: atk.rank(),
}
}
fn layout_brk(&self) -> BlindRotationKeyLayout {
fn brk_infos(&self) -> BlindRotationKeyLayout {
BlindRotationKeyLayout {
n_glwe: self.brk.n_glwe(),
n_lwe: self.brk.n_lwe(),
base2k: self.brk.base2k(),
k: self.brk.k(),
rows: self.brk.rows(),
dnum: self.brk.dnum(),
rank: self.brk.rank(),
}
}
fn layout_tsk(&self) -> GGLWETensorKeyLayout {
fn tsk_infos(&self) -> GGLWETensorKeyLayout {
GGLWETensorKeyLayout {
n: self.tsk.n(),
base2k: self.tsk.base2k(),
k: self.tsk.k(),
rows: self.tsk.rows(),
digits: self.tsk.digits(),
dnum: self.tsk.dnum(),
dsize: self.tsk.dsize(),
rank: self.tsk.rank(),
}
}

View File

@@ -14,8 +14,8 @@ use poulpy_hal::{
},
layouts::{Backend, Module, ScalarZnx, ScratchOwned, ZnxView, ZnxViewMut},
oep::{
ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeMatZnxImpl, TakeScalarZnxImpl, TakeSvpPPolImpl,
TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl,
ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeMatZnxImpl, TakeScalarZnxImpl, TakeSliceImpl,
TakeSvpPPolImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl,
},
source::Source,
};
@@ -32,7 +32,7 @@ use crate::tfhe::{
};
use poulpy_core::layouts::{
Digits, GGLWEAutomorphismKeyLayout, GGLWETensorKeyLayout, GGSWCiphertextLayout, LWECiphertextLayout, prepared::PrepareAlloc,
Dsize, GGLWEAutomorphismKeyLayout, GGLWETensorKeyLayout, GGSWCiphertextLayout, LWECiphertextLayout, prepared::PrepareAlloc,
};
use poulpy_core::layouts::{
@@ -100,7 +100,8 @@ where
+ TakeVecZnxBigImpl<B>
+ TakeVecZnxDftSliceImpl<B>
+ TakeMatZnxImpl<B>
+ TakeVecZnxSliceImpl<B>,
+ TakeVecZnxSliceImpl<B>
+ TakeSliceImpl<B>,
BlindRotationKey<Vec<u8>, BRA>: PrepareAlloc<B, BlindRotationKeyPrepared<Vec<u8>, BRA, B>>,
BlindRotationKeyPrepared<Vec<u8>, BRA, B>: BlincRotationExecute<B>,
BlindRotationKey<Vec<u8>, BRA>: BlindRotationKeyAlloc + BlindRotationKeyEncryptSk<B>,
@@ -139,23 +140,23 @@ where
n_lwe: n_lwe.into(),
base2k: base2k.into(),
k: k_brk.into(),
rows: rows_brk.into(),
dnum: rows_brk.into(),
rank: rank.into(),
},
layout_atk: GGLWEAutomorphismKeyLayout {
n: n_glwe.into(),
base2k: base2k.into(),
k: k_atk.into(),
rows: rows_atk.into(),
dnum: rows_atk.into(),
rank: rank.into(),
digits: Digits(1),
dsize: Dsize(1),
},
layout_tsk: GGLWETensorKeyLayout {
n: n_glwe.into(),
base2k: base2k.into(),
k: k_tsk.into(),
rows: rows_tsk.into(),
digits: Digits(1),
dnum: rows_tsk.into(),
dsize: Dsize(1),
rank: rank.into(),
},
};
@@ -164,8 +165,8 @@ where
n: n_glwe.into(),
base2k: base2k.into(),
k: k_ggsw_res.into(),
rows: rows_ggsw_res.into(),
digits: Digits(1),
dnum: rows_ggsw_res.into(),
dsize: Dsize(1),
rank: rank.into(),
};
@@ -321,7 +322,8 @@ where
+ TakeVecZnxBigImpl<B>
+ TakeVecZnxDftSliceImpl<B>
+ TakeMatZnxImpl<B>
+ TakeVecZnxSliceImpl<B>,
+ TakeVecZnxSliceImpl<B>
+ TakeSliceImpl<B>,
BlindRotationKey<Vec<u8>, BRA>: PrepareAlloc<B, BlindRotationKeyPrepared<Vec<u8>, BRA, B>>,
BlindRotationKeyPrepared<Vec<u8>, BRA, B>: BlincRotationExecute<B>,
BlindRotationKey<Vec<u8>, BRA>: BlindRotationKeyAlloc + BlindRotationKeyEncryptSk<B>,
@@ -360,23 +362,23 @@ where
n_lwe: n_lwe.into(),
base2k: base2k.into(),
k: k_brk.into(),
rows: rows_brk.into(),
dnum: rows_brk.into(),
rank: rank.into(),
},
layout_atk: GGLWEAutomorphismKeyLayout {
n: n_glwe.into(),
base2k: base2k.into(),
k: k_atk.into(),
rows: rows_atk.into(),
dnum: rows_atk.into(),
rank: rank.into(),
digits: Digits(1),
dsize: Dsize(1),
},
layout_tsk: GGLWETensorKeyLayout {
n: n_glwe.into(),
base2k: base2k.into(),
k: k_tsk.into(),
rows: rows_tsk.into(),
digits: Digits(1),
dnum: rows_tsk.into(),
dsize: Dsize(1),
rank: rank.into(),
},
};
@@ -385,8 +387,8 @@ where
n: n_glwe.into(),
base2k: base2k.into(),
k: k_ggsw_res.into(),
rows: rows_ggsw_res.into(),
digits: Digits(1),
dnum: rows_ggsw_res.into(),
dsize: Dsize(1),
rank: rank.into(),
};

View File

@@ -1,2 +1,3 @@
pub mod bdd_arithmetic;
pub mod blind_rotation;
pub mod circuit_bootstrapping;