Add identity BDD

This commit is contained in:
Pro7ech
2025-11-16 14:38:40 +01:00
parent f9dcddcce1
commit 2613bf1450
6 changed files with 280 additions and 2 deletions

View File

@@ -0,0 +1,157 @@
use poulpy_core::{
GLWECopy, GLWEPacking, ScratchTakeCore,
layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWEAutomorphismKeyHelper, GetGaloisElement},
};
use poulpy_hal::{
api::ModuleLogN,
layouts::{Backend, DataMut, DataRef, Module, Scratch},
};
use crate::tfhe::bdd_arithmetic::{ExecuteBDDCircuit, FheUint, FheUintPrepared, GetBitCircuitInfo, UnsignedInteger, circuits};
impl<BE: Backend> ExecuteBDDCircuit1WTo1W<BE> for Module<BE> where Self: Sized + ExecuteBDDCircuit<BE> + GLWEPacking<BE> + GLWECopy
{}
pub trait ExecuteBDDCircuit1WTo1W<BE: Backend>
where
Self: Sized + ModuleLogN + ExecuteBDDCircuit<BE> + GLWEPacking<BE> + GLWECopy,
{
fn execute_bdd_circuit_1w_to_1w<R, C, A, K, H, T>(
&self,
out: &mut FheUint<R, T>,
circuit: &C,
a: &FheUintPrepared<A, T, BE>,
key: &H,
scratch: &mut Scratch<BE>,
) where
T: UnsignedInteger,
C: GetBitCircuitInfo,
R: DataMut,
A: DataRef,
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
H: GLWEAutomorphismKeyHelper<K, BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
self.execute_bdd_circuit_1w_to_1w_multi_thread(1, out, circuit, a, key, scratch);
}
#[allow(clippy::too_many_arguments)]
/// Operations Z x Z -> Z
fn execute_bdd_circuit_1w_to_1w_multi_thread<R, C, A, K, H, T>(
&self,
threads: usize,
out: &mut FheUint<R, T>,
circuit: &C,
a: &FheUintPrepared<A, T, BE>,
key: &H,
scratch: &mut Scratch<BE>,
) where
T: UnsignedInteger,
C: GetBitCircuitInfo,
R: DataMut,
A: DataRef,
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
H: GLWEAutomorphismKeyHelper<K, BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
let (mut out_bits, scratch_1) = scratch.take_glwe_slice(T::BITS as usize, out);
// Evaluates out[i] = circuit[i](a, b)
self.execute_bdd_circuit_multi_thread(threads, &mut out_bits, a, circuit, scratch_1);
// Repacks the bits
out.pack(self, out_bits, key, scratch_1);
}
}
#[macro_export]
macro_rules! define_bdd_1w_to_1w_trait {
($(#[$meta:meta])* $vis:vis $trait_name:ident, $method_name:ident) => {
paste::paste! {
$(#[$meta])*
$vis trait $trait_name<T: UnsignedInteger, BE: Backend> {
/// Single-threaded version
fn $method_name<A, M, K, H>(
&mut self,
module: &M,
a: &FheUintPrepared<A, T, BE>,
key: &H,
scratch: &mut Scratch<BE>,
) where
M: ExecuteBDDCircuit1WTo1W<BE>,
A: DataRef,
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
H: GLWEAutomorphismKeyHelper<K, BE>,
Scratch<BE>: ScratchTakeCore<BE>;
/// Multithreaded version same vis, method_name + "_multi_thread"
fn [<$method_name _multi_thread>]<A, M, K, H>(
&mut self,
threads: usize,
module: &M,
a: &FheUintPrepared<A, T, BE>,
key: &H,
scratch: &mut Scratch<BE>,
) where
M: ExecuteBDDCircuit1WTo1W<BE>,
A: DataRef,
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
H: GLWEAutomorphismKeyHelper<K, BE>,
Scratch<BE>: ScratchTakeCore<BE>;
}
}
};
}
#[macro_export]
macro_rules! impl_bdd_1w_to_1w_trait {
($trait_name:ident, $method_name:ident, $ty:ty, $circuit_ty:ty, $output_circuits:path) => {
paste::paste! {
impl<D: DataMut, BE: Backend> $trait_name<$ty, BE> for FheUint<D, $ty> {
fn $method_name<A, M, K, H>(
&mut self,
module: &M,
a: &FheUintPrepared<A, $ty, BE>,
key: &H,
scratch: &mut Scratch<BE>,
) where
M: ExecuteBDDCircuit1WTo1W<BE>,
A: DataRef,
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
H: GLWEAutomorphismKeyHelper<K, BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
module.execute_bdd_circuit_1w_to_1w(self, &$output_circuits, a, key, scratch)
}
fn [<$method_name _multi_thread>]<A, M, K, H>(
&mut self,
threads: usize,
module: &M,
a: &FheUintPrepared<A, $ty, BE>,
key: &H,
scratch: &mut Scratch<BE>,
) where
M: ExecuteBDDCircuit1WTo1W<BE>,
A: DataRef,
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
H: GLWEAutomorphismKeyHelper<K, BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
module.execute_bdd_circuit_1w_to_1w_multi_thread(threads, self, &$output_circuits, a, key, scratch)
}
}
}
};
}
define_bdd_1w_to_1w_trait!(pub Identity, identity);
impl_bdd_1w_to_1w_trait!(
Identity,
identity,
u32,
circuits::u32::identity_codgen::AnyBitCircuit,
circuits::u32::identity_codgen::OUTPUT_CIRCUITS
);

