use std::marker::PhantomData; use poulpy_core::layouts::{ Base2K, Dnum, Dsize, GGSWInfos, GGSWPreparedFactory, GLWEInfos, LWEInfos, Rank, TorusPrecision, prepared::GGSWPrepared, }; use poulpy_core::layouts::{GGSWPreparedToMut, GGSWPreparedToRef}; use poulpy_core::{GGSWEncryptSk, ScratchTakeCore, layouts::GLWESecretPreparedToRef}; use poulpy_hal::layouts::{Backend, Data, DataRef, Module}; use poulpy_hal::{ api::ModuleN, layouts::{DataMut, Scratch}, source::Source, }; use crate::tfhe::bdd_arithmetic::ToBits; use crate::tfhe::bdd_arithmetic::UnsignedInteger; /// A prepared FHE ciphertext encrypting the bits of an [UnsignedInteger]. pub struct FheUintPrepared { pub(crate) bits: Vec>, pub(crate) _phantom: PhantomData, } impl FheUintBlocksPreparedFactory for Module where Self: Sized + GGSWPreparedFactory { } pub trait GetGGSWBit { fn get_bit(&self, bit: usize) -> GGSWPrepared<&[u8], BE>; } impl GetGGSWBit for FheUintPrepared { fn get_bit(&self, bit: usize) -> GGSWPrepared<&[u8], BE> { assert!(bit <= self.bits.len()); self.bits[bit].to_ref() } } pub trait GetGGSWBitMut { fn get_bit(&mut self, bit: usize) -> GGSWPrepared<&mut [u8], BE>; } impl GetGGSWBitMut for FheUintPrepared { fn get_bit(&mut self, bit: usize) -> GGSWPrepared<&mut [u8], BE> { assert!(bit <= self.bits.len()); self.bits[bit].to_mut() } } pub trait FheUintBlocksPreparedFactory where Self: Sized + GGSWPreparedFactory, { fn alloc_fhe_uint_prepared( &self, base2k: Base2K, k: TorusPrecision, dnum: Dnum, dsize: Dsize, rank: Rank, ) -> FheUintPrepared, T, BE> { FheUintPrepared { bits: (0..T::WORD_SIZE) .map(|_| GGSWPrepared::alloc(self, base2k, k, dnum, dsize, rank)) .collect(), _phantom: PhantomData, } } fn alloc_fhe_uint_prepared_from_infos(&self, infos: &A) -> FheUintPrepared, T, BE> where A: GGSWInfos, { self.alloc_fhe_uint_prepared( infos.base2k(), infos.k(), infos.dnum(), infos.dsize(), infos.rank(), ) } } impl FheUintPrepared, T, BE> { pub fn alloc(module: &M, infos: &A) -> Self where A: GGSWInfos, M: FheUintBlocksPreparedFactory, { module.alloc_fhe_uint_prepared_from_infos(infos) } pub fn alloc_with(module: &M, base2k: Base2K, k: TorusPrecision, dnum: Dnum, dsize: Dsize, rank: Rank) -> Self where M: FheUintBlocksPreparedFactory, { module.alloc_fhe_uint_prepared(base2k, k, dnum, dsize, rank) } } impl FheUintBlocksPreparedEncryptSk for Module where Self: Sized + ModuleN + GGSWEncryptSk + GGSWPreparedFactory { } pub trait FheUintBlocksPreparedEncryptSk where Self: Sized + ModuleN + GGSWEncryptSk + GGSWPreparedFactory, { fn fhe_uint_prepared_encrypt_sk( &self, res: &mut FheUintPrepared, value: T, sk: &S, source_xa: &mut Source, source_xe: &mut Source, scratch: &mut Scratch, ) where DM: DataMut, S: GLWESecretPreparedToRef + GLWEInfos, Scratch: ScratchTakeCore, { use poulpy_hal::{api::ScratchTakeBasic, layouts::ZnxZero}; assert!(self.n().is_multiple_of(T::WORD_SIZE)); assert_eq!(res.n(), self.n() as u32); assert_eq!(sk.n(), self.n() as u32); let (mut tmp_ggsw, scratch_1) = scratch.take_ggsw(res); let (mut pt, scratch_2) = scratch_1.take_scalar_znx(self.n(), 1); pt.zero(); for i in 0..T::WORD_SIZE { use poulpy_hal::layouts::ZnxViewMut; pt.at_mut(0, 0)[0] = value.bit(i) as i64; tmp_ggsw.encrypt_sk(self, &pt, sk, source_xa, source_xe, scratch_2); res.bits[i].prepare(self, &tmp_ggsw, scratch_2); } } } impl FheUintPrepared { pub fn encrypt_sk( &mut self, module: &M, value: T, sk: &S, source_xa: &mut Source, source_xe: &mut Source, scratch: &mut Scratch, ) where S: GLWESecretPreparedToRef + GLWEInfos, M: FheUintBlocksPreparedEncryptSk, Scratch: ScratchTakeCore, { module.fhe_uint_prepared_encrypt_sk(self, value, sk, source_xa, source_xe, scratch); } } impl LWEInfos for FheUintPrepared { fn base2k(&self) -> poulpy_core::layouts::Base2K { self.bits[0].base2k() } fn k(&self) -> poulpy_core::layouts::TorusPrecision { self.bits[0].k() } fn n(&self) -> poulpy_core::layouts::Degree { self.bits[0].n() } } impl GLWEInfos for FheUintPrepared { fn rank(&self) -> poulpy_core::layouts::Rank { self.bits[0].rank() } } impl GGSWInfos for FheUintPrepared { fn dsize(&self) -> poulpy_core::layouts::Dsize { self.bits[0].dsize() } fn dnum(&self) -> poulpy_core::layouts::Dnum { self.bits[0].dnum() } }