diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_1w_to_1w.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_1w_to_1w.rs new file mode 100644 index 0000000..7fc4ded --- /dev/null +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_1w_to_1w.rs @@ -0,0 +1,157 @@ +use poulpy_core::{ + GLWECopy, GLWEPacking, ScratchTakeCore, + layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWEAutomorphismKeyHelper, GetGaloisElement}, +}; +use poulpy_hal::{ + api::ModuleLogN, + layouts::{Backend, DataMut, DataRef, Module, Scratch}, +}; + +use crate::tfhe::bdd_arithmetic::{ExecuteBDDCircuit, FheUint, FheUintPrepared, GetBitCircuitInfo, UnsignedInteger, circuits}; + +impl ExecuteBDDCircuit1WTo1W for Module where Self: Sized + ExecuteBDDCircuit + GLWEPacking + GLWECopy +{} + +pub trait ExecuteBDDCircuit1WTo1W +where + Self: Sized + ModuleLogN + ExecuteBDDCircuit + GLWEPacking + GLWECopy, +{ + fn execute_bdd_circuit_1w_to_1w( + &self, + out: &mut FheUint, + circuit: &C, + a: &FheUintPrepared, + key: &H, + scratch: &mut Scratch, + ) where + T: UnsignedInteger, + C: GetBitCircuitInfo, + R: DataMut, + A: DataRef, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, + H: GLWEAutomorphismKeyHelper, + Scratch: ScratchTakeCore, + { + self.execute_bdd_circuit_1w_to_1w_multi_thread(1, out, circuit, a, key, scratch); + } + + #[allow(clippy::too_many_arguments)] + /// Operations Z x Z -> Z + fn execute_bdd_circuit_1w_to_1w_multi_thread( + &self, + threads: usize, + out: &mut FheUint, + circuit: &C, + a: &FheUintPrepared, + key: &H, + scratch: &mut Scratch, + ) where + T: UnsignedInteger, + C: GetBitCircuitInfo, + R: DataMut, + A: DataRef, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, + H: GLWEAutomorphismKeyHelper, + Scratch: ScratchTakeCore, + { + let (mut out_bits, scratch_1) = scratch.take_glwe_slice(T::BITS as usize, out); + + // Evaluates out[i] = circuit[i](a, b) + self.execute_bdd_circuit_multi_thread(threads, &mut out_bits, a, circuit, scratch_1); + + // Repacks the bits + out.pack(self, out_bits, key, scratch_1); + } +} + +#[macro_export] +macro_rules! define_bdd_1w_to_1w_trait { + ($(#[$meta:meta])* $vis:vis $trait_name:ident, $method_name:ident) => { + paste::paste! { + $(#[$meta])* + $vis trait $trait_name { + + /// Single-threaded version + fn $method_name( + &mut self, + module: &M, + a: &FheUintPrepared, + key: &H, + scratch: &mut Scratch, + ) where + M: ExecuteBDDCircuit1WTo1W, + A: DataRef, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, + H: GLWEAutomorphismKeyHelper, + Scratch: ScratchTakeCore; + + /// Multithreaded version – same vis, method_name + "_multi_thread" + fn [<$method_name _multi_thread>]( + &mut self, + threads: usize, + module: &M, + a: &FheUintPrepared, + key: &H, + scratch: &mut Scratch, + ) where + M: ExecuteBDDCircuit1WTo1W, + A: DataRef, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, + H: GLWEAutomorphismKeyHelper, + Scratch: ScratchTakeCore; + } + } + }; +} + +#[macro_export] +macro_rules! impl_bdd_1w_to_1w_trait { + ($trait_name:ident, $method_name:ident, $ty:ty, $circuit_ty:ty, $output_circuits:path) => { + paste::paste! { + impl $trait_name<$ty, BE> for FheUint { + + fn $method_name( + &mut self, + module: &M, + a: &FheUintPrepared, + key: &H, + scratch: &mut Scratch, + ) where + M: ExecuteBDDCircuit1WTo1W, + A: DataRef, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, + H: GLWEAutomorphismKeyHelper, + Scratch: ScratchTakeCore, + { + module.execute_bdd_circuit_1w_to_1w(self, &$output_circuits, a, key, scratch) + } + + fn [<$method_name _multi_thread>]( + &mut self, + threads: usize, + module: &M, + a: &FheUintPrepared, + key: &H, + scratch: &mut Scratch, + ) where + M: ExecuteBDDCircuit1WTo1W, + A: DataRef, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, + H: GLWEAutomorphismKeyHelper, + Scratch: ScratchTakeCore, + { + module.execute_bdd_circuit_1w_to_1w_multi_thread(threads, self, &$output_circuits, a, key, scratch) + } + } + } + }; +} +define_bdd_1w_to_1w_trait!(pub Identity, identity); + +impl_bdd_1w_to_1w_trait!( + Identity, + identity, + u32, + circuits::u32::identity_codgen::AnyBitCircuit, + circuits::u32::identity_codgen::OUTPUT_CIRCUITS +); diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs index f167657..59a745c 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs @@ -209,7 +209,6 @@ macro_rules! impl_bdd_2w_to_1w_trait { } }; } - define_bdd_2w_to_1w_trait!(pub Add, add); define_bdd_2w_to_1w_trait!(pub Sub, sub); define_bdd_2w_to_1w_trait!(pub Sll, sll); diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared.rs index 2f26b06..bc4f74c 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared.rs @@ -20,7 +20,9 @@ use poulpy_hal::{ source::Source, }; -use crate::tfhe::bdd_arithmetic::{BDDKey, BDDKeyHelper, BDDKeyInfos, BDDKeyPrepared, BDDKeyPreparedFactory, FheUint, ToBits}; +use crate::tfhe::bdd_arithmetic::{ + BDDKey, BDDKeyHelper, BDDKeyInfos, BDDKeyPrepared, BDDKeyPreparedFactory, BitSize, FheUint, ToBits, +}; use crate::tfhe::bdd_arithmetic::{Cmux, FromBits, ScratchTakeBDD, UnsignedInteger}; use crate::tfhe::blind_rotation::BlindRotationAlgo; use crate::tfhe::circuit_bootstrapping::{CircuitBootstrappingKeyInfos, CirtuitBootstrappingExecute}; @@ -55,6 +57,12 @@ impl GetGGSWBitMut for FheUi } } +impl BitSize for FheUintPrepared { + fn bit_size(&self) -> usize { + T::BITS as usize + } +} + pub trait FheUintPreparedFactory where Self: Sized + GGSWPreparedFactory, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/identity_codgen.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/identity_codgen.rs new file mode 100644 index 0000000..caae248 --- /dev/null +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/identity_codgen.rs @@ -0,0 +1,111 @@ +use crate::tfhe::bdd_arithmetic::{BitCircuit, BitCircuitFamily, BitCircuitInfo, Circuit, Node}; +pub(crate) enum AnyBitCircuit { + B0(BitCircuit<2>), + B1(BitCircuit<2>), + B2(BitCircuit<2>), + B3(BitCircuit<2>), + B4(BitCircuit<2>), + B5(BitCircuit<2>), + B6(BitCircuit<2>), + B7(BitCircuit<2>), + B8(BitCircuit<2>), + B9(BitCircuit<2>), + B10(BitCircuit<2>), + B11(BitCircuit<2>), + B12(BitCircuit<2>), + B13(BitCircuit<2>), + B14(BitCircuit<2>), + B15(BitCircuit<2>), + B16(BitCircuit<2>), + B17(BitCircuit<2>), + B18(BitCircuit<2>), + B19(BitCircuit<2>), + B20(BitCircuit<2>), + B21(BitCircuit<2>), + B22(BitCircuit<2>), + B23(BitCircuit<2>), + B24(BitCircuit<2>), + B25(BitCircuit<2>), + B26(BitCircuit<2>), + B27(BitCircuit<2>), + B28(BitCircuit<2>), + B29(BitCircuit<2>), + B30(BitCircuit<2>), + B31(BitCircuit<2>), +} +impl BitCircuitFamily for AnyBitCircuit { + const INPUT_BITS: usize = 32usize; + const OUTPUT_BITS: usize = 32usize; +} +impl BitCircuitInfo for AnyBitCircuit { + fn info(&self) -> (&[Node], usize) { + match self { + AnyBitCircuit::B0(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state), + AnyBitCircuit::B1(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state), + AnyBitCircuit::B2(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state), + AnyBitCircuit::B3(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state), + AnyBitCircuit::B4(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state), + AnyBitCircuit::B5(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state), + AnyBitCircuit::B6(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state), + AnyBitCircuit::B7(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state), + AnyBitCircuit::B8(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state), + AnyBitCircuit::B9(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state), + AnyBitCircuit::B10(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state), + AnyBitCircuit::B11(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state), + AnyBitCircuit::B12(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state), + AnyBitCircuit::B13(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state), + AnyBitCircuit::B14(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state), + AnyBitCircuit::B15(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state), + AnyBitCircuit::B16(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state), + AnyBitCircuit::B17(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state), + AnyBitCircuit::B18(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state), + AnyBitCircuit::B19(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state), + AnyBitCircuit::B20(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state), + AnyBitCircuit::B21(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state), + AnyBitCircuit::B22(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state), + AnyBitCircuit::B23(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state), + AnyBitCircuit::B24(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state), + AnyBitCircuit::B25(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state), + AnyBitCircuit::B26(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state), + AnyBitCircuit::B27(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state), + AnyBitCircuit::B28(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state), + AnyBitCircuit::B29(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state), + AnyBitCircuit::B30(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state), + AnyBitCircuit::B31(bit_circuit) => (bit_circuit.nodes.as_ref(), bit_circuit.max_inter_state), + } + } +} +pub(crate) static OUTPUT_CIRCUITS: Circuit = Circuit([ + AnyBitCircuit::B0(BitCircuit::new([Node::Cmux(0, 1, 0), Node::None], 2)), + AnyBitCircuit::B1(BitCircuit::new([Node::Cmux(1, 1, 0), Node::None], 2)), + AnyBitCircuit::B2(BitCircuit::new([Node::Cmux(2, 1, 0), Node::None], 2)), + AnyBitCircuit::B3(BitCircuit::new([Node::Cmux(3, 1, 0), Node::None], 2)), + AnyBitCircuit::B4(BitCircuit::new([Node::Cmux(4, 1, 0), Node::None], 2)), + AnyBitCircuit::B5(BitCircuit::new([Node::Cmux(5, 1, 0), Node::None], 2)), + AnyBitCircuit::B6(BitCircuit::new([Node::Cmux(6, 1, 0), Node::None], 2)), + AnyBitCircuit::B7(BitCircuit::new([Node::Cmux(7, 1, 0), Node::None], 2)), + AnyBitCircuit::B8(BitCircuit::new([Node::Cmux(8, 1, 0), Node::None], 2)), + AnyBitCircuit::B9(BitCircuit::new([Node::Cmux(9, 1, 0), Node::None], 2)), + AnyBitCircuit::B10(BitCircuit::new([Node::Cmux(10, 1, 0), Node::None], 2)), + AnyBitCircuit::B11(BitCircuit::new([Node::Cmux(11, 1, 0), Node::None], 2)), + AnyBitCircuit::B12(BitCircuit::new([Node::Cmux(12, 1, 0), Node::None], 2)), + AnyBitCircuit::B13(BitCircuit::new([Node::Cmux(13, 1, 0), Node::None], 2)), + AnyBitCircuit::B14(BitCircuit::new([Node::Cmux(14, 1, 0), Node::None], 2)), + AnyBitCircuit::B15(BitCircuit::new([Node::Cmux(15, 1, 0), Node::None], 2)), + AnyBitCircuit::B16(BitCircuit::new([Node::Cmux(16, 1, 0), Node::None], 2)), + AnyBitCircuit::B17(BitCircuit::new([Node::Cmux(17, 1, 0), Node::None], 2)), + AnyBitCircuit::B18(BitCircuit::new([Node::Cmux(18, 1, 0), Node::None], 2)), + AnyBitCircuit::B19(BitCircuit::new([Node::Cmux(19, 1, 0), Node::None], 2)), + AnyBitCircuit::B20(BitCircuit::new([Node::Cmux(20, 1, 0), Node::None], 2)), + AnyBitCircuit::B21(BitCircuit::new([Node::Cmux(21, 1, 0), Node::None], 2)), + AnyBitCircuit::B22(BitCircuit::new([Node::Cmux(22, 1, 0), Node::None], 2)), + AnyBitCircuit::B23(BitCircuit::new([Node::Cmux(23, 1, 0), Node::None], 2)), + AnyBitCircuit::B24(BitCircuit::new([Node::Cmux(24, 1, 0), Node::None], 2)), + AnyBitCircuit::B25(BitCircuit::new([Node::Cmux(25, 1, 0), Node::None], 2)), + AnyBitCircuit::B26(BitCircuit::new([Node::Cmux(26, 1, 0), Node::None], 2)), + AnyBitCircuit::B27(BitCircuit::new([Node::Cmux(27, 1, 0), Node::None], 2)), + AnyBitCircuit::B28(BitCircuit::new([Node::Cmux(28, 1, 0), Node::None], 2)), + AnyBitCircuit::B29(BitCircuit::new([Node::Cmux(29, 1, 0), Node::None], 2)), + AnyBitCircuit::B30(BitCircuit::new([Node::Cmux(30, 1, 0), Node::None], 2)), + AnyBitCircuit::B31(BitCircuit::new([Node::Cmux(31, 1, 0), Node::None], 2)), +]); diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/mod.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/mod.rs index 72deb2a..bcc1e35 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/mod.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/mod.rs @@ -1,5 +1,6 @@ pub(crate) mod add_codegen; pub(crate) mod and_codegen; +pub(crate) mod identity_codgen; pub(crate) mod or_codegen; pub(crate) mod sll_codegen; pub(crate) mod slt_codegen; diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/mod.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/mod.rs index 7ef1921..d8cc783 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/mod.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/mod.rs @@ -1,3 +1,4 @@ +mod bdd_1w_to_1w; mod bdd_2w_to_1w; mod blind_retrieval; mod blind_rotation; @@ -7,6 +8,7 @@ mod circuits; mod eval; mod key; +pub use bdd_1w_to_1w::*; pub use bdd_2w_to_1w::*; pub use blind_retrieval::*; pub use blind_rotation::*;