mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
removed unused field in macro + fixed BDD circuit eval for bits with 0 nodes
This commit is contained in:
@@ -122,7 +122,7 @@ macro_rules! define_bdd_2w_to_1w_trait {
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! impl_bdd_2w_to_1w_trait {
|
||||
($trait_name:ident, $method_name:ident, $ty:ty, $n:literal, $circuit_ty:ty, $output_circuits:path) => {
|
||||
($trait_name:ident, $method_name:ident, $ty:ty, $circuit_ty:ty, $output_circuits:path) => {
|
||||
impl<D: DataMut, BE: Backend> $trait_name<$ty, BE> for FheUint<D, $ty> {
|
||||
fn $method_name<A, M, K, H, B>(
|
||||
&mut self,
|
||||
@@ -160,7 +160,6 @@ impl_bdd_2w_to_1w_trait!(
|
||||
Add,
|
||||
add,
|
||||
u32,
|
||||
32,
|
||||
circuits::u32::add_codegen::AnyBitCircuit,
|
||||
circuits::u32::add_codegen::OUTPUT_CIRCUITS
|
||||
);
|
||||
@@ -169,7 +168,6 @@ impl_bdd_2w_to_1w_trait!(
|
||||
Sub,
|
||||
sub,
|
||||
u32,
|
||||
32,
|
||||
circuits::u32::sub_codegen::AnyBitCircuit,
|
||||
circuits::u32::sub_codegen::OUTPUT_CIRCUITS
|
||||
);
|
||||
@@ -178,7 +176,6 @@ impl_bdd_2w_to_1w_trait!(
|
||||
Sll,
|
||||
sll,
|
||||
u32,
|
||||
32,
|
||||
circuits::u32::sll_codegen::AnyBitCircuit,
|
||||
circuits::u32::sll_codegen::OUTPUT_CIRCUITS
|
||||
);
|
||||
@@ -187,7 +184,6 @@ impl_bdd_2w_to_1w_trait!(
|
||||
Sra,
|
||||
sra,
|
||||
u32,
|
||||
32,
|
||||
circuits::u32::sra_codegen::AnyBitCircuit,
|
||||
circuits::u32::sra_codegen::OUTPUT_CIRCUITS
|
||||
);
|
||||
@@ -196,7 +192,6 @@ impl_bdd_2w_to_1w_trait!(
|
||||
Srl,
|
||||
srl,
|
||||
u32,
|
||||
32,
|
||||
circuits::u32::srl_codegen::AnyBitCircuit,
|
||||
circuits::u32::srl_codegen::OUTPUT_CIRCUITS
|
||||
);
|
||||
@@ -205,7 +200,6 @@ impl_bdd_2w_to_1w_trait!(
|
||||
Slt,
|
||||
slt,
|
||||
u32,
|
||||
1,
|
||||
circuits::u32::slt_codegen::AnyBitCircuit,
|
||||
circuits::u32::slt_codegen::OUTPUT_CIRCUITS
|
||||
);
|
||||
@@ -214,7 +208,6 @@ impl_bdd_2w_to_1w_trait!(
|
||||
Sltu,
|
||||
sltu,
|
||||
u32,
|
||||
1,
|
||||
circuits::u32::sltu_codegen::AnyBitCircuit,
|
||||
circuits::u32::sltu_codegen::OUTPUT_CIRCUITS
|
||||
);
|
||||
@@ -223,7 +216,6 @@ impl_bdd_2w_to_1w_trait!(
|
||||
And,
|
||||
and,
|
||||
u32,
|
||||
32,
|
||||
circuits::u32::and_codegen::AnyBitCircuit,
|
||||
circuits::u32::and_codegen::OUTPUT_CIRCUITS
|
||||
);
|
||||
@@ -232,7 +224,6 @@ impl_bdd_2w_to_1w_trait!(
|
||||
Or,
|
||||
or,
|
||||
u32,
|
||||
32,
|
||||
circuits::u32::or_codegen::AnyBitCircuit,
|
||||
circuits::u32::or_codegen::OUTPUT_CIRCUITS
|
||||
);
|
||||
@@ -241,7 +232,6 @@ impl_bdd_2w_to_1w_trait!(
|
||||
Xor,
|
||||
xor,
|
||||
u32,
|
||||
32,
|
||||
circuits::u32::xor_codegen::AnyBitCircuit,
|
||||
circuits::u32::xor_codegen::OUTPUT_CIRCUITS
|
||||
);
|
||||
|
||||
@@ -78,56 +78,60 @@ where
|
||||
for (i, out_i) in out.iter_mut().enumerate().take(circuit.output_size()) {
|
||||
let (nodes, max_inter_state) = circuit.get_circuit(i);
|
||||
|
||||
assert!(nodes.len().is_multiple_of(max_inter_state));
|
||||
if max_inter_state == 0 {
|
||||
out_i.data_mut().zero();
|
||||
} else {
|
||||
assert!(nodes.len().is_multiple_of(max_inter_state));
|
||||
|
||||
let (mut level, scratch_1) = scratch.take_glwe_slice(max_inter_state * 2, out_i);
|
||||
let (mut level, scratch_1) = scratch.take_glwe_slice(max_inter_state * 2, out_i);
|
||||
|
||||
level.iter_mut().for_each(|ct| ct.data_mut().zero());
|
||||
level.iter_mut().for_each(|ct| ct.data_mut().zero());
|
||||
|
||||
// TODO: implement API on GLWE
|
||||
level[1]
|
||||
.data_mut()
|
||||
.encode_coeff_i64(out_i.base2k().into(), 0, 2, 0, 1);
|
||||
// TODO: implement API on GLWE
|
||||
level[1]
|
||||
.data_mut()
|
||||
.encode_coeff_i64(out_i.base2k().into(), 0, 2, 0, 1);
|
||||
|
||||
let mut level_ref = level.iter_mut().collect_vec();
|
||||
let (mut prev_level, mut next_level) = level_ref.split_at_mut(max_inter_state);
|
||||
let mut level_ref = level.iter_mut().collect_vec();
|
||||
let (mut prev_level, mut next_level) = level_ref.split_at_mut(max_inter_state);
|
||||
|
||||
let (all_but_last, last) = nodes.split_at(nodes.len() - max_inter_state);
|
||||
let (all_but_last, last) = nodes.split_at(nodes.len() - max_inter_state);
|
||||
|
||||
for nodes_lvl in all_but_last.chunks_exact(max_inter_state) {
|
||||
for (j, node) in nodes_lvl.iter().enumerate() {
|
||||
match node {
|
||||
Node::Cmux(in_idx, hi_idx, lo_idx) => {
|
||||
self.cmux(
|
||||
next_level[j],
|
||||
prev_level[*hi_idx],
|
||||
prev_level[*lo_idx],
|
||||
&inputs.get_bit(*in_idx),
|
||||
scratch_1,
|
||||
);
|
||||
for nodes_lvl in all_but_last.chunks_exact(max_inter_state) {
|
||||
for (j, node) in nodes_lvl.iter().enumerate() {
|
||||
match node {
|
||||
Node::Cmux(in_idx, hi_idx, lo_idx) => {
|
||||
self.cmux(
|
||||
next_level[j],
|
||||
prev_level[*hi_idx],
|
||||
prev_level[*lo_idx],
|
||||
&inputs.get_bit(*in_idx),
|
||||
scratch_1,
|
||||
);
|
||||
}
|
||||
Node::Copy => self.glwe_copy(next_level[j], prev_level[j]), /* Update BDD circuits to order Cmux -> Copy -> None so that mem swap can be used */
|
||||
Node::None => {}
|
||||
}
|
||||
Node::Copy => self.glwe_copy(next_level[j], prev_level[j]), /* Update BDD circuits to order Cmux -> Copy -> None so that mem swap can be used */
|
||||
Node::None => {}
|
||||
}
|
||||
|
||||
(prev_level, next_level) = (next_level, prev_level);
|
||||
}
|
||||
|
||||
(prev_level, next_level) = (next_level, prev_level);
|
||||
}
|
||||
|
||||
// Last chunck of max_inter_state Nodes is always structured as
|
||||
// [CMUX, NONE, NONE, ..., NONE]
|
||||
match &last[0] {
|
||||
Node::Cmux(in_idx, hi_idx, lo_idx) => {
|
||||
self.cmux(
|
||||
out_i,
|
||||
prev_level[*hi_idx],
|
||||
prev_level[*lo_idx],
|
||||
&inputs.get_bit(*in_idx),
|
||||
scratch_1,
|
||||
);
|
||||
}
|
||||
_ => {
|
||||
panic!("invalid last node, should be CMUX")
|
||||
// Last chunck of max_inter_state Nodes is always structured as
|
||||
// [CMUX, NONE, NONE, ..., NONE]
|
||||
match &last[0] {
|
||||
Node::Cmux(in_idx, hi_idx, lo_idx) => {
|
||||
self.cmux(
|
||||
out_i,
|
||||
prev_level[*hi_idx],
|
||||
prev_level[*lo_idx],
|
||||
&inputs.get_bit(*in_idx),
|
||||
scratch_1,
|
||||
);
|
||||
}
|
||||
_ => {
|
||||
panic!("invalid last node, should be CMUX")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user