Add multi-thread bdd eval

This commit is contained in:
Pro7ech
2025-11-12 11:02:37 +01:00
parent 6924ffd94a
commit 1423de1c46
22 changed files with 336 additions and 273 deletions

View File

@@ -1,4 +1,5 @@
use core::panic;
use std::thread;
use itertools::Itertools;
use poulpy_core::{
@@ -6,17 +7,20 @@ use poulpy_core::{
layouts::{GGSWInfos, GGSWPrepared, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, prepared::GGSWPreparedToRef},
};
use poulpy_hal::{
api::{ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftBytesOf},
api::{
ScratchAvailable, ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
VecZnxDftBytesOf,
},
layouts::{Backend, DataMut, Module, Scratch, VecZnxBig, ZnxZero},
};
use crate::tfhe::bdd_arithmetic::{GetGGSWBit, UnsignedInteger};
pub trait BitCircuitInfo {
pub trait BitCircuitInfo: Sync {
fn info(&self) -> (&[Node], usize);
}
pub trait GetBitCircuitInfo<T: UnsignedInteger> {
pub trait GetBitCircuitInfo<T: UnsignedInteger>: Sync {
fn input_size(&self) -> usize;
fn output_size(&self) -> usize;
fn get_circuit(&self, bit: usize) -> (&[Node], usize);
@@ -49,9 +53,34 @@ where
}
}
pub trait ExecuteBDDCircuit<T: UnsignedInteger, BE: Backend> {
fn execute_bdd_circuit<C, G, O>(&self, out: &mut [GLWE<O>], inputs: &G, circuit: &C, scratch: &mut Scratch<BE>)
pub trait ExecuteBDDCircuit<BE: Backend> {
fn execute_bdd_circuit_tmp_bytes<R, G>(&self, res_infos: &R, state_size: usize, ggsw_infos: &G) -> usize
where
R: GLWEInfos,
G: GGSWInfos;
fn execute_bdd_circuit<C, G, O, T: UnsignedInteger>(
&self,
out: &mut [GLWE<O>],
inputs: &G,
circuit: &C,
scratch: &mut Scratch<BE>,
) where
G: GetGGSWBit<BE> + BitSize,
C: GetBitCircuitInfo<T>,
O: DataMut,
{
self.execute_bdd_circuit_multi_thread(1, out, inputs, circuit, scratch);
}
fn execute_bdd_circuit_multi_thread<C, G, O, T: UnsignedInteger>(
&self,
threads: usize,
out: &mut [GLWE<O>],
inputs: &G,
circuit: &C,
scratch: &mut Scratch<BE>,
) where
G: GetGGSWBit<BE> + BitSize,
C: GetBitCircuitInfo<T>,
O: DataMut;
@@ -61,13 +90,27 @@ pub trait BitSize {
fn bit_size(&self) -> usize;
}
impl<T: UnsignedInteger, BE: Backend> ExecuteBDDCircuit<T, BE> for Module<BE>
impl<BE: Backend> ExecuteBDDCircuit<BE> for Module<BE>
where
Self: Cmux<BE> + GLWECopy,
Scratch<BE>: ScratchTakeCore<BE>,
{
fn execute_bdd_circuit<C, G, O>(&self, out: &mut [GLWE<O>], inputs: &G, circuit: &C, scratch: &mut Scratch<BE>)
fn execute_bdd_circuit_tmp_bytes<R, G>(&self, res_infos: &R, state_size: usize, ggsw_infos: &G) -> usize
where
R: GLWEInfos,
G: GGSWInfos,
{
2 * state_size * GLWE::bytes_of_from_infos(res_infos) + self.cmux_tmp_bytes(res_infos, res_infos, ggsw_infos)
}
fn execute_bdd_circuit_multi_thread<C, G, O, T: UnsignedInteger>(
&self,
threads: usize,
out: &mut [GLWE<O>],
inputs: &G,
circuit: &C,
scratch: &mut Scratch<BE>,
) where
G: GetGGSWBit<BE> + BitSize,
C: GetBitCircuitInfo<T>,
O: DataMut,
@@ -88,66 +131,43 @@ where
);
}
for (i, out_i) in out.iter_mut().enumerate().take(circuit.output_size()) {
let (nodes, max_inter_state) = circuit.get_circuit(i);
let mut max_state_size = 0;
for i in 0..circuit.output_size() {
let (_, state_size) = circuit.get_circuit(i);
max_state_size = max_state_size.max(state_size)
}
if max_inter_state == 0 {
out_i.data_mut().zero();
} else {
assert!(nodes.len().is_multiple_of(max_inter_state));
let scratch_thread_size: usize = self.execute_bdd_circuit_tmp_bytes(&out[0], max_state_size, &inputs.get_bit(0));
let (mut level, scratch_1) = scratch.take_glwe_slice(max_inter_state * 2, out_i);
assert!(
scratch.available() >= threads * scratch_thread_size,
"scratch.available(): {} < threads:{threads} * scratch_thread_size: {scratch_thread_size}",
scratch.available()
);
level.iter_mut().for_each(|ct| ct.data_mut().zero());
let (mut scratches, _) = scratch.split_mut(threads, scratch_thread_size);
// TODO: implement API on GLWE
level[1]
.data_mut()
.encode_coeff_i64(out_i.base2k().into(), 0, 2, 0, 1);
let chunk_size: usize = circuit.output_size().div_ceil(threads);
let mut level_ref = level.iter_mut().collect_vec();
let (mut prev_level, mut next_level) = level_ref.split_at_mut(max_inter_state);
thread::scope(|scope| {
for (scratch_thread, out_chunk) in scratches
.iter_mut()
.zip(out[..circuit.output_size()].chunks_mut(chunk_size))
{
// Capture chunk + thread scratch by move
scope.spawn(move || {
for (idx, out_i) in out_chunk.iter_mut().enumerate() {
let (nodes, state_size) = circuit.get_circuit(idx);
let (all_but_last, last) = nodes.split_at(nodes.len() - max_inter_state);
for nodes_lvl in all_but_last.chunks_exact(max_inter_state) {
for (j, node) in nodes_lvl.iter().enumerate() {
match node {
Node::Cmux(in_idx, hi_idx, lo_idx) => {
self.cmux(
next_level[j],
prev_level[*hi_idx],
prev_level[*lo_idx],
&inputs.get_bit(*in_idx),
scratch_1,
);
}
Node::Copy => self.glwe_copy(next_level[j], prev_level[j]), /* Update BDD circuits to order Cmux -> Copy -> None so that mem swap can be used */
Node::None => {}
if state_size == 0 {
out_i.data_mut().zero();
} else {
eval_level(self, out_i, inputs, nodes, state_size, *scratch_thread);
}
}
(prev_level, next_level) = (next_level, prev_level);
}
// Last chunck of max_inter_state Nodes is always structured as
// [CMUX, NONE, NONE, ..., NONE]
match &last[0] {
Node::Cmux(in_idx, hi_idx, lo_idx) => {
self.cmux(
out_i,
prev_level[*hi_idx],
prev_level[*lo_idx],
&inputs.get_bit(*in_idx),
scratch_1,
);
}
_ => {
panic!("invalid last node, should be CMUX")
}
}
});
}
}
});
for out_i in out.iter_mut().skip(circuit.output_size()) {
out_i.data_mut().zero();
@@ -155,6 +175,74 @@ where
}
}
fn eval_level<M, R, G, BE: Backend>(
module: &M,
res: &mut R,
inputs: &G,
nodes: &[Node],
state_size: usize,
scratch: &mut Scratch<BE>,
) where
M: Cmux<BE> + GLWECopy,
R: GLWEToMut,
G: GetGGSWBit<BE> + BitSize,
Scratch<BE>: ScratchTakeCore<BE>,
{
assert!(nodes.len().is_multiple_of(state_size));
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let (mut level, scratch_1) = scratch.take_glwe_slice(state_size * 2, res);
level.iter_mut().for_each(|ct| ct.data_mut().zero());
// TODO: implement API on GLWE
level[1]
.data_mut()
.encode_coeff_i64(res.base2k().into(), 0, 2, 0, 1);
let mut level_ref: Vec<&mut GLWE<&mut [u8]>> = level.iter_mut().collect_vec();
let (mut prev_level, mut next_level) = level_ref.split_at_mut(state_size);
let (all_but_last, last) = nodes.split_at(nodes.len() - state_size);
for nodes_lvl in all_but_last.chunks_exact(state_size) {
for (j, node) in nodes_lvl.iter().enumerate() {
match node {
Node::Cmux(in_idx, hi_idx, lo_idx) => {
module.cmux(
next_level[j],
prev_level[*hi_idx],
prev_level[*lo_idx],
&inputs.get_bit(*in_idx),
scratch_1,
);
}
Node::Copy => module.glwe_copy(next_level[j], prev_level[j]), /* Update BDD circuits to order Cmux -> Copy -> None so that mem swap can be used */
Node::None => {}
}
}
(prev_level, next_level) = (next_level, prev_level);
}
// Last chunck of max_inter_state Nodes is always structured as
// [CMUX, NONE, NONE, ..., NONE]
match &last[0] {
Node::Cmux(in_idx, hi_idx, lo_idx) => {
module.cmux(
res,
prev_level[*hi_idx],
prev_level[*lo_idx],
&inputs.get_bit(*in_idx),
scratch_1,
);
}
_ => {
panic!("invalid last node, should be CMUX")
}
}
}
impl<const N: usize> BitCircuit<N> {
pub const fn new(nodes: [Node; N], max_inter_state: usize) -> Self {
Self {