From 8eafcaff1f287af929318faa7ef346a02e19800a Mon Sep 17 00:00:00 2001 From: Pro7ech Date: Wed, 8 Oct 2025 17:57:40 +0200 Subject: [PATCH] fix BDD Binary Circuits --- .../src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs | 4 +- .../circuits/u32/and_codegen.rs | 151 ++---- .../bdd_arithmetic/circuits/u32/or_codegen.rs | 360 +++++++++++++- .../circuits/u32/xor_codegen.rs | 448 +++++++++++++++++- 4 files changed, 823 insertions(+), 140 deletions(-) diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs index 3836067..938c45e 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs @@ -205,7 +205,7 @@ impl_bdd_2w_to_1w_trait!( Or, or, u32, - 1, + 32, circuits::u32::or_codegen::AnyBitCircuit, circuits::u32::or_codegen::OUTPUT_CIRCUITS ); @@ -214,7 +214,7 @@ impl_bdd_2w_to_1w_trait!( Xor, xor, u32, - 1, + 32, circuits::u32::xor_codegen::AnyBitCircuit, circuits::u32::xor_codegen::OUTPUT_CIRCUITS ); diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/and_codegen.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/and_codegen.rs index a37b5a6..c195ae1 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/and_codegen.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/and_codegen.rs @@ -33,7 +33,6 @@ pub(crate) enum AnyBitCircuit { B30(BitCircuit<3, 2>), B31(BitCircuit<3, 2>), } - impl BitCircuitInfo for AnyBitCircuit { fn info(&self) -> (&[Node], &[usize], usize) { match self { @@ -220,245 +219,157 @@ pub(crate) static OUTPUT_CIRCUITS: Circuit = Circuit([ 2, )), AnyBitCircuit::B1(BitCircuit::new( - [Node::new(1, 0, 0), Node::new(33, 1, 0), Node::new(1, 1, 0)], + [Node::new(0, 0, 0), Node::new(33, 1, 0), Node::new(1, 1, 0)], [0, 2], 2, )), AnyBitCircuit::B2(BitCircuit::new( - [Node::new(2, 0, 0), Node::new(34, 1, 0), Node::new(2, 1, 0)], + [Node::new(0, 0, 0), Node::new(34, 1, 0), Node::new(2, 1, 0)], [0, 2], 2, )), AnyBitCircuit::B3(BitCircuit::new( - [Node::new(3, 0, 0), Node::new(35, 1, 0), Node::new(3, 1, 0)], + [Node::new(0, 0, 0), Node::new(35, 1, 0), Node::new(3, 1, 0)], [0, 2], 2, )), AnyBitCircuit::B4(BitCircuit::new( - [Node::new(4, 0, 0), Node::new(36, 1, 0), Node::new(4, 1, 0)], + [Node::new(0, 0, 0), Node::new(36, 1, 0), Node::new(4, 1, 0)], [0, 2], 2, )), AnyBitCircuit::B5(BitCircuit::new( - [Node::new(5, 0, 0), Node::new(37, 1, 0), Node::new(5, 1, 0)], + [Node::new(0, 0, 0), Node::new(37, 1, 0), Node::new(5, 1, 0)], [0, 2], 2, )), AnyBitCircuit::B6(BitCircuit::new( - [Node::new(6, 0, 0), Node::new(38, 1, 0), Node::new(6, 1, 0)], + [Node::new(0, 0, 0), Node::new(38, 1, 0), Node::new(6, 1, 0)], [0, 2], 2, )), AnyBitCircuit::B7(BitCircuit::new( - [Node::new(7, 0, 0), Node::new(39, 1, 0), Node::new(7, 1, 0)], + [Node::new(0, 0, 0), Node::new(39, 1, 0), Node::new(7, 1, 0)], [0, 2], 2, )), AnyBitCircuit::B8(BitCircuit::new( - [Node::new(8, 0, 0), Node::new(40, 1, 0), Node::new(8, 1, 0)], + [Node::new(0, 0, 0), Node::new(40, 1, 0), Node::new(8, 1, 0)], [0, 2], 2, )), AnyBitCircuit::B9(BitCircuit::new( - [Node::new(9, 0, 0), Node::new(41, 1, 0), Node::new(9, 1, 0)], + [Node::new(0, 0, 0), Node::new(41, 1, 0), Node::new(9, 1, 0)], [0, 2], 2, )), AnyBitCircuit::B10(BitCircuit::new( - [ - Node::new(10, 0, 0), - Node::new(42, 1, 0), - Node::new(10, 1, 0), - ], + [Node::new(0, 0, 0), Node::new(42, 1, 0), Node::new(10, 1, 0)], [0, 2], 2, )), AnyBitCircuit::B11(BitCircuit::new( - [ - Node::new(11, 0, 0), - Node::new(43, 1, 0), - Node::new(11, 1, 0), - ], + [Node::new(0, 0, 0), Node::new(43, 1, 0), Node::new(11, 1, 0)], [0, 2], 2, )), AnyBitCircuit::B12(BitCircuit::new( - [ - Node::new(12, 0, 0), - Node::new(44, 1, 0), - Node::new(12, 1, 0), - ], + [Node::new(0, 0, 0), Node::new(44, 1, 0), Node::new(12, 1, 0)], [0, 2], 2, )), AnyBitCircuit::B13(BitCircuit::new( - [ - Node::new(13, 0, 0), - Node::new(45, 1, 0), - Node::new(13, 1, 0), - ], + [Node::new(0, 0, 0), Node::new(45, 1, 0), Node::new(13, 1, 0)], [0, 2], 2, )), AnyBitCircuit::B14(BitCircuit::new( - [ - Node::new(14, 0, 0), - Node::new(46, 1, 0), - Node::new(14, 1, 0), - ], + [Node::new(0, 0, 0), Node::new(46, 1, 0), Node::new(14, 1, 0)], [0, 2], 2, )), AnyBitCircuit::B15(BitCircuit::new( - [ - Node::new(15, 0, 0), - Node::new(47, 1, 0), - Node::new(15, 1, 0), - ], + [Node::new(0, 0, 0), Node::new(47, 1, 0), Node::new(15, 1, 0)], [0, 2], 2, )), AnyBitCircuit::B16(BitCircuit::new( - [ - Node::new(16, 0, 0), - Node::new(48, 1, 0), - Node::new(16, 1, 0), - ], + [Node::new(0, 0, 0), Node::new(48, 1, 0), Node::new(16, 1, 0)], [0, 2], 2, )), AnyBitCircuit::B17(BitCircuit::new( - [ - Node::new(17, 0, 0), - Node::new(49, 1, 0), - Node::new(17, 1, 0), - ], + [Node::new(0, 0, 0), Node::new(49, 1, 0), Node::new(17, 1, 0)], [0, 2], 2, )), AnyBitCircuit::B18(BitCircuit::new( - [ - Node::new(18, 0, 0), - Node::new(50, 1, 0), - Node::new(18, 1, 0), - ], + [Node::new(0, 0, 0), Node::new(50, 1, 0), Node::new(18, 1, 0)], [0, 2], 2, )), AnyBitCircuit::B19(BitCircuit::new( - [ - Node::new(19, 0, 0), - Node::new(51, 1, 0), - Node::new(19, 1, 0), - ], + [Node::new(0, 0, 0), Node::new(51, 1, 0), Node::new(19, 1, 0)], [0, 2], 2, )), AnyBitCircuit::B20(BitCircuit::new( - [ - Node::new(20, 0, 0), - Node::new(52, 1, 0), - Node::new(20, 1, 0), - ], + [Node::new(0, 0, 0), Node::new(52, 1, 0), Node::new(20, 1, 0)], [0, 2], 2, )), AnyBitCircuit::B21(BitCircuit::new( - [ - Node::new(21, 0, 0), - Node::new(53, 1, 0), - Node::new(21, 1, 0), - ], + [Node::new(0, 0, 0), Node::new(53, 1, 0), Node::new(21, 1, 0)], [0, 2], 2, )), AnyBitCircuit::B22(BitCircuit::new( - [ - Node::new(22, 0, 0), - Node::new(54, 1, 0), - Node::new(22, 1, 0), - ], + [Node::new(0, 0, 0), Node::new(54, 1, 0), Node::new(22, 1, 0)], [0, 2], 2, )), AnyBitCircuit::B23(BitCircuit::new( - [ - Node::new(23, 0, 0), - Node::new(55, 1, 0), - Node::new(23, 1, 0), - ], + [Node::new(0, 0, 0), Node::new(55, 1, 0), Node::new(23, 1, 0)], [0, 2], 2, )), AnyBitCircuit::B24(BitCircuit::new( - [ - Node::new(24, 0, 0), - Node::new(56, 1, 0), - Node::new(24, 1, 0), - ], + [Node::new(0, 0, 0), Node::new(56, 1, 0), Node::new(24, 1, 0)], [0, 2], 2, )), AnyBitCircuit::B25(BitCircuit::new( - [ - Node::new(25, 0, 0), - Node::new(57, 1, 0), - Node::new(25, 1, 0), - ], + [Node::new(0, 0, 0), Node::new(57, 1, 0), Node::new(25, 1, 0)], [0, 2], 2, )), AnyBitCircuit::B26(BitCircuit::new( - [ - Node::new(26, 0, 0), - Node::new(58, 1, 0), - Node::new(26, 1, 0), - ], + [Node::new(0, 0, 0), Node::new(58, 1, 0), Node::new(26, 1, 0)], [0, 2], 2, )), AnyBitCircuit::B27(BitCircuit::new( - [ - Node::new(27, 0, 0), - Node::new(59, 1, 0), - Node::new(27, 1, 0), - ], + [Node::new(0, 0, 0), Node::new(59, 1, 0), Node::new(27, 1, 0)], [0, 2], 2, )), AnyBitCircuit::B28(BitCircuit::new( - [ - Node::new(28, 0, 0), - Node::new(60, 1, 0), - Node::new(28, 1, 0), - ], + [Node::new(0, 0, 0), Node::new(60, 1, 0), Node::new(28, 1, 0)], [0, 2], 2, )), AnyBitCircuit::B29(BitCircuit::new( - [ - Node::new(29, 0, 0), - Node::new(61, 1, 0), - Node::new(29, 1, 0), - ], + [Node::new(0, 0, 0), Node::new(61, 1, 0), Node::new(29, 1, 0)], [0, 2], 2, )), AnyBitCircuit::B30(BitCircuit::new( - [ - Node::new(30, 0, 0), - Node::new(62, 1, 0), - Node::new(30, 1, 0), - ], + [Node::new(0, 0, 0), Node::new(62, 1, 0), Node::new(30, 1, 0)], [0, 2], 2, )), AnyBitCircuit::B31(BitCircuit::new( - [ - Node::new(31, 0, 0), - Node::new(63, 1, 0), - Node::new(31, 1, 0), - ], + [Node::new(0, 0, 0), Node::new(63, 1, 0), Node::new(31, 1, 0)], [0, 2], 2, )), diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/or_codegen.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/or_codegen.rs index 52b73e1..8f0efce 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/or_codegen.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/or_codegen.rs @@ -1,8 +1,38 @@ use crate::tfhe::bdd_arithmetic::{BitCircuit, BitCircuitInfo, Circuit, GetBitCircuitInfo, Node}; pub(crate) enum AnyBitCircuit { B0(BitCircuit<3, 2>), + B1(BitCircuit<3, 2>), + B2(BitCircuit<3, 2>), + B3(BitCircuit<3, 2>), + B4(BitCircuit<3, 2>), + B5(BitCircuit<3, 2>), + B6(BitCircuit<3, 2>), + B7(BitCircuit<3, 2>), + B8(BitCircuit<3, 2>), + B9(BitCircuit<3, 2>), + B10(BitCircuit<3, 2>), + B11(BitCircuit<3, 2>), + B12(BitCircuit<3, 2>), + B13(BitCircuit<3, 2>), + B14(BitCircuit<3, 2>), + B15(BitCircuit<3, 2>), + B16(BitCircuit<3, 2>), + B17(BitCircuit<3, 2>), + B18(BitCircuit<3, 2>), + B19(BitCircuit<3, 2>), + B20(BitCircuit<3, 2>), + B21(BitCircuit<3, 2>), + B22(BitCircuit<3, 2>), + B23(BitCircuit<3, 2>), + B24(BitCircuit<3, 2>), + B25(BitCircuit<3, 2>), + B26(BitCircuit<3, 2>), + B27(BitCircuit<3, 2>), + B28(BitCircuit<3, 2>), + B29(BitCircuit<3, 2>), + B30(BitCircuit<3, 2>), + B31(BitCircuit<3, 2>), } - impl BitCircuitInfo for AnyBitCircuit { fn info(&self) -> (&[Node], &[usize], usize) { match self { @@ -11,24 +41,336 @@ impl BitCircuitInfo for AnyBitCircuit { bit_circuit.levels.as_ref(), bit_circuit.max_inter_state, ), + AnyBitCircuit::B1(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B2(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B3(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B4(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B5(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B6(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B7(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B8(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B9(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B10(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B11(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B12(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B13(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B14(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B15(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B16(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B17(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B18(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B19(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B20(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B21(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B22(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B23(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B24(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B25(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B26(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B27(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B28(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B29(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B30(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B31(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), } } } -impl GetBitCircuitInfo for Circuit { +impl GetBitCircuitInfo for Circuit { fn input_size(&self) -> usize { 2 * u32::BITS as usize } fn output_size(&self) -> usize { u32::BITS as usize } - fn get_circuit(&self, _bit: usize) -> (&[Node], &[usize], usize) { - self.0[0].info() + fn get_circuit(&self, bit: usize) -> (&[Node], &[usize], usize) { + self.0[bit].info() } } -pub(crate) static OUTPUT_CIRCUITS: Circuit = Circuit([AnyBitCircuit::B0(BitCircuit::new( - [Node::new(0, 0, 0), Node::new(1, 1, 0), Node::new(0, 1, 1)], - [0, 2], - 2, -))]); +pub(crate) static OUTPUT_CIRCUITS: Circuit = Circuit([ + AnyBitCircuit::B0(BitCircuit::new( + [Node::new(0, 0, 0), Node::new(32, 1, 0), Node::new(0, 1, 1)], + [0, 2], + 2, + )), + AnyBitCircuit::B1(BitCircuit::new( + [Node::new(0, 0, 0), Node::new(33, 1, 0), Node::new(1, 1, 1)], + [0, 2], + 2, + )), + AnyBitCircuit::B2(BitCircuit::new( + [Node::new(0, 0, 0), Node::new(34, 1, 0), Node::new(2, 1, 1)], + [0, 2], + 2, + )), + AnyBitCircuit::B3(BitCircuit::new( + [Node::new(0, 0, 0), Node::new(35, 1, 0), Node::new(3, 1, 1)], + [0, 2], + 2, + )), + AnyBitCircuit::B4(BitCircuit::new( + [Node::new(0, 0, 0), Node::new(36, 1, 0), Node::new(4, 1, 1)], + [0, 2], + 2, + )), + AnyBitCircuit::B5(BitCircuit::new( + [Node::new(0, 0, 0), Node::new(37, 1, 0), Node::new(5, 1, 1)], + [0, 2], + 2, + )), + AnyBitCircuit::B6(BitCircuit::new( + [Node::new(0, 0, 0), Node::new(38, 1, 0), Node::new(6, 1, 1)], + [0, 2], + 2, + )), + AnyBitCircuit::B7(BitCircuit::new( + [Node::new(0, 0, 0), Node::new(39, 1, 0), Node::new(7, 1, 1)], + [0, 2], + 2, + )), + AnyBitCircuit::B8(BitCircuit::new( + [Node::new(0, 0, 0), Node::new(40, 1, 0), Node::new(8, 1, 1)], + [0, 2], + 2, + )), + AnyBitCircuit::B9(BitCircuit::new( + [Node::new(0, 0, 0), Node::new(41, 1, 0), Node::new(9, 1, 1)], + [0, 2], + 2, + )), + AnyBitCircuit::B10(BitCircuit::new( + [Node::new(0, 0, 0), Node::new(42, 1, 0), Node::new(10, 1, 1)], + [0, 2], + 2, + )), + AnyBitCircuit::B11(BitCircuit::new( + [Node::new(0, 0, 0), Node::new(43, 1, 0), Node::new(11, 1, 1)], + [0, 2], + 2, + )), + AnyBitCircuit::B12(BitCircuit::new( + [Node::new(0, 0, 0), Node::new(44, 1, 0), Node::new(12, 1, 1)], + [0, 2], + 2, + )), + AnyBitCircuit::B13(BitCircuit::new( + [Node::new(0, 0, 0), Node::new(45, 1, 0), Node::new(13, 1, 1)], + [0, 2], + 2, + )), + AnyBitCircuit::B14(BitCircuit::new( + [Node::new(0, 0, 0), Node::new(46, 1, 0), Node::new(14, 1, 1)], + [0, 2], + 2, + )), + AnyBitCircuit::B15(BitCircuit::new( + [Node::new(0, 0, 0), Node::new(47, 1, 0), Node::new(15, 1, 1)], + [0, 2], + 2, + )), + AnyBitCircuit::B16(BitCircuit::new( + [Node::new(0, 0, 0), Node::new(48, 1, 0), Node::new(16, 1, 1)], + [0, 2], + 2, + )), + AnyBitCircuit::B17(BitCircuit::new( + [Node::new(0, 0, 0), Node::new(49, 1, 0), Node::new(17, 1, 1)], + [0, 2], + 2, + )), + AnyBitCircuit::B18(BitCircuit::new( + [Node::new(0, 0, 0), Node::new(50, 1, 0), Node::new(18, 1, 1)], + [0, 2], + 2, + )), + AnyBitCircuit::B19(BitCircuit::new( + [Node::new(0, 0, 0), Node::new(51, 1, 0), Node::new(19, 1, 1)], + [0, 2], + 2, + )), + AnyBitCircuit::B20(BitCircuit::new( + [Node::new(0, 0, 0), Node::new(52, 1, 0), Node::new(20, 1, 1)], + [0, 2], + 2, + )), + AnyBitCircuit::B21(BitCircuit::new( + [Node::new(0, 0, 0), Node::new(53, 1, 0), Node::new(21, 1, 1)], + [0, 2], + 2, + )), + AnyBitCircuit::B22(BitCircuit::new( + [Node::new(0, 0, 0), Node::new(54, 1, 0), Node::new(22, 1, 1)], + [0, 2], + 2, + )), + AnyBitCircuit::B23(BitCircuit::new( + [Node::new(0, 0, 0), Node::new(55, 1, 0), Node::new(23, 1, 1)], + [0, 2], + 2, + )), + AnyBitCircuit::B24(BitCircuit::new( + [Node::new(0, 0, 0), Node::new(56, 1, 0), Node::new(24, 1, 1)], + [0, 2], + 2, + )), + AnyBitCircuit::B25(BitCircuit::new( + [Node::new(0, 0, 0), Node::new(57, 1, 0), Node::new(25, 1, 1)], + [0, 2], + 2, + )), + AnyBitCircuit::B26(BitCircuit::new( + [Node::new(0, 0, 0), Node::new(58, 1, 0), Node::new(26, 1, 1)], + [0, 2], + 2, + )), + AnyBitCircuit::B27(BitCircuit::new( + [Node::new(0, 0, 0), Node::new(59, 1, 0), Node::new(27, 1, 1)], + [0, 2], + 2, + )), + AnyBitCircuit::B28(BitCircuit::new( + [Node::new(0, 0, 0), Node::new(60, 1, 0), Node::new(28, 1, 1)], + [0, 2], + 2, + )), + AnyBitCircuit::B29(BitCircuit::new( + [Node::new(0, 0, 0), Node::new(61, 1, 0), Node::new(29, 1, 1)], + [0, 2], + 2, + )), + AnyBitCircuit::B30(BitCircuit::new( + [Node::new(0, 0, 0), Node::new(62, 1, 0), Node::new(30, 1, 1)], + [0, 2], + 2, + )), + AnyBitCircuit::B31(BitCircuit::new( + [Node::new(0, 0, 0), Node::new(63, 1, 0), Node::new(31, 1, 1)], + [0, 2], + 2, + )), +]); diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/xor_codegen.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/xor_codegen.rs index 2512c52..9b650bc 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/xor_codegen.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/circuits/u32/xor_codegen.rs @@ -1,8 +1,38 @@ use crate::tfhe::bdd_arithmetic::{BitCircuit, BitCircuitInfo, Circuit, GetBitCircuitInfo, Node}; pub(crate) enum AnyBitCircuit { B0(BitCircuit<3, 2>), + B1(BitCircuit<3, 2>), + B2(BitCircuit<3, 2>), + B3(BitCircuit<3, 2>), + B4(BitCircuit<3, 2>), + B5(BitCircuit<3, 2>), + B6(BitCircuit<3, 2>), + B7(BitCircuit<3, 2>), + B8(BitCircuit<3, 2>), + B9(BitCircuit<3, 2>), + B10(BitCircuit<3, 2>), + B11(BitCircuit<3, 2>), + B12(BitCircuit<3, 2>), + B13(BitCircuit<3, 2>), + B14(BitCircuit<3, 2>), + B15(BitCircuit<3, 2>), + B16(BitCircuit<3, 2>), + B17(BitCircuit<3, 2>), + B18(BitCircuit<3, 2>), + B19(BitCircuit<3, 2>), + B20(BitCircuit<3, 2>), + B21(BitCircuit<3, 2>), + B22(BitCircuit<3, 2>), + B23(BitCircuit<3, 2>), + B24(BitCircuit<3, 2>), + B25(BitCircuit<3, 2>), + B26(BitCircuit<3, 2>), + B27(BitCircuit<3, 2>), + B28(BitCircuit<3, 2>), + B29(BitCircuit<3, 2>), + B30(BitCircuit<3, 2>), + B31(BitCircuit<3, 2>), } - impl BitCircuitInfo for AnyBitCircuit { fn info(&self) -> (&[Node], &[usize], usize) { match self { @@ -11,24 +41,424 @@ impl BitCircuitInfo for AnyBitCircuit { bit_circuit.levels.as_ref(), bit_circuit.max_inter_state, ), + AnyBitCircuit::B1(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B2(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B3(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B4(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B5(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B6(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B7(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B8(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B9(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B10(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B11(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B12(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B13(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B14(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B15(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B16(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B17(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B18(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B19(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B20(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B21(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B22(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B23(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B24(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B25(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B26(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B27(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B28(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B29(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B30(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), + AnyBitCircuit::B31(bit_circuit) => ( + bit_circuit.nodes.as_ref(), + bit_circuit.levels.as_ref(), + bit_circuit.max_inter_state, + ), } } } -impl GetBitCircuitInfo for Circuit { +impl GetBitCircuitInfo for Circuit { fn input_size(&self) -> usize { 2 * u32::BITS as usize } fn output_size(&self) -> usize { u32::BITS as usize } - fn get_circuit(&self, _bit: usize) -> (&[Node], &[usize], usize) { - self.0[0].info() + fn get_circuit(&self, bit: usize) -> (&[Node], &[usize], usize) { + self.0[bit].info() } } -pub(crate) static OUTPUT_CIRCUITS: Circuit = Circuit([AnyBitCircuit::B0(BitCircuit::new( - [Node::new(1, 1, 0), Node::new(1, 0, 1), Node::new(0, 1, 0)], - [0, 2], - 2, -))]); +pub(crate) static OUTPUT_CIRCUITS: Circuit = Circuit([ + AnyBitCircuit::B0(BitCircuit::new( + [Node::new(32, 0, 1), Node::new(32, 1, 0), Node::new(0, 0, 1)], + [0, 2], + 2, + )), + AnyBitCircuit::B1(BitCircuit::new( + [Node::new(33, 0, 1), Node::new(33, 1, 0), Node::new(1, 0, 1)], + [0, 2], + 2, + )), + AnyBitCircuit::B2(BitCircuit::new( + [Node::new(34, 0, 1), Node::new(34, 1, 0), Node::new(2, 0, 1)], + [0, 2], + 2, + )), + AnyBitCircuit::B3(BitCircuit::new( + [Node::new(35, 1, 0), Node::new(35, 0, 1), Node::new(3, 1, 0)], + [0, 2], + 2, + )), + AnyBitCircuit::B4(BitCircuit::new( + [Node::new(36, 1, 0), Node::new(36, 0, 1), Node::new(4, 1, 0)], + [0, 2], + 2, + )), + AnyBitCircuit::B5(BitCircuit::new( + [Node::new(37, 0, 1), Node::new(37, 1, 0), Node::new(5, 0, 1)], + [0, 2], + 2, + )), + AnyBitCircuit::B6(BitCircuit::new( + [Node::new(38, 0, 1), Node::new(38, 1, 0), Node::new(6, 0, 1)], + [0, 2], + 2, + )), + AnyBitCircuit::B7(BitCircuit::new( + [Node::new(39, 0, 1), Node::new(39, 1, 0), Node::new(7, 0, 1)], + [0, 2], + 2, + )), + AnyBitCircuit::B8(BitCircuit::new( + [Node::new(40, 0, 1), Node::new(40, 1, 0), Node::new(8, 0, 1)], + [0, 2], + 2, + )), + AnyBitCircuit::B9(BitCircuit::new( + [Node::new(41, 1, 0), Node::new(41, 0, 1), Node::new(9, 1, 0)], + [0, 2], + 2, + )), + AnyBitCircuit::B10(BitCircuit::new( + [ + Node::new(42, 0, 1), + Node::new(42, 1, 0), + Node::new(10, 0, 1), + ], + [0, 2], + 2, + )), + AnyBitCircuit::B11(BitCircuit::new( + [ + Node::new(43, 1, 0), + Node::new(43, 0, 1), + Node::new(11, 1, 0), + ], + [0, 2], + 2, + )), + AnyBitCircuit::B12(BitCircuit::new( + [ + Node::new(44, 0, 1), + Node::new(44, 1, 0), + Node::new(12, 0, 1), + ], + [0, 2], + 2, + )), + AnyBitCircuit::B13(BitCircuit::new( + [ + Node::new(45, 1, 0), + Node::new(45, 0, 1), + Node::new(13, 1, 0), + ], + [0, 2], + 2, + )), + AnyBitCircuit::B14(BitCircuit::new( + [ + Node::new(46, 0, 1), + Node::new(46, 1, 0), + Node::new(14, 0, 1), + ], + [0, 2], + 2, + )), + AnyBitCircuit::B15(BitCircuit::new( + [ + Node::new(47, 1, 0), + Node::new(47, 0, 1), + Node::new(15, 1, 0), + ], + [0, 2], + 2, + )), + AnyBitCircuit::B16(BitCircuit::new( + [ + Node::new(48, 0, 1), + Node::new(48, 1, 0), + Node::new(16, 0, 1), + ], + [0, 2], + 2, + )), + AnyBitCircuit::B17(BitCircuit::new( + [ + Node::new(49, 1, 0), + Node::new(49, 0, 1), + Node::new(17, 1, 0), + ], + [0, 2], + 2, + )), + AnyBitCircuit::B18(BitCircuit::new( + [ + Node::new(50, 1, 0), + Node::new(50, 0, 1), + Node::new(18, 1, 0), + ], + [0, 2], + 2, + )), + AnyBitCircuit::B19(BitCircuit::new( + [ + Node::new(51, 0, 1), + Node::new(51, 1, 0), + Node::new(19, 0, 1), + ], + [0, 2], + 2, + )), + AnyBitCircuit::B20(BitCircuit::new( + [ + Node::new(52, 1, 0), + Node::new(52, 0, 1), + Node::new(20, 1, 0), + ], + [0, 2], + 2, + )), + AnyBitCircuit::B21(BitCircuit::new( + [ + Node::new(53, 1, 0), + Node::new(53, 0, 1), + Node::new(21, 1, 0), + ], + [0, 2], + 2, + )), + AnyBitCircuit::B22(BitCircuit::new( + [ + Node::new(54, 1, 0), + Node::new(54, 0, 1), + Node::new(22, 1, 0), + ], + [0, 2], + 2, + )), + AnyBitCircuit::B23(BitCircuit::new( + [ + Node::new(55, 1, 0), + Node::new(55, 0, 1), + Node::new(23, 1, 0), + ], + [0, 2], + 2, + )), + AnyBitCircuit::B24(BitCircuit::new( + [ + Node::new(56, 1, 0), + Node::new(56, 0, 1), + Node::new(24, 1, 0), + ], + [0, 2], + 2, + )), + AnyBitCircuit::B25(BitCircuit::new( + [ + Node::new(57, 1, 0), + Node::new(57, 0, 1), + Node::new(25, 1, 0), + ], + [0, 2], + 2, + )), + AnyBitCircuit::B26(BitCircuit::new( + [ + Node::new(58, 1, 0), + Node::new(58, 0, 1), + Node::new(26, 1, 0), + ], + [0, 2], + 2, + )), + AnyBitCircuit::B27(BitCircuit::new( + [ + Node::new(59, 1, 0), + Node::new(59, 0, 1), + Node::new(27, 1, 0), + ], + [0, 2], + 2, + )), + AnyBitCircuit::B28(BitCircuit::new( + [ + Node::new(60, 1, 0), + Node::new(60, 0, 1), + Node::new(28, 1, 0), + ], + [0, 2], + 2, + )), + AnyBitCircuit::B29(BitCircuit::new( + [ + Node::new(61, 0, 1), + Node::new(61, 1, 0), + Node::new(29, 0, 1), + ], + [0, 2], + 2, + )), + AnyBitCircuit::B30(BitCircuit::new( + [ + Node::new(62, 1, 0), + Node::new(62, 0, 1), + Node::new(30, 1, 0), + ], + [0, 2], + 2, + )), + AnyBitCircuit::B31(BitCircuit::new( + [ + Node::new(63, 1, 0), + Node::new(63, 0, 1), + Node::new(31, 1, 0), + ], + [0, 2], + 2, + )), +]);