Update FheUint ciphertext naming + circuit evaluation based on GetGGSWBit

This commit is contained in:
Pro7ech
2025-10-28 15:43:30 +01:00
parent a2aecfd380
commit 8c1cc354e3
23 changed files with 265 additions and 282 deletions

View File

@@ -1,9 +1,10 @@
use itertools::Itertools;
use poulpy_core::layouts::prepared::GGSWPreparedToRef;
use std::marker::PhantomData;
use poulpy_core::layouts::GGSWPrepared;
use poulpy_hal::layouts::{Backend, DataMut, DataRef, Module, Scratch};
use crate::tfhe::bdd_arithmetic::{
ExecuteBDDCircuit, FheUintBlocks, FheUintBlocksPrepared, GetBitCircuitInfo, UnsignedInteger, circuits,
BitSize, ExecuteBDDCircuit, FheUint, FheUintPrepared, GetBitCircuitInfo, GetGGSWBit, UnsignedInteger, circuits,
};
impl<T: UnsignedInteger, BE: Backend> ExecuteBDDCircuit2WTo1W<T, BE> for Module<BE> where Self: Sized + ExecuteBDDCircuit<T, BE> {}
@@ -15,10 +16,10 @@ where
/// Operations Z x Z -> Z
fn execute_bdd_circuit_2w_to_1w<R, C, A, B>(
&self,
out: &mut FheUintBlocks<R, T>,
out: &mut FheUint<R, T>,
circuit: &C,
a: &FheUintBlocksPrepared<A, T, BE>,
b: &FheUintBlocksPrepared<B, T, BE>,
a: &FheUintPrepared<A, T, BE>,
b: &FheUintPrepared<B, T, BE>,
scratch: &mut Scratch<BE>,
) where
C: GetBitCircuitInfo<T>,
@@ -26,20 +27,59 @@ where
A: DataRef,
B: DataRef,
{
assert_eq!(out.blocks.len(), T::WORD_SIZE);
assert_eq!(b.blocks.len(), T::WORD_SIZE);
assert_eq!(b.blocks.len(), T::WORD_SIZE);
assert_eq!(out.bits.len(), T::WORD_SIZE);
assert_eq!(b.bits.len(), T::WORD_SIZE);
assert_eq!(b.bits.len(), T::WORD_SIZE);
// Collects inputs into a single array
let inputs: Vec<&dyn GGSWPreparedToRef<BE>> = a
.blocks
.iter()
.map(|x| x as &dyn GGSWPreparedToRef<BE>)
.chain(b.blocks.iter().map(|x| x as &dyn GGSWPreparedToRef<BE>))
.collect_vec();
let inputs: Vec<&dyn GetGGSWBit<BE>> = [a as &dyn GetGGSWBit<BE>, b as &dyn GetGGSWBit<BE>].to_vec();
let helper: FheUintHelper<'_, T, BE> = FheUintHelper {
data: inputs,
_phantom: PhantomData,
};
// Evaluates out[i] = circuit[i](a, b)
self.execute_bdd_circuit(&mut out.blocks, &inputs, circuit, scratch);
self.execute_bdd_circuit(&mut out.bits, &helper, circuit, scratch);
}
}
struct FheUintHelper<'a, T: UnsignedInteger, BE: Backend> {
data: Vec<&'a dyn GetGGSWBit<BE>>,
_phantom: PhantomData<T>,
}
impl<'a, T: UnsignedInteger, BE: Backend> GetGGSWBit<BE> for FheUintHelper<'a, T, BE> {
fn get_bit(&self, bit: usize) -> GGSWPrepared<&[u8], BE> {
let lo: usize = bit % T::WORD_SIZE;
let hi: usize = bit / T::WORD_SIZE;
self.data[hi].get_bit(lo)
}
}
impl<'a, T: UnsignedInteger, BE: Backend> BitSize for FheUintHelper<'a, T, BE> {
fn bit_size(&self) -> usize {
T::WORD_SIZE * self.data.len()
}
}
pub struct JoinedBits<A, B> {
pub lo: A,
pub hi: B,
pub split: usize, // 32 in your example
}
impl<A, B, BE> GetGGSWBit<BE> for JoinedBits<A, B>
where
BE: Backend,
A: GetGGSWBit<BE>,
B: GetGGSWBit<BE>,
{
fn get_bit(&self, bit: usize) -> GGSWPrepared<&[u8], BE> {
if bit < self.split {
self.lo.get_bit(bit)
} else {
self.hi.get_bit(bit - self.split)
}
}
}
@@ -51,8 +91,8 @@ macro_rules! define_bdd_2w_to_1w_trait {
fn $method_name<A, M, B>(
&mut self,
module: &M,
a: &FheUintBlocksPrepared<A, T, BE>,
b: &FheUintBlocksPrepared<B, T, BE>,
a: &FheUintPrepared<A, T, BE>,
b: &FheUintPrepared<B, T, BE>,
scratch: &mut Scratch<BE>,
) where
M: ExecuteBDDCircuit2WTo1W<T, BE>,
@@ -65,12 +105,12 @@ macro_rules! define_bdd_2w_to_1w_trait {
#[macro_export]
macro_rules! impl_bdd_2w_to_1w_trait {
($trait_name:ident, $method_name:ident, $ty:ty, $n:literal, $circuit_ty:ty, $output_circuits:path) => {
impl<D: DataMut, BE: Backend> $trait_name<$ty, BE> for FheUintBlocks<D, $ty> {
impl<D: DataMut, BE: Backend> $trait_name<$ty, BE> for FheUint<D, $ty> {
fn $method_name<A, M, B>(
&mut self,
module: &M,
a: &FheUintBlocksPrepared<A, $ty, BE>,
b: &FheUintBlocksPrepared<B, $ty, BE>,
a: &FheUintPrepared<A, $ty, BE>,
b: &FheUintPrepared<B, $ty, BE>,
scratch: &mut Scratch<BE>,
) where
M: ExecuteBDDCircuit2WTo1W<$ty, BE>,