removed unused field in macro + fixed BDD circuit eval for bits with 0 nodes

This commit is contained in:
Jean-Philippe Bossuat
2025-10-31 10:35:23 +01:00
parent 578ed45b9a
commit 2feda14b63
2 changed files with 45 additions and 51 deletions

View File

@@ -122,7 +122,7 @@ macro_rules! define_bdd_2w_to_1w_trait {
#[macro_export] #[macro_export]
macro_rules! impl_bdd_2w_to_1w_trait { macro_rules! impl_bdd_2w_to_1w_trait {
($trait_name:ident, $method_name:ident, $ty:ty, $n:literal, $circuit_ty:ty, $output_circuits:path) => { ($trait_name:ident, $method_name:ident, $ty:ty, $circuit_ty:ty, $output_circuits:path) => {
impl<D: DataMut, BE: Backend> $trait_name<$ty, BE> for FheUint<D, $ty> { impl<D: DataMut, BE: Backend> $trait_name<$ty, BE> for FheUint<D, $ty> {
fn $method_name<A, M, K, H, B>( fn $method_name<A, M, K, H, B>(
&mut self, &mut self,
@@ -160,7 +160,6 @@ impl_bdd_2w_to_1w_trait!(
Add, Add,
add, add,
u32, u32,
32,
circuits::u32::add_codegen::AnyBitCircuit, circuits::u32::add_codegen::AnyBitCircuit,
circuits::u32::add_codegen::OUTPUT_CIRCUITS circuits::u32::add_codegen::OUTPUT_CIRCUITS
); );
@@ -169,7 +168,6 @@ impl_bdd_2w_to_1w_trait!(
Sub, Sub,
sub, sub,
u32, u32,
32,
circuits::u32::sub_codegen::AnyBitCircuit, circuits::u32::sub_codegen::AnyBitCircuit,
circuits::u32::sub_codegen::OUTPUT_CIRCUITS circuits::u32::sub_codegen::OUTPUT_CIRCUITS
); );
@@ -178,7 +176,6 @@ impl_bdd_2w_to_1w_trait!(
Sll, Sll,
sll, sll,
u32, u32,
32,
circuits::u32::sll_codegen::AnyBitCircuit, circuits::u32::sll_codegen::AnyBitCircuit,
circuits::u32::sll_codegen::OUTPUT_CIRCUITS circuits::u32::sll_codegen::OUTPUT_CIRCUITS
); );
@@ -187,7 +184,6 @@ impl_bdd_2w_to_1w_trait!(
Sra, Sra,
sra, sra,
u32, u32,
32,
circuits::u32::sra_codegen::AnyBitCircuit, circuits::u32::sra_codegen::AnyBitCircuit,
circuits::u32::sra_codegen::OUTPUT_CIRCUITS circuits::u32::sra_codegen::OUTPUT_CIRCUITS
); );
@@ -196,7 +192,6 @@ impl_bdd_2w_to_1w_trait!(
Srl, Srl,
srl, srl,
u32, u32,
32,
circuits::u32::srl_codegen::AnyBitCircuit, circuits::u32::srl_codegen::AnyBitCircuit,
circuits::u32::srl_codegen::OUTPUT_CIRCUITS circuits::u32::srl_codegen::OUTPUT_CIRCUITS
); );
@@ -205,7 +200,6 @@ impl_bdd_2w_to_1w_trait!(
Slt, Slt,
slt, slt,
u32, u32,
1,
circuits::u32::slt_codegen::AnyBitCircuit, circuits::u32::slt_codegen::AnyBitCircuit,
circuits::u32::slt_codegen::OUTPUT_CIRCUITS circuits::u32::slt_codegen::OUTPUT_CIRCUITS
); );
@@ -214,7 +208,6 @@ impl_bdd_2w_to_1w_trait!(
Sltu, Sltu,
sltu, sltu,
u32, u32,
1,
circuits::u32::sltu_codegen::AnyBitCircuit, circuits::u32::sltu_codegen::AnyBitCircuit,
circuits::u32::sltu_codegen::OUTPUT_CIRCUITS circuits::u32::sltu_codegen::OUTPUT_CIRCUITS
); );
@@ -223,7 +216,6 @@ impl_bdd_2w_to_1w_trait!(
And, And,
and, and,
u32, u32,
32,
circuits::u32::and_codegen::AnyBitCircuit, circuits::u32::and_codegen::AnyBitCircuit,
circuits::u32::and_codegen::OUTPUT_CIRCUITS circuits::u32::and_codegen::OUTPUT_CIRCUITS
); );
@@ -232,7 +224,6 @@ impl_bdd_2w_to_1w_trait!(
Or, Or,
or, or,
u32, u32,
32,
circuits::u32::or_codegen::AnyBitCircuit, circuits::u32::or_codegen::AnyBitCircuit,
circuits::u32::or_codegen::OUTPUT_CIRCUITS circuits::u32::or_codegen::OUTPUT_CIRCUITS
); );
@@ -241,7 +232,6 @@ impl_bdd_2w_to_1w_trait!(
Xor, Xor,
xor, xor,
u32, u32,
32,
circuits::u32::xor_codegen::AnyBitCircuit, circuits::u32::xor_codegen::AnyBitCircuit,
circuits::u32::xor_codegen::OUTPUT_CIRCUITS circuits::u32::xor_codegen::OUTPUT_CIRCUITS
); );

View File

@@ -78,6 +78,9 @@ where
for (i, out_i) in out.iter_mut().enumerate().take(circuit.output_size()) { for (i, out_i) in out.iter_mut().enumerate().take(circuit.output_size()) {
let (nodes, max_inter_state) = circuit.get_circuit(i); let (nodes, max_inter_state) = circuit.get_circuit(i);
if max_inter_state == 0 {
out_i.data_mut().zero();
} else {
assert!(nodes.len().is_multiple_of(max_inter_state)); assert!(nodes.len().is_multiple_of(max_inter_state));
let (mut level, scratch_1) = scratch.take_glwe_slice(max_inter_state * 2, out_i); let (mut level, scratch_1) = scratch.take_glwe_slice(max_inter_state * 2, out_i);
@@ -131,6 +134,7 @@ where
} }
} }
} }
}
for out_i in out.iter_mut().skip(circuit.output_size()) { for out_i in out.iter_mut().skip(circuit.output_size()) {
out_i.data_mut().zero(); out_i.data_mut().zero();