mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
Add identity BDD
This commit is contained in:
157
poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_1w_to_1w.rs
Normal file
157
poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_1w_to_1w.rs
Normal file
@@ -0,0 +1,157 @@
|
||||
use poulpy_core::{
|
||||
GLWECopy, GLWEPacking, ScratchTakeCore,
|
||||
layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWEAutomorphismKeyHelper, GetGaloisElement},
|
||||
};
|
||||
use poulpy_hal::{
|
||||
api::ModuleLogN,
|
||||
layouts::{Backend, DataMut, DataRef, Module, Scratch},
|
||||
};
|
||||
|
||||
use crate::tfhe::bdd_arithmetic::{ExecuteBDDCircuit, FheUint, FheUintPrepared, GetBitCircuitInfo, UnsignedInteger, circuits};
|
||||
|
||||
impl<BE: Backend> ExecuteBDDCircuit1WTo1W<BE> for Module<BE> where Self: Sized + ExecuteBDDCircuit<BE> + GLWEPacking<BE> + GLWECopy
|
||||
{}
|
||||
|
||||
pub trait ExecuteBDDCircuit1WTo1W<BE: Backend>
|
||||
where
|
||||
Self: Sized + ModuleLogN + ExecuteBDDCircuit<BE> + GLWEPacking<BE> + GLWECopy,
|
||||
{
|
||||
fn execute_bdd_circuit_1w_to_1w<R, C, A, K, H, T>(
|
||||
&self,
|
||||
out: &mut FheUint<R, T>,
|
||||
circuit: &C,
|
||||
a: &FheUintPrepared<A, T, BE>,
|
||||
key: &H,
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
T: UnsignedInteger,
|
||||
C: GetBitCircuitInfo,
|
||||
R: DataMut,
|
||||
A: DataRef,
|
||||
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
|
||||
H: GLWEAutomorphismKeyHelper<K, BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
self.execute_bdd_circuit_1w_to_1w_multi_thread(1, out, circuit, a, key, scratch);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// Operations Z x Z -> Z
|
||||
fn execute_bdd_circuit_1w_to_1w_multi_thread<R, C, A, K, H, T>(
|
||||
&self,
|
||||
threads: usize,
|
||||
out: &mut FheUint<R, T>,
|
||||
circuit: &C,
|
||||
a: &FheUintPrepared<A, T, BE>,
|
||||
key: &H,
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
T: UnsignedInteger,
|
||||
C: GetBitCircuitInfo,
|
||||
R: DataMut,
|
||||
A: DataRef,
|
||||
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
|
||||
H: GLWEAutomorphismKeyHelper<K, BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
let (mut out_bits, scratch_1) = scratch.take_glwe_slice(T::BITS as usize, out);
|
||||
|
||||
// Evaluates out[i] = circuit[i](a, b)
|
||||
self.execute_bdd_circuit_multi_thread(threads, &mut out_bits, a, circuit, scratch_1);
|
||||
|
||||
// Repacks the bits
|
||||
out.pack(self, out_bits, key, scratch_1);
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! define_bdd_1w_to_1w_trait {
|
||||
($(#[$meta:meta])* $vis:vis $trait_name:ident, $method_name:ident) => {
|
||||
paste::paste! {
|
||||
$(#[$meta])*
|
||||
$vis trait $trait_name<T: UnsignedInteger, BE: Backend> {
|
||||
|
||||
/// Single-threaded version
|
||||
fn $method_name<A, M, K, H>(
|
||||
&mut self,
|
||||
module: &M,
|
||||
a: &FheUintPrepared<A, T, BE>,
|
||||
key: &H,
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
M: ExecuteBDDCircuit1WTo1W<BE>,
|
||||
A: DataRef,
|
||||
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
|
||||
H: GLWEAutomorphismKeyHelper<K, BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>;
|
||||
|
||||
/// Multithreaded version – same vis, method_name + "_multi_thread"
|
||||
fn [<$method_name _multi_thread>]<A, M, K, H>(
|
||||
&mut self,
|
||||
threads: usize,
|
||||
module: &M,
|
||||
a: &FheUintPrepared<A, T, BE>,
|
||||
key: &H,
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
M: ExecuteBDDCircuit1WTo1W<BE>,
|
||||
A: DataRef,
|
||||
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
|
||||
H: GLWEAutomorphismKeyHelper<K, BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>;
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! impl_bdd_1w_to_1w_trait {
|
||||
($trait_name:ident, $method_name:ident, $ty:ty, $circuit_ty:ty, $output_circuits:path) => {
|
||||
paste::paste! {
|
||||
impl<D: DataMut, BE: Backend> $trait_name<$ty, BE> for FheUint<D, $ty> {
|
||||
|
||||
fn $method_name<A, M, K, H>(
|
||||
&mut self,
|
||||
module: &M,
|
||||
a: &FheUintPrepared<A, $ty, BE>,
|
||||
key: &H,
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
M: ExecuteBDDCircuit1WTo1W<BE>,
|
||||
A: DataRef,
|
||||
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
|
||||
H: GLWEAutomorphismKeyHelper<K, BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
module.execute_bdd_circuit_1w_to_1w(self, &$output_circuits, a, key, scratch)
|
||||
}
|
||||
|
||||
fn [<$method_name _multi_thread>]<A, M, K, H>(
|
||||
&mut self,
|
||||
threads: usize,
|
||||
module: &M,
|
||||
a: &FheUintPrepared<A, $ty, BE>,
|
||||
key: &H,
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
M: ExecuteBDDCircuit1WTo1W<BE>,
|
||||
A: DataRef,
|
||||
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
|
||||
H: GLWEAutomorphismKeyHelper<K, BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
module.execute_bdd_circuit_1w_to_1w_multi_thread(threads, self, &$output_circuits, a, key, scratch)
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
define_bdd_1w_to_1w_trait!(pub Identity, identity);
|
||||
|
||||
impl_bdd_1w_to_1w_trait!(
|
||||
Identity,
|
||||
identity,
|
||||
u32,
|
||||
circuits::u32::identity_codgen::AnyBitCircuit,
|
||||
circuits::u32::identity_codgen::OUTPUT_CIRCUITS
|
||||
);
|
||||
@@ -209,7 +209,6 @@ macro_rules! impl_bdd_2w_to_1w_trait {
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
define_bdd_2w_to_1w_trait!(pub Add, add);
|
||||
define_bdd_2w_to_1w_trait!(pub Sub, sub);
|
||||
define_bdd_2w_to_1w_trait!(pub Sll, sll);
|
||||
|
||||
@@ -20,7 +20,9 @@ use poulpy_hal::{
|
||||
source::Source,
|
||||
};
|
||||
|
||||
use crate::tfhe::bdd_arithmetic::{BDDKey, BDDKeyHelper, BDDKeyInfos, BDDKeyPrepared, BDDKeyPreparedFactory, FheUint, ToBits};
|
||||
use crate::tfhe::bdd_arithmetic::{
|
||||
BDDKey, BDDKeyHelper, BDDKeyInfos, BDDKeyPrepared, BDDKeyPreparedFactory, BitSize, FheUint, ToBits,
|
||||
};
|
||||
use crate::tfhe::bdd_arithmetic::{Cmux, FromBits, ScratchTakeBDD, UnsignedInteger};
|
||||
use crate::tfhe::blind_rotation::BlindRotationAlgo;
|
||||
use crate::tfhe::circuit_bootstrapping::{CircuitBootstrappingKeyInfos, CirtuitBootstrappingExecute};
|
||||
@@ -55,6 +57,12 @@ impl<D: DataMut, T: UnsignedInteger, BE: Backend> GetGGSWBitMut<T, BE> for FheUi
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data, T: UnsignedInteger, BE: Backend> BitSize for FheUintPrepared<D, T, BE> {
|
||||
fn bit_size(&self) -> usize {
|
||||
T::BITS as usize
|
||||
}
|
||||
}
|
||||
|
||||
pub trait FheUintPreparedFactory<T: UnsignedInteger, BE: Backend>
|
||||
where
|
||||
Self: Sized + GGSWPreparedFactory<BE>,
|
||||
|
||||
@@ -0,0 +1,111 @@
|
||||
use crate::tfhe::bdd_arithmetic::{BitCircuit, BitCircuitFamily, BitCircuitInfo, Circuit, Node};
|
||||
pub(crate) enum AnyBitCircuit {
|
||||
B0(BitCircuit<2>),
|
||||
B1(BitCircuit<2>),
|
||||
B2(BitCircuit<2>),
|
||||
B3(BitCircuit<2>),
|
||||
B4(BitCircuit<2>),
|
||||
B5(BitCircuit<2>),
|
||||
B6(BitCircuit<2>),
|
||||
B7(BitCircuit<2>),
|
||||
B8(BitCircuit<2>),
|
||||
B9(BitCircuit<2>),
|
||||
B10(BitCircuit<2>),
|
||||
B11(BitCircuit<2>),
|
||||
B12(BitCircuit<2>),
|
||||
B13(BitCircuit<2>),
|
||||
B14(BitCircuit<2>),
|
||||
B15(BitCircuit<2>),
|
||||
B16(BitCircuit<2>),
|
||||
B17(BitCircuit<2>),
|
||||
B18(BitCircuit<2>),
|
||||
B19(BitCircuit<2>),
|
||||
B20(BitCircuit<2>),
|
||||
B21(BitCircuit<2>),
|
||||
B22(BitCircuit<2>),
|
||||
B23(BitCircuit<2>),
|
||||
B24(BitCircuit<2>),
|
||||
B25(BitCircuit<2>),
|
||||
B26(BitCircuit<2>),
|
||||
B27(BitCircuit<2>),
|
||||
B28(BitCircuit<2>),
|
||||
B29(BitCircuit<2>),
|
||||
B30(BitCircuit<2>),
|
||||
B31(BitCircuit<2>),
|
||||
}
|
||||
impl BitCircuitFamily for AnyBitCircuit {
|
||||
const INPUT_BITS: usize = 32usize;
|
||||
const OUTPUT_BITS: usize = 32usize;
|
||||
}
|
||||
impl BitCircuitInfo for AnyBitCircuit {
|
||||
fn info(&self) -> (&[Node], usize) {
|
||||
match self {
|
||||
AnyBitCircuit::B0(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
|
||||
AnyBitCircuit::B1(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
|
||||
AnyBitCircuit::B2(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
|
||||
AnyBitCircuit::B3(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
|
||||
AnyBitCircuit::B4(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
|
||||
AnyBitCircuit::B5(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
|
||||
AnyBitCircuit::B6(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
|
||||
AnyBitCircuit::B7(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
|
||||
AnyBitCircuit::B8(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
|
||||
AnyBitCircuit::B9(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
|
||||
AnyBitCircuit::B10(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
|
||||
AnyBitCircuit::B11(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
|
||||
AnyBitCircuit::B12(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
|
||||
AnyBitCircuit::B13(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
|
||||
AnyBitCircuit::B14(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
|
||||
AnyBitCircuit::B15(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
|
||||
AnyBitCircuit::B16(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
|
||||
AnyBitCircuit::B17(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
|
||||
AnyBitCircuit::B18(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
|
||||
AnyBitCircuit::B19(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
|
||||
AnyBitCircuit::B20(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
|
||||
AnyBitCircuit::B21(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
|
||||
AnyBitCircuit::B22(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
|
||||
AnyBitCircuit::B23(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
|
||||
AnyBitCircuit::B24(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
|
||||
AnyBitCircuit::B25(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
|
||||
AnyBitCircuit::B26(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
|
||||
AnyBitCircuit::B27(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
|
||||
AnyBitCircuit::B28(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
|
||||
AnyBitCircuit::B29(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
|
||||
AnyBitCircuit::B30(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
|
||||
AnyBitCircuit::B31(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state),
|
||||
}
|
||||
}
|
||||
}
|
||||
pub(crate) static OUTPUT_CIRCUITS: Circuit<AnyBitCircuit, 32usize> = Circuit([
|
||||
AnyBitCircuit::B0(BitCircuit::new([Node::Cmux(0, 1, 0), Node::None], 2)),
|
||||
AnyBitCircuit::B1(BitCircuit::new([Node::Cmux(1, 1, 0), Node::None], 2)),
|
||||
AnyBitCircuit::B2(BitCircuit::new([Node::Cmux(2, 1, 0), Node::None], 2)),
|
||||
AnyBitCircuit::B3(BitCircuit::new([Node::Cmux(3, 1, 0), Node::None], 2)),
|
||||
AnyBitCircuit::B4(BitCircuit::new([Node::Cmux(4, 1, 0), Node::None], 2)),
|
||||
AnyBitCircuit::B5(BitCircuit::new([Node::Cmux(5, 1, 0), Node::None], 2)),
|
||||
AnyBitCircuit::B6(BitCircuit::new([Node::Cmux(6, 1, 0), Node::None], 2)),
|
||||
AnyBitCircuit::B7(BitCircuit::new([Node::Cmux(7, 1, 0), Node::None], 2)),
|
||||
AnyBitCircuit::B8(BitCircuit::new([Node::Cmux(8, 1, 0), Node::None], 2)),
|
||||
AnyBitCircuit::B9(BitCircuit::new([Node::Cmux(9, 1, 0), Node::None], 2)),
|
||||
AnyBitCircuit::B10(BitCircuit::new([Node::Cmux(10, 1, 0), Node::None], 2)),
|
||||
AnyBitCircuit::B11(BitCircuit::new([Node::Cmux(11, 1, 0), Node::None], 2)),
|
||||
AnyBitCircuit::B12(BitCircuit::new([Node::Cmux(12, 1, 0), Node::None], 2)),
|
||||
AnyBitCircuit::B13(BitCircuit::new([Node::Cmux(13, 1, 0), Node::None], 2)),
|
||||
AnyBitCircuit::B14(BitCircuit::new([Node::Cmux(14, 1, 0), Node::None], 2)),
|
||||
AnyBitCircuit::B15(BitCircuit::new([Node::Cmux(15, 1, 0), Node::None], 2)),
|
||||
AnyBitCircuit::B16(BitCircuit::new([Node::Cmux(16, 1, 0), Node::None], 2)),
|
||||
AnyBitCircuit::B17(BitCircuit::new([Node::Cmux(17, 1, 0), Node::None], 2)),
|
||||
AnyBitCircuit::B18(BitCircuit::new([Node::Cmux(18, 1, 0), Node::None], 2)),
|
||||
AnyBitCircuit::B19(BitCircuit::new([Node::Cmux(19, 1, 0), Node::None], 2)),
|
||||
AnyBitCircuit::B20(BitCircuit::new([Node::Cmux(20, 1, 0), Node::None], 2)),
|
||||
AnyBitCircuit::B21(BitCircuit::new([Node::Cmux(21, 1, 0), Node::None], 2)),
|
||||
AnyBitCircuit::B22(BitCircuit::new([Node::Cmux(22, 1, 0), Node::None], 2)),
|
||||
AnyBitCircuit::B23(BitCircuit::new([Node::Cmux(23, 1, 0), Node::None], 2)),
|
||||
AnyBitCircuit::B24(BitCircuit::new([Node::Cmux(24, 1, 0), Node::None], 2)),
|
||||
AnyBitCircuit::B25(BitCircuit::new([Node::Cmux(25, 1, 0), Node::None], 2)),
|
||||
AnyBitCircuit::B26(BitCircuit::new([Node::Cmux(26, 1, 0), Node::None], 2)),
|
||||
AnyBitCircuit::B27(BitCircuit::new([Node::Cmux(27, 1, 0), Node::None], 2)),
|
||||
AnyBitCircuit::B28(BitCircuit::new([Node::Cmux(28, 1, 0), Node::None], 2)),
|
||||
AnyBitCircuit::B29(BitCircuit::new([Node::Cmux(29, 1, 0), Node::None], 2)),
|
||||
AnyBitCircuit::B30(BitCircuit::new([Node::Cmux(30, 1, 0), Node::None], 2)),
|
||||
AnyBitCircuit::B31(BitCircuit::new([Node::Cmux(31, 1, 0), Node::None], 2)),
|
||||
]);
|
||||
@@ -1,5 +1,6 @@
|
||||
pub(crate) mod add_codegen;
|
||||
pub(crate) mod and_codegen;
|
||||
pub(crate) mod identity_codgen;
|
||||
pub(crate) mod or_codegen;
|
||||
pub(crate) mod sll_codegen;
|
||||
pub(crate) mod slt_codegen;
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
mod bdd_1w_to_1w;
|
||||
mod bdd_2w_to_1w;
|
||||
mod blind_retrieval;
|
||||
mod blind_rotation;
|
||||
@@ -7,6 +8,7 @@ mod circuits;
|
||||
mod eval;
|
||||
mod key;
|
||||
|
||||
pub use bdd_1w_to_1w::*;
|
||||
pub use bdd_2w_to_1w::*;
|
||||
pub use blind_retrieval::*;
|
||||
pub use blind_rotation::*;
|
||||
|
||||
Reference in New Issue
Block a user