mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
update BDD circuits & fix non-zero scratch related bug
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
use core::panic;
|
||||
|
||||
use itertools::Itertools;
|
||||
use poulpy_core::{
|
||||
GLWEAdd, GLWECopy, GLWEExternalProduct, GLWESub, ScratchTakeCore,
|
||||
@@ -11,18 +13,17 @@ use poulpy_hal::layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero};
|
||||
use crate::tfhe::bdd_arithmetic::UnsignedInteger;
|
||||
|
||||
pub trait BitCircuitInfo {
|
||||
fn info(&self) -> (&[Node], &[usize], usize);
|
||||
fn info(&self) -> (&[Node], usize);
|
||||
}
|
||||
|
||||
pub trait GetBitCircuitInfo<T: UnsignedInteger> {
|
||||
fn input_size(&self) -> usize;
|
||||
fn output_size(&self) -> usize;
|
||||
fn get_circuit(&self, bit: usize) -> (&[Node], &[usize], usize);
|
||||
fn get_circuit(&self, bit: usize) -> (&[Node], usize);
|
||||
}
|
||||
|
||||
pub(crate) struct BitCircuit<const N: usize, const K: usize> {
|
||||
pub(crate) struct BitCircuit<const N: usize> {
|
||||
pub(crate) nodes: [Node; N],
|
||||
pub(crate) levels: [usize; K],
|
||||
pub(crate) max_inter_state: usize,
|
||||
}
|
||||
|
||||
@@ -62,7 +63,9 @@ where
|
||||
}
|
||||
|
||||
for (i, out_i) in out.iter_mut().enumerate().take(circuit.output_size()) {
|
||||
let (nodes, levels, max_inter_state) = circuit.get_circuit(i);
|
||||
let (nodes, max_inter_state) = circuit.get_circuit(i);
|
||||
|
||||
assert!(nodes.len().is_multiple_of(max_inter_state));
|
||||
|
||||
let (mut level, scratch_1) = scratch.take_glwe_slice(max_inter_state * 2, out_i);
|
||||
|
||||
@@ -76,39 +79,44 @@ where
|
||||
let mut level_ref = level.iter_mut().collect_vec();
|
||||
let (mut prev_level, mut next_level) = level_ref.split_at_mut(max_inter_state);
|
||||
|
||||
for i in 0..levels.len() - 1 {
|
||||
let start: usize = levels[i];
|
||||
let end: usize = levels[i + 1];
|
||||
|
||||
let nodes_lvl: &[Node] = &nodes[start..end];
|
||||
let (all_but_last, last) = nodes.split_at(nodes.len() - max_inter_state);
|
||||
|
||||
for nodes_lvl in all_but_last.chunks_exact(max_inter_state) {
|
||||
for (j, node) in nodes_lvl.iter().enumerate() {
|
||||
if node.low_index == node.high_index {
|
||||
self.glwe_copy(next_level[j], prev_level[node.low_index]);
|
||||
} else {
|
||||
self.cmux(
|
||||
next_level[j],
|
||||
prev_level[node.high_index],
|
||||
prev_level[node.low_index],
|
||||
&inputs[node.input_index].to_ref(),
|
||||
scratch_1,
|
||||
);
|
||||
match node {
|
||||
Node::Cmux(in_idx, hi_idx, lo_idx) => {
|
||||
self.cmux(
|
||||
next_level[j],
|
||||
prev_level[*hi_idx],
|
||||
prev_level[*lo_idx],
|
||||
&inputs[*in_idx].to_ref(),
|
||||
scratch_1,
|
||||
);
|
||||
}
|
||||
Node::Copy => self.glwe_copy(next_level[j], prev_level[j]), /* Update BDD circuits to order Cmux -> Copy -> None so that mem swap can be used */
|
||||
Node::None => {}
|
||||
}
|
||||
}
|
||||
|
||||
(prev_level, next_level) = (next_level, prev_level);
|
||||
}
|
||||
|
||||
// handle last output
|
||||
// there's always only 1 node at last level
|
||||
let node: &Node = nodes.last().unwrap();
|
||||
self.cmux(
|
||||
out_i,
|
||||
prev_level[node.high_index],
|
||||
prev_level[node.low_index],
|
||||
&inputs[node.input_index].to_ref(),
|
||||
scratch_1,
|
||||
);
|
||||
// Last chunck of max_inter_state Nodes is always structured as
|
||||
// [CMUX, NONE, NONE, ..., NONE]
|
||||
match &last[0] {
|
||||
Node::Cmux(in_idx, hi_idx, lo_idx) => {
|
||||
self.cmux(
|
||||
out_i,
|
||||
prev_level[*hi_idx],
|
||||
prev_level[*lo_idx],
|
||||
&inputs[*in_idx].to_ref(),
|
||||
scratch_1,
|
||||
);
|
||||
}
|
||||
_ => {
|
||||
panic!("invalid last node, should be CMUX")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for out_i in out.iter_mut().skip(circuit.output_size()) {
|
||||
@@ -117,40 +125,25 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize, const K: usize> BitCircuit<N, K> {
|
||||
pub(crate) const fn new(nodes: [Node; N], levels: [usize; K], max_inter_state: usize) -> Self {
|
||||
impl<const N: usize> BitCircuit<N> {
|
||||
pub(crate) const fn new(nodes: [Node; N], max_inter_state: usize) -> Self {
|
||||
Self {
|
||||
nodes,
|
||||
levels,
|
||||
max_inter_state,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<const N: usize, const K: usize> BitCircuitInfo for BitCircuit<N, K> {
|
||||
fn info(&self) -> (&[Node], &[usize], usize) {
|
||||
(
|
||||
self.nodes.as_ref(),
|
||||
self.levels.as_ref(),
|
||||
self.max_inter_state,
|
||||
)
|
||||
impl<const N: usize> BitCircuitInfo for BitCircuit<N> {
|
||||
fn info(&self) -> (&[Node], usize) {
|
||||
(self.nodes.as_ref(), self.max_inter_state)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Node {
|
||||
input_index: usize,
|
||||
high_index: usize,
|
||||
low_index: usize,
|
||||
}
|
||||
|
||||
impl Node {
|
||||
pub(crate) const fn new(input_index: usize, high_index: usize, low_index: usize) -> Self {
|
||||
Self {
|
||||
input_index,
|
||||
high_index,
|
||||
low_index,
|
||||
}
|
||||
}
|
||||
pub enum Node {
|
||||
Cmux(usize, usize, usize),
|
||||
Copy,
|
||||
None,
|
||||
}
|
||||
|
||||
pub trait Cmux<BE: Backend> {
|
||||
|
||||
Reference in New Issue
Block a user