Files
poulpy/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs
Pro7ech cedf7b9c59 clippy
2025-10-22 16:43:46 +02:00

186 lines
4.8 KiB
Rust

use itertools::Itertools;
use poulpy_core::layouts::prepared::GGSWPreparedToRef;
use poulpy_hal::layouts::{Backend, DataMut, DataRef, Module, Scratch};
use crate::tfhe::bdd_arithmetic::{
ExecuteBDDCircuit, FheUintBlocks, FheUintBlocksPrepared, GetBitCircuitInfo, UnsignedInteger, circuits,
};
impl<T: UnsignedInteger, BE: Backend> ExecuteBDDCircuit2WTo1W<T, BE> for Module<BE> where Self: Sized + ExecuteBDDCircuit<T, BE> {}
pub trait ExecuteBDDCircuit2WTo1W<T: UnsignedInteger, BE: Backend>
where
Self: Sized + ExecuteBDDCircuit<T, BE>,
{
/// Operations Z x Z -> Z
fn execute_bdd_circuit_2w_to_1w<R, C, A, B>(
&self,
out: &mut FheUintBlocks<R, T>,
circuit: &C,
a: &FheUintBlocksPrepared<A, T, BE>,
b: &FheUintBlocksPrepared<B, T, BE>,
scratch: &mut Scratch<BE>,
) where
C: GetBitCircuitInfo<T>,
R: DataMut,
A: DataRef,
B: DataRef,
{
assert_eq!(out.blocks.len(), T::WORD_SIZE);
assert_eq!(b.blocks.len(), T::WORD_SIZE);
assert_eq!(b.blocks.len(), T::WORD_SIZE);
// Collects inputs into a single array
let inputs: Vec<&dyn GGSWPreparedToRef<BE>> = a
.blocks
.iter()
.map(|x| x as &dyn GGSWPreparedToRef<BE>)
.chain(b.blocks.iter().map(|x| x as &dyn GGSWPreparedToRef<BE>))
.collect_vec();
// Evaluates out[i] = circuit[i](a, b)
self.execute_bdd_circuit(&mut out.blocks, &inputs, circuit, scratch);
}
}
#[macro_export]
macro_rules! define_bdd_2w_to_1w_trait {
($(#[$meta:meta])* $vis:vis $trait_name:ident, $method_name:ident) => {
$(#[$meta])*
$vis trait $trait_name<T: UnsignedInteger, BE: Backend> {
fn $method_name<A, M, B>(
&mut self,
module: &M,
a: &FheUintBlocksPrepared<A, T, BE>,
b: &FheUintBlocksPrepared<B, T, BE>,
scratch: &mut Scratch<BE>,
) where
M: ExecuteBDDCircuit2WTo1W<T, BE>,
A: DataRef,
B: DataRef;
}
};
}
#[macro_export]
macro_rules! impl_bdd_2w_to_1w_trait {
($trait_name:ident, $method_name:ident, $ty:ty, $n:literal, $circuit_ty:ty, $output_circuits:path) => {
impl<D: DataMut, BE: Backend> $trait_name<$ty, BE> for FheUintBlocks<D, $ty> {
fn $method_name<A, M, B>(
&mut self,
module: &M,
a: &FheUintBlocksPrepared<A, $ty, BE>,
b: &FheUintBlocksPrepared<B, $ty, BE>,
scratch: &mut Scratch<BE>,
) where
M: ExecuteBDDCircuit2WTo1W<$ty, BE>,
A: DataRef,
B: DataRef,
{
module.execute_bdd_circuit_2w_to_1w(self, &$output_circuits, a, b, scratch)
}
}
};
}
define_bdd_2w_to_1w_trait!(pub Add, add);
define_bdd_2w_to_1w_trait!(pub Sub, sub);
define_bdd_2w_to_1w_trait!(pub Sll, sll);
define_bdd_2w_to_1w_trait!(pub Sra, sra);
define_bdd_2w_to_1w_trait!(pub Srl, srl);
define_bdd_2w_to_1w_trait!(pub Slt, slt);
define_bdd_2w_to_1w_trait!(pub Sltu, sltu);
define_bdd_2w_to_1w_trait!(pub Or, or);
define_bdd_2w_to_1w_trait!(pub And, and);
define_bdd_2w_to_1w_trait!(pub Xor, xor);
impl_bdd_2w_to_1w_trait!(
Add,
add,
u32,
32,
circuits::u32::add_codegen::AnyBitCircuit,
circuits::u32::add_codegen::OUTPUT_CIRCUITS
);
impl_bdd_2w_to_1w_trait!(
Sub,
sub,
u32,
32,
circuits::u32::sub_codegen::AnyBitCircuit,
circuits::u32::sub_codegen::OUTPUT_CIRCUITS
);
impl_bdd_2w_to_1w_trait!(
Sll,
sll,
u32,
32,
circuits::u32::sll_codegen::AnyBitCircuit,
circuits::u32::sll_codegen::OUTPUT_CIRCUITS
);
impl_bdd_2w_to_1w_trait!(
Sra,
sra,
u32,
32,
circuits::u32::sra_codegen::AnyBitCircuit,
circuits::u32::sra_codegen::OUTPUT_CIRCUITS
);
impl_bdd_2w_to_1w_trait!(
Srl,
srl,
u32,
32,
circuits::u32::srl_codegen::AnyBitCircuit,
circuits::u32::srl_codegen::OUTPUT_CIRCUITS
);
impl_bdd_2w_to_1w_trait!(
Slt,
slt,
u32,
1,
circuits::u32::slt_codegen::AnyBitCircuit,
circuits::u32::slt_codegen::OUTPUT_CIRCUITS
);
impl_bdd_2w_to_1w_trait!(
Sltu,
sltu,
u32,
1,
circuits::u32::sltu_codegen::AnyBitCircuit,
circuits::u32::sltu_codegen::OUTPUT_CIRCUITS
);
impl_bdd_2w_to_1w_trait!(
And,
and,
u32,
32,
circuits::u32::and_codegen::AnyBitCircuit,
circuits::u32::and_codegen::OUTPUT_CIRCUITS
);
impl_bdd_2w_to_1w_trait!(
Or,
or,
u32,
32,
circuits::u32::or_codegen::AnyBitCircuit,
circuits::u32::or_codegen::OUTPUT_CIRCUITS
);
impl_bdd_2w_to_1w_trait!(
Xor,
xor,
u32,
32,
circuits::u32::xor_codegen::AnyBitCircuit,
circuits::u32::xor_codegen::OUTPUT_CIRCUITS
);