View File

@@ -209,7 +209,6 @@ macro_rules! impl_bdd_2w_to_1w_trait {
}
};
}
define_bdd_2w_to_1w_trait!(pub Add, add);
define_bdd_2w_to_1w_trait!(pub Sub, sub);
define_bdd_2w_to_1w_trait!(pub Sll, sll);

View File

@@ -20,7 +20,9 @@ use poulpy_hal::{
source::Source,
};
use crate::tfhe::bdd_arithmetic::{BDDKey, BDDKeyHelper, BDDKeyInfos, BDDKeyPrepared, BDDKeyPreparedFactory, FheUint, ToBits};
use crate::tfhe::bdd_arithmetic::{
BDDKey, BDDKeyHelper, BDDKeyInfos, BDDKeyPrepared, BDDKeyPreparedFactory, BitSize, FheUint, ToBits,
};
use crate::tfhe::bdd_arithmetic::{Cmux, FromBits, ScratchTakeBDD, UnsignedInteger};
use crate::tfhe::blind_rotation::BlindRotationAlgo;
use crate::tfhe::circuit_bootstrapping::{CircuitBootstrappingKeyInfos, CirtuitBootstrappingExecute};
@@ -55,6 +57,12 @@ impl<D: DataMut, T: UnsignedInteger, BE: Backend> GetGGSWBitMut<T, BE> for FheUi
}
}
impl<D: Data, T: UnsignedInteger, BE: Backend> BitSize for FheUintPrepared<D, T, BE> {
fn bit_size(&self) -> usize {
T::BITS as usize
}
}
pub trait FheUintPreparedFactory<T: UnsignedInteger, BE: Backend>
where
Self: Sized + GGSWPreparedFactory<BE>,

View File

