diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs index 09564b5..b56f907 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs @@ -122,7 +122,7 @@ macro_rules! define_bdd_2w_to_1w_trait { #[macro_export] macro_rules! impl_bdd_2w_to_1w_trait { - ($trait_name:ident, $method_name:ident, $ty:ty, $n:literal, $circuit_ty:ty, $output_circuits:path) => { + ($trait_name:ident, $method_name:ident, $ty:ty, $circuit_ty:ty, $output_circuits:path) => { impl $trait_name<$ty, BE> for FheUint { fn $method_name( &mut self, @@ -160,7 +160,6 @@ impl_bdd_2w_to_1w_trait!( Add, add, u32, - 32, circuits::u32::add_codegen::AnyBitCircuit, circuits::u32::add_codegen::OUTPUT_CIRCUITS ); @@ -169,7 +168,6 @@ impl_bdd_2w_to_1w_trait!( Sub, sub, u32, - 32, circuits::u32::sub_codegen::AnyBitCircuit, circuits::u32::sub_codegen::OUTPUT_CIRCUITS ); @@ -178,7 +176,6 @@ impl_bdd_2w_to_1w_trait!( Sll, sll, u32, - 32, circuits::u32::sll_codegen::AnyBitCircuit, circuits::u32::sll_codegen::OUTPUT_CIRCUITS ); @@ -187,7 +184,6 @@ impl_bdd_2w_to_1w_trait!( Sra, sra, u32, - 32, circuits::u32::sra_codegen::AnyBitCircuit, circuits::u32::sra_codegen::OUTPUT_CIRCUITS ); @@ -196,7 +192,6 @@ impl_bdd_2w_to_1w_trait!( Srl, srl, u32, - 32, circuits::u32::srl_codegen::AnyBitCircuit, circuits::u32::srl_codegen::OUTPUT_CIRCUITS ); @@ -205,7 +200,6 @@ impl_bdd_2w_to_1w_trait!( Slt, slt, u32, - 1, circuits::u32::slt_codegen::AnyBitCircuit, circuits::u32::slt_codegen::OUTPUT_CIRCUITS ); @@ -214,7 +208,6 @@ impl_bdd_2w_to_1w_trait!( Sltu, sltu, u32, - 1, circuits::u32::sltu_codegen::AnyBitCircuit, circuits::u32::sltu_codegen::OUTPUT_CIRCUITS ); @@ -223,7 +216,6 @@ impl_bdd_2w_to_1w_trait!( And, and, u32, - 32, circuits::u32::and_codegen::AnyBitCircuit, circuits::u32::and_codegen::OUTPUT_CIRCUITS ); @@ -232,7 +224,6 @@ impl_bdd_2w_to_1w_trait!( Or, or, u32, - 32, circuits::u32::or_codegen::AnyBitCircuit, circuits::u32::or_codegen::OUTPUT_CIRCUITS ); @@ -241,7 +232,6 @@ impl_bdd_2w_to_1w_trait!( Xor, xor, u32, - 32, circuits::u32::xor_codegen::AnyBitCircuit, circuits::u32::xor_codegen::OUTPUT_CIRCUITS ); diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs index b845a22..5369903 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs @@ -78,56 +78,60 @@ where for (i, out_i) in out.iter_mut().enumerate().take(circuit.output_size()) { let (nodes, max_inter_state) = circuit.get_circuit(i); - assert!(nodes.len().is_multiple_of(max_inter_state)); + if max_inter_state == 0 { + out_i.data_mut().zero(); + } else { + assert!(nodes.len().is_multiple_of(max_inter_state)); - let (mut level, scratch_1) = scratch.take_glwe_slice(max_inter_state * 2, out_i); + let (mut level, scratch_1) = scratch.take_glwe_slice(max_inter_state * 2, out_i); - level.iter_mut().for_each(|ct| ct.data_mut().zero()); + level.iter_mut().for_each(|ct| ct.data_mut().zero()); - // TODO: implement API on GLWE - level[1] - .data_mut() - .encode_coeff_i64(out_i.base2k().into(), 0, 2, 0, 1); + // TODO: implement API on GLWE + level[1] + .data_mut() + .encode_coeff_i64(out_i.base2k().into(), 0, 2, 0, 1); - let mut level_ref = level.iter_mut().collect_vec(); - let (mut prev_level, mut next_level) = level_ref.split_at_mut(max_inter_state); + let mut level_ref = level.iter_mut().collect_vec(); + let (mut prev_level, mut next_level) = level_ref.split_at_mut(max_inter_state); - let (all_but_last, last) = nodes.split_at(nodes.len() - max_inter_state); + let (all_but_last, last) = nodes.split_at(nodes.len() - max_inter_state); - for nodes_lvl in all_but_last.chunks_exact(max_inter_state) { - for (j, node) in nodes_lvl.iter().enumerate() { - match node { - Node::Cmux(in_idx, hi_idx, lo_idx) => { - self.cmux( - next_level[j], - prev_level[*hi_idx], - prev_level[*lo_idx], - &inputs.get_bit(*in_idx), - scratch_1, - ); + for nodes_lvl in all_but_last.chunks_exact(max_inter_state) { + for (j, node) in nodes_lvl.iter().enumerate() { + match node { + Node::Cmux(in_idx, hi_idx, lo_idx) => { + self.cmux( + next_level[j], + prev_level[*hi_idx], + prev_level[*lo_idx], + &inputs.get_bit(*in_idx), + scratch_1, + ); + } + Node::Copy => self.glwe_copy(next_level[j], prev_level[j]), /* Update BDD circuits to order Cmux -> Copy -> None so that mem swap can be used */ + Node::None => {} } - Node::Copy => self.glwe_copy(next_level[j], prev_level[j]), /* Update BDD circuits to order Cmux -> Copy -> None so that mem swap can be used */ - Node::None => {} } + + (prev_level, next_level) = (next_level, prev_level); } - (prev_level, next_level) = (next_level, prev_level); - } - - // Last chunck of max_inter_state Nodes is always structured as - // [CMUX, NONE, NONE, ..., NONE] - match &last[0] { - Node::Cmux(in_idx, hi_idx, lo_idx) => { - self.cmux( - out_i, - prev_level[*hi_idx], - prev_level[*lo_idx], - &inputs.get_bit(*in_idx), - scratch_1, - ); - } - _ => { - panic!("invalid last node, should be CMUX") + // Last chunck of max_inter_state Nodes is always structured as + // [CMUX, NONE, NONE, ..., NONE] + match &last[0] { + Node::Cmux(in_idx, hi_idx, lo_idx) => { + self.cmux( + out_i, + prev_level[*hi_idx], + prev_level[*lo_idx], + &inputs.get_bit(*in_idx), + scratch_1, + ); + } + _ => { + panic!("invalid last node, should be CMUX") + } } } }