removed unused field in macro + fixed BDD circuit eval for bits with 0 nodes

This commit is contained in:
Jean-Philippe Bossuat
2025-10-31 10:35:23 +01:00
parent 578ed45b9a
commit 2feda14b63
2 changed files with 45 additions and 51 deletions

View File

@@ -78,56 +78,60 @@ where
for (i, out_i) in out.iter_mut().enumerate().take(circuit.output_size()) {
let (nodes, max_inter_state) = circuit.get_circuit(i);
assert!(nodes.len().is_multiple_of(max_inter_state));
if max_inter_state == 0 {
out_i.data_mut().zero();
} else {
assert!(nodes.len().is_multiple_of(max_inter_state));
let (mut level, scratch_1) = scratch.take_glwe_slice(max_inter_state * 2, out_i);
let (mut level, scratch_1) = scratch.take_glwe_slice(max_inter_state * 2, out_i);
level.iter_mut().for_each(|ct| ct.data_mut().zero());
level.iter_mut().for_each(|ct| ct.data_mut().zero());
// TODO: implement API on GLWE
level[1]
.data_mut()
.encode_coeff_i64(out_i.base2k().into(), 0, 2, 0, 1);
// TODO: implement API on GLWE
level[1]
.data_mut()
.encode_coeff_i64(out_i.base2k().into(), 0, 2, 0, 1);
let mut level_ref = level.iter_mut().collect_vec();
let (mut prev_level, mut next_level) = level_ref.split_at_mut(max_inter_state);
let mut level_ref = level.iter_mut().collect_vec();
let (mut prev_level, mut next_level) = level_ref.split_at_mut(max_inter_state);
let (all_but_last, last) = nodes.split_at(nodes.len() - max_inter_state);
let (all_but_last, last) = nodes.split_at(nodes.len() - max_inter_state);
for nodes_lvl in all_but_last.chunks_exact(max_inter_state) {
for (j, node) in nodes_lvl.iter().enumerate() {
match node {
Node::Cmux(in_idx, hi_idx, lo_idx) => {
self.cmux(
next_level[j],
prev_level[*hi_idx],
prev_level[*lo_idx],
&inputs.get_bit(*in_idx),
scratch_1,
);
for nodes_lvl in all_but_last.chunks_exact(max_inter_state) {
for (j, node) in nodes_lvl.iter().enumerate() {
match node {
Node::Cmux(in_idx, hi_idx, lo_idx) => {
self.cmux(
next_level[j],
prev_level[*hi_idx],
prev_level[*lo_idx],
&inputs.get_bit(*in_idx),
scratch_1,
);
}
Node::Copy => self.glwe_copy(next_level[j], prev_level[j]), /* Update BDD circuits to order Cmux -> Copy -> None so that mem swap can be used */
Node::None => {}
}
Node::Copy => self.glwe_copy(next_level[j], prev_level[j]), /* Update BDD circuits to order Cmux -> Copy -> None so that mem swap can be used */
Node::None => {}
}
(prev_level, next_level) = (next_level, prev_level);
}
(prev_level, next_level) = (next_level, prev_level);
}
// Last chunck of max_inter_state Nodes is always structured as
// [CMUX, NONE, NONE, ..., NONE]
match &last[0] {
Node::Cmux(in_idx, hi_idx, lo_idx) => {
self.cmux(
out_i,
prev_level[*hi_idx],
prev_level[*lo_idx],
&inputs.get_bit(*in_idx),
scratch_1,
);
}
_ => {
panic!("invalid last node, should be CMUX")
// Last chunck of max_inter_state Nodes is always structured as
// [CMUX, NONE, NONE, ..., NONE]
match &last[0] {
Node::Cmux(in_idx, hi_idx, lo_idx) => {
self.cmux(
out_i,
prev_level[*hi_idx],
prev_level[*lo_idx],
&inputs.get_bit(*in_idx),
scratch_1,
);
}
_ => {
panic!("invalid last node, should be CMUX")
}
}
}
}