use itertools::Itertools; use poulpy_core::layouts::prepared::GGSWPreparedToRef; use poulpy_hal::layouts::{Backend, DataMut, DataRef, Module, Scratch}; use crate::tfhe::bdd_arithmetic::{ ExecuteBDDCircuit, FheUintBlocks, FheUintBlocksPrepared, GetBitCircuitInfo, UnsignedInteger, circuits, }; impl ExecuteBDDCircuit2WTo1W for Module where Self: Sized + ExecuteBDDCircuit {} pub trait ExecuteBDDCircuit2WTo1W where Self: Sized + ExecuteBDDCircuit, { /// Operations Z x Z -> Z fn execute_bdd_circuit_2w_to_1w( &self, out: &mut FheUintBlocks, circuit: &C, a: &FheUintBlocksPrepared, b: &FheUintBlocksPrepared, scratch: &mut Scratch, ) where C: GetBitCircuitInfo, R: DataMut, A: DataRef, B: DataRef, { assert_eq!(out.blocks.len(), T::WORD_SIZE); assert_eq!(b.blocks.len(), T::WORD_SIZE); assert_eq!(b.blocks.len(), T::WORD_SIZE); // Collects inputs into a single array let inputs: Vec<&dyn GGSWPreparedToRef> = a .blocks .iter() .map(|x| x as &dyn GGSWPreparedToRef) .chain(b.blocks.iter().map(|x| x as &dyn GGSWPreparedToRef)) .collect_vec(); // Evaluates out[i] = circuit[i](a, b) self.execute_bdd_circuit(&mut out.blocks, &inputs, circuit, scratch); } } #[macro_export] macro_rules! define_bdd_2w_to_1w_trait { ($(#[$meta:meta])* $vis:vis $trait_name:ident, $method_name:ident) => { $(#[$meta])* $vis trait $trait_name { fn $method_name( &mut self, module: &M, a: &FheUintBlocksPrepared, b: &FheUintBlocksPrepared, scratch: &mut Scratch, ) where M: ExecuteBDDCircuit2WTo1W, A: DataRef, B: DataRef; } }; } #[macro_export] macro_rules! impl_bdd_2w_to_1w_trait { ($trait_name:ident, $method_name:ident, $ty:ty, $n:literal, $circuit_ty:ty, $output_circuits:path) => { impl $trait_name<$ty, BE> for FheUintBlocks { fn $method_name( &mut self, module: &M, a: &FheUintBlocksPrepared, b: &FheUintBlocksPrepared, scratch: &mut Scratch, ) where M: ExecuteBDDCircuit2WTo1W<$ty, BE>, A: DataRef, B: DataRef, { module.execute_bdd_circuit_2w_to_1w(self, &$output_circuits, a, b, scratch) } } }; } define_bdd_2w_to_1w_trait!(pub Add, add); define_bdd_2w_to_1w_trait!(pub Sub, sub); define_bdd_2w_to_1w_trait!(pub Sll, sll); define_bdd_2w_to_1w_trait!(pub Sra, sra); define_bdd_2w_to_1w_trait!(pub Srl, srl); define_bdd_2w_to_1w_trait!(pub Slt, slt); define_bdd_2w_to_1w_trait!(pub Sltu, sltu); define_bdd_2w_to_1w_trait!(pub Or, or); define_bdd_2w_to_1w_trait!(pub And, and); define_bdd_2w_to_1w_trait!(pub Xor, xor); impl_bdd_2w_to_1w_trait!( Add, add, u32, 32, circuits::u32::add_codegen::AnyBitCircuit, circuits::u32::add_codegen::OUTPUT_CIRCUITS ); impl_bdd_2w_to_1w_trait!( Sub, sub, u32, 32, circuits::u32::sub_codegen::AnyBitCircuit, circuits::u32::sub_codegen::OUTPUT_CIRCUITS ); impl_bdd_2w_to_1w_trait!( Sll, sll, u32, 32, circuits::u32::sll_codegen::AnyBitCircuit, circuits::u32::sll_codegen::OUTPUT_CIRCUITS ); impl_bdd_2w_to_1w_trait!( Sra, sra, u32, 32, circuits::u32::sra_codegen::AnyBitCircuit, circuits::u32::sra_codegen::OUTPUT_CIRCUITS ); impl_bdd_2w_to_1w_trait!( Srl, srl, u32, 32, circuits::u32::srl_codegen::AnyBitCircuit, circuits::u32::srl_codegen::OUTPUT_CIRCUITS ); impl_bdd_2w_to_1w_trait!( Slt, slt, u32, 1, circuits::u32::slt_codegen::AnyBitCircuit, circuits::u32::slt_codegen::OUTPUT_CIRCUITS ); impl_bdd_2w_to_1w_trait!( Sltu, sltu, u32, 1, circuits::u32::sltu_codegen::AnyBitCircuit, circuits::u32::sltu_codegen::OUTPUT_CIRCUITS ); impl_bdd_2w_to_1w_trait!( And, and, u32, 32, circuits::u32::and_codegen::AnyBitCircuit, circuits::u32::and_codegen::OUTPUT_CIRCUITS ); impl_bdd_2w_to_1w_trait!( Or, or, u32, 32, circuits::u32::or_codegen::AnyBitCircuit, circuits::u32::or_codegen::OUTPUT_CIRCUITS ); impl_bdd_2w_to_1w_trait!( Xor, xor, u32, 32, circuits::u32::xor_codegen::AnyBitCircuit, circuits::u32::xor_codegen::OUTPUT_CIRCUITS );