diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/add_codegen.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/add_codegen.rs index ad2d28f..056e482 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/add_codegen.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/add_codegen.rs @@ -1,4 +1,4 @@ -use crate::tfhe::bdd_arithmetic::{BitCircuit, BitCircuitInfo, Circuit, GetBitCircuitInfo, Node}; +use crate::tfhe::bdd_arithmetic::{BitCircuit, BitCircuitFamily, BitCircuitInfo, Circuit, Node}; pub(crate) enum AnyBitCircuit { B0(BitCircuit<4>), B1(BitCircuit<8>), @@ -72,16 +72,9 @@ impl BitCircuitInfo for AnyBitCircuit { } } -impl GetBitCircuitInfo for Circuit { - fn input_size(&self) -> usize { - 2 * u32::BITS as usize - } - fn output_size(&self) -> usize { - u32::BITS as usize - } - fn get_circuit(&self, bit: usize) -> (&[Node], usize) { - self.0[bit].info() - } +impl BitCircuitFamily for AnyBitCircuit { + const INPUT_BITS: usize = 64; + const OUTPUT_BITS: usize = 32; } pub(crate) static OUTPUT_CIRCUITS: Circuit = Circuit([ diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/and_codegen.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/and_codegen.rs index 07efe52..d23ef67 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/and_codegen.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/and_codegen.rs @@ -1,4 +1,4 @@ -use crate::tfhe::bdd_arithmetic::{BitCircuit, BitCircuitInfo, Circuit, GetBitCircuitInfo, Node}; +use crate::tfhe::bdd_arithmetic::{BitCircuit, BitCircuitFamily, BitCircuitInfo, Circuit, Node}; pub(crate) enum AnyBitCircuit { B0(BitCircuit<4>), B1(BitCircuit<4>), @@ -72,16 +72,9 @@ impl BitCircuitInfo for AnyBitCircuit { } } -impl GetBitCircuitInfo for Circuit { - fn input_size(&self) -> usize { - 2 * u32::BITS as usize - } - fn output_size(&self) -> usize { - u32::BITS as usize - } - fn get_circuit(&self, bit: usize) -> (&[Node], usize) { - self.0[bit].info() - } +impl BitCircuitFamily for AnyBitCircuit { + const INPUT_BITS: usize = 64; + const OUTPUT_BITS: usize = 32; } pub(crate) static OUTPUT_CIRCUITS: Circuit = Circuit([ diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/or_codegen.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/or_codegen.rs index e0c893f..a41c701 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/or_codegen.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/or_codegen.rs @@ -1,4 +1,4 @@ -use crate::tfhe::bdd_arithmetic::{BitCircuit, BitCircuitInfo, Circuit, GetBitCircuitInfo, Node}; +use crate::tfhe::bdd_arithmetic::{BitCircuit, BitCircuitFamily, BitCircuitInfo, Circuit, Node}; pub(crate) enum AnyBitCircuit { B0(BitCircuit<4>), B1(BitCircuit<4>), @@ -72,16 +72,9 @@ impl BitCircuitInfo for AnyBitCircuit { } } -impl GetBitCircuitInfo for Circuit { - fn input_size(&self) -> usize { - 2 * u32::BITS as usize - } - fn output_size(&self) -> usize { - u32::BITS as usize - } - fn get_circuit(&self, bit: usize) -> (&[Node], usize) { - self.0[bit].info() - } +impl BitCircuitFamily for AnyBitCircuit { + const INPUT_BITS: usize = 64; + const OUTPUT_BITS: usize = 32; } pub(crate) static OUTPUT_CIRCUITS: Circuit = Circuit([ diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/sll_codegen.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/sll_codegen.rs index 68a378c..ec2be77 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/sll_codegen.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/sll_codegen.rs @@ -1,4 +1,4 @@ -use crate::tfhe::bdd_arithmetic::{BitCircuit, BitCircuitInfo, Circuit, GetBitCircuitInfo, Node}; +use crate::tfhe::bdd_arithmetic::{BitCircuit, BitCircuitFamily, BitCircuitInfo, Circuit, Node}; pub(crate) enum AnyBitCircuit { B0(BitCircuit<12>), B1(BitCircuit<18>), @@ -72,16 +72,9 @@ impl BitCircuitInfo for AnyBitCircuit { } } -impl GetBitCircuitInfo for Circuit { - fn input_size(&self) -> usize { - 2 * u32::BITS as usize - } - fn output_size(&self) -> usize { - u32::BITS as usize - } - fn get_circuit(&self, bit: usize) -> (&[Node], usize) { - self.0[bit].info() - } +impl BitCircuitFamily for AnyBitCircuit { + const INPUT_BITS: usize = 37; + const OUTPUT_BITS: usize = 32; } pub(crate) static OUTPUT_CIRCUITS: Circuit = Circuit([ diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/slt_codegen.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/slt_codegen.rs index de19969..de2a69d 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/slt_codegen.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/slt_codegen.rs @@ -1,4 +1,4 @@ -use crate::tfhe::bdd_arithmetic::{BitCircuit, BitCircuitInfo, Circuit, GetBitCircuitInfo, Node}; +use crate::tfhe::bdd_arithmetic::{BitCircuit, BitCircuitFamily, BitCircuitInfo, Circuit, Node}; pub(crate) enum AnyBitCircuit { B0(BitCircuit<256>), } @@ -10,16 +10,9 @@ impl BitCircuitInfo for AnyBitCircuit { } } -impl GetBitCircuitInfo for Circuit { - fn input_size(&self) -> usize { - 2 * u32::BITS as usize - } - fn output_size(&self) -> usize { - 1 - } - fn get_circuit(&self, bit: usize) -> (&[Node], usize) { - self.0[bit].info() - } +impl BitCircuitFamily for AnyBitCircuit { + const INPUT_BITS: usize = 64; + const OUTPUT_BITS: usize = 1; } pub(crate) static OUTPUT_CIRCUITS: Circuit = Circuit([AnyBitCircuit::B0(BitCircuit::new( diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/sltu_codegen.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/sltu_codegen.rs index 96b9189..77cd451 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/sltu_codegen.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/sltu_codegen.rs @@ -1,4 +1,4 @@ -use crate::tfhe::bdd_arithmetic::{BitCircuit, BitCircuitInfo, Circuit, GetBitCircuitInfo, Node}; +use crate::tfhe::bdd_arithmetic::{BitCircuit, BitCircuitFamily, BitCircuitInfo, Circuit, Node}; pub(crate) enum AnyBitCircuit { B0(BitCircuit<256>), } @@ -10,16 +10,9 @@ impl BitCircuitInfo for AnyBitCircuit { } } -impl GetBitCircuitInfo for Circuit { - fn input_size(&self) -> usize { - 2 * u32::BITS as usize - } - fn output_size(&self) -> usize { - 1 - } - fn get_circuit(&self, bit: usize) -> (&[Node], usize) { - self.0[bit].info() - } +impl BitCircuitFamily for AnyBitCircuit { + const INPUT_BITS: usize = 64; + const OUTPUT_BITS: usize = 1; } pub(crate) static OUTPUT_CIRCUITS: Circuit = Circuit([AnyBitCircuit::B0(BitCircuit::new( diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/sra_codegen.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/sra_codegen.rs index a1a0c14..1e7f886 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/sra_codegen.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/sra_codegen.rs @@ -1,4 +1,4 @@ -use crate::tfhe::bdd_arithmetic::{BitCircuit, BitCircuitInfo, Circuit, GetBitCircuitInfo, Node}; +use crate::tfhe::bdd_arithmetic::{BitCircuit, BitCircuitFamily, BitCircuitInfo, Circuit, Node}; pub(crate) enum AnyBitCircuit { B0(BitCircuit<192>), B1(BitCircuit<186>), @@ -72,16 +72,9 @@ impl BitCircuitInfo for AnyBitCircuit { } } -impl GetBitCircuitInfo for Circuit { - fn input_size(&self) -> usize { - 2 * u32::BITS as usize - } - fn output_size(&self) -> usize { - u32::BITS as usize - } - fn get_circuit(&self, bit: usize) -> (&[Node], usize) { - self.0[bit].info() - } +impl BitCircuitFamily for AnyBitCircuit { + const INPUT_BITS: usize = 37; + const OUTPUT_BITS: usize = 32; } pub(crate) static OUTPUT_CIRCUITS: Circuit = Circuit([ diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/srl_codegen.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/srl_codegen.rs index d9ec736..e85c03d 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/srl_codegen.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/srl_codegen.rs @@ -1,4 +1,4 @@ -use crate::tfhe::bdd_arithmetic::{BitCircuit, BitCircuitInfo, Circuit, GetBitCircuitInfo, Node}; +use crate::tfhe::bdd_arithmetic::{BitCircuit, BitCircuitFamily, BitCircuitInfo, Circuit, Node}; pub(crate) enum AnyBitCircuit { B0(BitCircuit<192>), B1(BitCircuit<192>), @@ -72,16 +72,9 @@ impl BitCircuitInfo for AnyBitCircuit { } } -impl GetBitCircuitInfo for Circuit { - fn input_size(&self) -> usize { - 2 * u32::BITS as usize - } - fn output_size(&self) -> usize { - u32::BITS as usize - } - fn get_circuit(&self, bit: usize) -> (&[Node], usize) { - self.0[bit].info() - } +impl BitCircuitFamily for AnyBitCircuit { + const INPUT_BITS: usize = 37; + const OUTPUT_BITS: usize = 32; } pub(crate) static OUTPUT_CIRCUITS: Circuit = Circuit([ diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/sub_codegen.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/sub_codegen.rs index 84e8eef..5bcca74 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/sub_codegen.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/sub_codegen.rs @@ -1,4 +1,4 @@ -use crate::tfhe::bdd_arithmetic::{BitCircuit, BitCircuitInfo, Circuit, GetBitCircuitInfo, Node}; +use crate::tfhe::bdd_arithmetic::{BitCircuit, BitCircuitFamily, BitCircuitInfo, Circuit, Node}; pub(crate) enum AnyBitCircuit { B0(BitCircuit<4>), B1(BitCircuit<8>), @@ -72,16 +72,9 @@ impl BitCircuitInfo for AnyBitCircuit { } } -impl GetBitCircuitInfo for Circuit { - fn input_size(&self) -> usize { - 2 * u32::BITS as usize - } - fn output_size(&self) -> usize { - u32::BITS as usize - } - fn get_circuit(&self, bit: usize) -> (&[Node], usize) { - self.0[bit].info() - } +impl BitCircuitFamily for AnyBitCircuit { + const INPUT_BITS: usize = 64; + const OUTPUT_BITS: usize = 32; } pub(crate) static OUTPUT_CIRCUITS: Circuit = Circuit([ diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/xor_codegen.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/xor_codegen.rs index 1cd9f57..ecc9b62 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/xor_codegen.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/xor_codegen.rs @@ -1,4 +1,4 @@ -use crate::tfhe::bdd_arithmetic::{BitCircuit, BitCircuitInfo, Circuit, GetBitCircuitInfo, Node}; +use crate::tfhe::bdd_arithmetic::{BitCircuit, BitCircuitFamily, BitCircuitInfo, Circuit, Node}; pub(crate) enum AnyBitCircuit { B0(BitCircuit<4>), B1(BitCircuit<4>), @@ -72,16 +72,9 @@ impl BitCircuitInfo for AnyBitCircuit { } } -impl GetBitCircuitInfo for Circuit { - fn input_size(&self) -> usize { - 2 * u32::BITS as usize - } - fn output_size(&self) -> usize { - u32::BITS as usize - } - fn get_circuit(&self, bit: usize) -> (&[Node], usize) { - self.0[bit].info() - } +impl BitCircuitFamily for AnyBitCircuit { + const INPUT_BITS: usize = 64; + const OUTPUT_BITS: usize = 32; } pub(crate) static OUTPUT_CIRCUITS: Circuit = Circuit([ diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs index 2244eca..b845a22 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs @@ -19,13 +19,33 @@ pub trait GetBitCircuitInfo { fn get_circuit(&self, bit: usize) -> (&[Node], usize); } -pub(crate) struct BitCircuit { - pub(crate) nodes: [Node; N], - pub(crate) max_inter_state: usize, +pub struct BitCircuit { + pub nodes: [Node; N], + pub max_inter_state: usize, +} + +pub trait BitCircuitFamily { + const INPUT_BITS: usize; + const OUTPUT_BITS: usize; } pub struct Circuit(pub [C; N]); +impl GetBitCircuitInfo for Circuit +where + C: BitCircuitInfo + BitCircuitFamily, +{ + fn input_size(&self) -> usize { + C::INPUT_BITS + } + fn output_size(&self) -> usize { + C::OUTPUT_BITS + } + fn get_circuit(&self, bit: usize) -> (&[Node], usize) { + self.0[bit].info() + } +} + pub trait ExecuteBDDCircuit { fn execute_bdd_circuit(&self, out: &mut [GLWE], inputs: &G, circuit: &C, scratch: &mut Scratch) where @@ -51,7 +71,7 @@ where { #[cfg(debug_assertions)] { - assert_eq!(inputs.bit_size(), circuit.input_size()); + assert!(inputs.bit_size() >= circuit.input_size()); assert!(out.len() >= circuit.output_size()); } @@ -119,7 +139,7 @@ where } impl BitCircuit { - pub(crate) const fn new(nodes: [Node; N], max_inter_state: usize) -> Self { + pub const fn new(nodes: [Node; N], max_inter_state: usize) -> Self { Self { nodes, max_inter_state, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/mod.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/mod.rs index 66b0699..5b3661f 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/mod.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/mod.rs @@ -9,7 +9,7 @@ pub use bdd_2w_to_1w::*; pub use blind_rotation::*; pub use ciphertexts::*; pub(crate) use circuits::*; -pub(crate) use eval::*; +pub use eval::*; pub use key::*; pub mod tests;