@@ -0,0 +1,111 @@
use crate::tfhe::bdd_arithmetic::{BitCircuit, BitCircuitFamily, BitCircuitInfo, Circuit, Node};
pub(crate) enum AnyBitCircuit {
B0(BitCircuit<2>),
B1(BitCircuit<2>),
B2(BitCircuit<2>),
B3(BitCircuit<2>),
B4(BitCircuit<2>),
B5(BitCircuit<2>),
B6(BitCircuit<2>),
B7(BitCircuit<2>),
B8(BitCircuit<2>),
B9(BitCircuit<2>),
B10(BitCircuit<2>),
B11(BitCircuit<2>),
B12(BitCircuit<2>),
B13(BitCircuit<2>),
B14(BitCircuit<2>),
B15(BitCircuit<2>),
B16(BitCircuit<2>),
B17(BitCircuit<2>),
B18(BitCircuit<2>),
B19(BitCircuit<2>),
B20(BitCircuit<2>),
B21(BitCircuit<2>),
B22(BitCircuit<2>),
B23(BitCircuit<2>),
B24(BitCircuit<2>),
B25(BitCircuit<2>),
B26(BitCircuit<2>),
B27(BitCircuit<2>),
B28(BitCircuit<2>),
B29(BitCircuit<2>),
B30(BitCircuit<2>),
B31(BitCircuit<2>),
}
impl BitCircuitFamily for AnyBitCircuit {
const INPUT_BITS: usize = 32usize;
const OUTPUT_BITS: usize = 32usize;
}
impl BitCircuitInfo for AnyBitCircuit {
fn info(&self) -> (&[Node], usize) {
match self {
AnyBitCircuit::B0(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
AnyBitCircuit::B1(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
AnyBitCircuit::B2(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
AnyBitCircuit::B3(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
AnyBitCircuit::B4(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
AnyBitCircuit::B5(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
AnyBitCircuit::B6(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
AnyBitCircuit::B7(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
AnyBitCircuit::B8(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
AnyBitCircuit::B9(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
AnyBitCircuit::B10(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
AnyBitCircuit::B11(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
AnyBitCircuit::B12(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
AnyBitCircuit::B13(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
AnyBitCircuit::B14(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
AnyBitCircuit::B15(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
AnyBitCircuit::B16(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
AnyBitCircuit::B17(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
AnyBitCircuit::B18(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
AnyBitCircuit::B19(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
AnyBitCircuit::B20(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
AnyBitCircuit::B21(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
AnyBitCircuit::B22(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
AnyBitCircuit::B23(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
AnyBitCircuit::B24(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
AnyBitCircuit::B25(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
AnyBitCircuit::B26(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
AnyBitCircuit::B27(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
AnyBitCircuit::B28(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
AnyBitCircuit::B29(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
AnyBitCircuit::B30(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
AnyBitCircuit::B31(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
}
}
}
pub(crate) static OUTPUT_CIRCUITS: Circuit<AnyBitCircuit, 32usize> = Circuit([
AnyBitCircuit::B0(BitCircuit::new([Node::Cmux(0, 1, 0), Node::None], 2)),
AnyBitCircuit::B1(BitCircuit::new([Node::Cmux(1, 1, 0), Node::None], 2)),
AnyBitCircuit::B2(BitCircuit::new([Node::Cmux(2, 1, 0), Node::None], 2)),
AnyBitCircuit::B3(BitCircuit::new([Node::Cmux(3, 1, 0), Node::None], 2)),
AnyBitCircuit::B4(BitCircuit::new([Node::Cmux(4, 1, 0), Node::None], 2)),
AnyBitCircuit::B5(BitCircuit::new([Node::Cmux(5, 1, 0), Node::None], 2)),
AnyBitCircuit::B6(BitCircuit::new([Node::Cmux(6, 1, 0), Node::None], 2)),
AnyBitCircuit::B7(BitCircuit::new([Node::Cmux(7, 1, 0), Node::None], 2)),
AnyBitCircuit::B8(BitCircuit::new([Node::Cmux(8, 1, 0), Node::None], 2)),
AnyBitCircuit::B9(BitCircuit::new([Node::Cmux(9, 1, 0), Node::None], 2)),
AnyBitCircuit::B10(BitCircuit::new([Node::Cmux(10, 1, 0), Node::None], 2)),
AnyBitCircuit::B11(BitCircuit::new([Node::Cmux(11, 1, 0), Node::None], 2)),
AnyBitCircuit::B12(BitCircuit::new([Node::Cmux(12, 1, 0), Node::None], 2)),
AnyBitCircuit::B13(BitCircuit::new([Node::Cmux(13, 1, 0), Node::None], 2)),
AnyBitCircuit::B14(BitCircuit::new([Node::Cmux(14, 1, 0), Node::None], 2)),
AnyBitCircuit::B15(BitCircuit::new([Node::Cmux(15, 1, 0), Node::None], 2)),
AnyBitCircuit::B16(BitCircuit::new([Node::Cmux(16, 1, 0), Node::None], 2)),
AnyBitCircuit::B17(BitCircuit::new([Node::Cmux(17, 1, 0), Node::None], 2)),
AnyBitCircuit::B18(BitCircuit::new([Node::Cmux(18, 1, 0), Node::None], 2)),
AnyBitCircuit::B19(BitCircuit::new([Node::Cmux(19, 1, 0), Node::None], 2)),
AnyBitCircuit::B20(BitCircuit::new([Node::Cmux(20, 1, 0), Node::None], 2)),
AnyBitCircuit::B21(BitCircuit::new([Node::Cmux(21, 1, 0), Node::None], 2)),
AnyBitCircuit::B22(BitCircuit::new([Node::Cmux(22, 1, 0), Node::None], 2)),
AnyBitCircuit::B23(BitCircuit::new([Node::Cmux(23, 1, 0), Node::None], 2)),
AnyBitCircuit::B24(BitCircuit::new([Node::Cmux(24, 1, 0), Node::None], 2)),
AnyBitCircuit::B25(BitCircuit::new([Node::Cmux(25, 1, 0), Node::None], 2)),
AnyBitCircuit::B26(BitCircuit::new([Node::Cmux(26, 1, 0), Node::None], 2)),
AnyBitCircuit::B27(BitCircuit::new([Node::Cmux(27, 1, 0), Node::None], 2)),
AnyBitCircuit::B28(BitCircuit::new([Node::Cmux(28, 1, 0), Node::None], 2)),
AnyBitCircuit::B29(BitCircuit::new([Node::Cmux(29, 1, 0), Node::None], 2)),
AnyBitCircuit::B30(BitCircuit::new([Node::Cmux(30, 1, 0), Node::None], 2)),
AnyBitCircuit::B31(BitCircuit::new([Node::Cmux(31, 1, 0), Node::None], 2)),
]);

View File

@@ -1,5 +1,6 @@
pub(crate) mod add_codegen;
pub(crate) mod and_codegen;
pub(crate) mod identity_codgen;
pub(crate) mod or_codegen;
pub(crate) mod sll_codegen;
pub(crate) mod slt_codegen;

View File

@@ -1,3 +1,4 @@
mod bdd_1w_to_1w;
mod bdd_2w_to_1w;
mod blind_retrieval;
mod blind_rotation;
@@ -7,6 +8,7 @@ mod circuits;
mod eval;
mod key;
pub use bdd_1w_to_1w::*;
pub use bdd_2w_to_1w::*;
pub use blind_retrieval::*;
pub use blind_rotation::*;