use core::panic; use std::thread; use itertools::Itertools; use poulpy_core::{ GLWECopy, GLWEExternalProductInternal, GLWENormalize, GLWESub, ScratchTakeCore, layouts::{GGSWInfos, GGSWPrepared, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, prepared::GGSWPreparedToRef}, }; use poulpy_hal::{ api::{ ScratchAvailable, ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftBytesOf, }, layouts::{Backend, DataMut, Module, Scratch, VecZnxBig, ZnxZero}, }; use crate::tfhe::bdd_arithmetic::GetGGSWBit; pub trait BitCircuitInfo: Sync { fn info(&self) -> (&[Node], usize); } pub trait GetBitCircuitInfo: Sync { fn input_size(&self) -> usize; fn output_size(&self) -> usize; fn get_circuit(&self, bit: usize) -> (&[Node], usize); } pub struct BitCircuit { pub nodes: [Node; N], pub max_inter_state: usize, } pub trait BitCircuitFamily { const INPUT_BITS: usize; const OUTPUT_BITS: usize; } pub struct Circuit(pub [C; N]); impl GetBitCircuitInfo for Circuit where C: BitCircuitInfo + BitCircuitFamily, { fn input_size(&self) -> usize { C::INPUT_BITS } fn output_size(&self) -> usize { C::OUTPUT_BITS } fn get_circuit(&self, bit: usize) -> (&[Node], usize) { self.0[bit].info() } } pub trait ExecuteBDDCircuit { fn execute_bdd_circuit_tmp_bytes(&self, res_infos: &R, state_size: usize, ggsw_infos: &G) -> usize where R: GLWEInfos, G: GGSWInfos; fn execute_bdd_circuit(&self, out: &mut [GLWE], inputs: &G, circuit: &C, scratch: &mut Scratch) where G: GetGGSWBit + BitSize, C: GetBitCircuitInfo, O: DataMut, { self.execute_bdd_circuit_multi_thread(1, out, inputs, circuit, scratch); } fn execute_bdd_circuit_multi_thread( &self, threads: usize, out: &mut [GLWE], inputs: &G, circuit: &C, scratch: &mut Scratch, ) where G: GetGGSWBit + BitSize, C: GetBitCircuitInfo, O: DataMut; } pub trait BitSize { fn bit_size(&self) -> usize; } impl ExecuteBDDCircuit for Module where Self: Cmux + GLWECopy, Scratch: ScratchTakeCore, { fn execute_bdd_circuit_tmp_bytes(&self, res_infos: &R, state_size: usize, ggsw_infos: &G) -> usize where R: GLWEInfos, G: GGSWInfos, { 2 * state_size * GLWE::bytes_of_from_infos(res_infos) + self.cmux_tmp_bytes(res_infos, res_infos, ggsw_infos) } fn execute_bdd_circuit_multi_thread( &self, threads: usize, out: &mut [GLWE], inputs: &G, circuit: &C, scratch: &mut Scratch, ) where G: GetGGSWBit + BitSize, C: GetBitCircuitInfo, O: DataMut, { #[cfg(debug_assertions)] { assert!( inputs.bit_size() >= circuit.input_size(), "inputs.bit_size(): {} < circuit.input_size():{}", inputs.bit_size(), circuit.input_size() ); assert!( out.len() >= circuit.output_size(), "out.len(): {} < circuit.output_size(): {}", out.len(), circuit.output_size() ); } let mut max_state_size = 0; for i in 0..circuit.output_size() { let (_, state_size) = circuit.get_circuit(i); max_state_size = max_state_size.max(state_size) } let scratch_thread_size: usize = self.execute_bdd_circuit_tmp_bytes(&out[0], max_state_size, &inputs.get_bit(0)); assert!( scratch.available() >= threads * scratch_thread_size, "scratch.available(): {} < threads:{threads} * scratch_thread_size: {scratch_thread_size}", scratch.available() ); let (mut scratches, _) = scratch.split_mut(threads, scratch_thread_size); let chunk_size: usize = circuit.output_size().div_ceil(threads); thread::scope(|scope| { for (scratch_thread, out_chunk) in scratches .iter_mut() .zip(out[..circuit.output_size()].chunks_mut(chunk_size)) { // Capture chunk + thread scratch by move scope.spawn(move || { for (idx, out_i) in out_chunk.iter_mut().enumerate() { let (nodes, state_size) = circuit.get_circuit(idx); if state_size == 0 { out_i.data_mut().zero(); } else { eval_level(self, out_i, inputs, nodes, state_size, *scratch_thread); } } }); } }); for out_i in out.iter_mut().skip(circuit.output_size()) { out_i.data_mut().zero(); } } } fn eval_level( module: &M, res: &mut R, inputs: &G, nodes: &[Node], state_size: usize, scratch: &mut Scratch, ) where M: Cmux + GLWECopy, R: GLWEToMut, G: GetGGSWBit + BitSize, Scratch: ScratchTakeCore, { assert!(nodes.len().is_multiple_of(state_size)); let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let (mut level, scratch_1) = scratch.take_glwe_slice(state_size * 2, res); level.iter_mut().for_each(|ct| ct.data_mut().zero()); // TODO: implement API on GLWE level[1] .data_mut() .encode_coeff_i64(res.base2k().into(), 0, 2, 0, 1); let mut level_ref: Vec<&mut GLWE<&mut [u8]>> = level.iter_mut().collect_vec(); let (mut prev_level, mut next_level) = level_ref.split_at_mut(state_size); let (all_but_last, last) = nodes.split_at(nodes.len() - state_size); for nodes_lvl in all_but_last.chunks_exact(state_size) { for (j, node) in nodes_lvl.iter().enumerate() { match node { Node::Cmux(in_idx, hi_idx, lo_idx) => { module.cmux( next_level[j], prev_level[*hi_idx], prev_level[*lo_idx], &inputs.get_bit(*in_idx), scratch_1, ); } Node::Copy => module.glwe_copy(next_level[j], prev_level[j]), /* Update BDD circuits to order Cmux -> Copy -> None so that mem swap can be used */ Node::None => {} } } (prev_level, next_level) = (next_level, prev_level); } // Last chunck of max_inter_state Nodes is always structured as // [CMUX, NONE, NONE, ..., NONE] match &last[0] { Node::Cmux(in_idx, hi_idx, lo_idx) => { module.cmux( res, prev_level[*hi_idx], prev_level[*lo_idx], &inputs.get_bit(*in_idx), scratch_1, ); } _ => { panic!("invalid last node, should be CMUX") } } } impl BitCircuit { pub const fn new(nodes: [Node; N], max_inter_state: usize) -> Self { Self { nodes, max_inter_state, } } } impl BitCircuitInfo for BitCircuit { fn info(&self) -> (&[Node], usize) { (self.nodes.as_ref(), self.max_inter_state) } } #[derive(Debug)] pub enum Node { Cmux(usize, usize, usize), Copy, None, } pub trait Cmux where Self: Sized + GLWEExternalProductInternal + GLWESub + VecZnxBigAddSmallInplace + GLWENormalize + VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxBigNormalizeTmpBytes, { fn cmux_tmp_bytes(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize where R: GLWEInfos, A: GLWEInfos, B: GGSWInfos, { let res_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), b_infos.size()); res_dft + self .glwe_external_product_internal_tmp_bytes(res_infos, a_infos, b_infos) .max(self.vec_znx_big_normalize_tmp_bytes()) } fn cmux(&self, res: &mut R, t: &T, f: &F, s: &S, scratch: &mut Scratch) where R: GLWEToMut, T: GLWEToRef, F: GLWEToRef, S: GGSWPreparedToRef + GGSWInfos, Scratch: ScratchTakeCore, { let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let s: &GGSWPrepared<&[u8], BE> = &s.to_ref(); let f: GLWE<&[u8]> = f.to_ref(); let res_base2k: usize = res.base2k().into(); let ggsw_base2k: usize = s.base2k().into(); self.glwe_sub(res, t, &f); let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), s.size()); // Todo optimise let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_external_product_internal(res_dft, res, s, scratch_1); for j in 0..(res.rank() + 1).into() { self.vec_znx_big_add_small_inplace(&mut res_big, j, f.data(), j); self.vec_znx_big_normalize( res_base2k, res.data_mut(), j, ggsw_base2k, &res_big, j, scratch_1, ); } } fn cmux_inplace(&self, res: &mut R, a: &A, s: &S, scratch: &mut Scratch) where R: GLWEToMut, A: GLWEToRef, S: GGSWPreparedToRef + GGSWInfos, Scratch: ScratchTakeCore, { let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let s: &GGSWPrepared<&[u8], BE> = &s.to_ref(); let a: GLWE<&[u8]> = a.to_ref(); let res_base2k: usize = res.base2k().into(); let ggsw_base2k: usize = s.base2k().into(); self.glwe_sub_inplace(res, &a); let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), s.size()); // Todo optimise let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_external_product_internal(res_dft, res, s, scratch_1); for j in 0..(res.rank() + 1).into() { self.vec_znx_big_add_small_inplace(&mut res_big, j, a.data(), j); self.vec_znx_big_normalize( res_base2k, res.data_mut(), j, ggsw_base2k, &res_big, j, scratch_1, ); } } } impl Cmux for Module where Self: Sized + GLWEExternalProductInternal + GLWESub + VecZnxBigAddSmallInplace + GLWENormalize + VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxBigNormalizeTmpBytes, Scratch: ScratchTakeCore, { }