use std::marker::PhantomData; use poulpy_core::layouts::{Base2K, Dnum, Dsize, GGSWInfos, GLWEInfos, LWEInfos, Rank, TorusPrecision, prepared::GGSWPrepared}; #[cfg(test)] use poulpy_core::{ TakeGGSW, layouts::{GGSW, prepared::GLWESecretPrepared}, }; use poulpy_hal::{ api::VmpPMatAlloc, layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}, }; #[cfg(test)] use poulpy_hal::{ api::{ ScratchAvailable, SvpApplyDftToDftInplace, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAlloc, VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VmpPrepare, }, oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl}, source::Source, }; use crate::tfhe::bdd_arithmetic::{FheUintBlocks, FheUintPrepare, ToBits, UnsignedInteger}; #[cfg(test)] pub(crate) struct FheUintBlocksPrepDebug { pub(crate) blocks: Vec>, pub(crate) _base: u8, pub(crate) _phantom: PhantomData, } #[cfg(test)] impl FheUintBlocksPrepDebug, T> { #[allow(dead_code)] pub(crate) fn alloc(module: &Module, infos: &A) -> Self where A: GGSWInfos, { Self::alloc_with( module, infos.base2k(), infos.k(), infos.dnum(), infos.dsize(), infos.rank(), ) } #[allow(dead_code)] pub(crate) fn alloc_with( module: &Module, base2k: Base2K, k: TorusPrecision, dnum: Dnum, dsize: Dsize, rank: Rank, ) -> Self { Self { blocks: (0..T::WORD_SIZE) .map(|_| GGSW::alloc(module.n().into(), base2k, k, rank, dnum, dsize)) .collect(), _base: 1, _phantom: PhantomData, } } } /// A prepared FHE ciphertext encrypting the bits of an [UnsignedInteger]. pub struct FheUintBlocksPrep { pub(crate) blocks: Vec>, pub(crate) _base: u8, pub(crate) _phantom: PhantomData, } impl FheUintBlocksPrep, BE, T> where Module: VmpPMatAlloc, { #[allow(dead_code)] pub(crate) fn alloc(module: &Module, infos: &A) -> Self where A: GGSWInfos, { Self::alloc_with( module, infos.base2k(), infos.k(), infos.dnum(), infos.dsize(), infos.rank(), ) } #[allow(dead_code)] pub(crate) fn alloc_with(module: &Module, base2k: Base2K, k: TorusPrecision, dnum: Dnum, dsize: Dsize, rank: Rank) -> Self where Module: VmpPMatAlloc, { Self { blocks: (0..T::WORD_SIZE) .map(|_| GGSWPrepared::alloc(module, base2k, k, dnum, dsize, rank)) .collect(), _base: 1, _phantom: PhantomData, } } } impl FheUintBlocksPrep { #[allow(dead_code)] #[cfg(test)] pub(crate) fn encrypt_sk( &mut self, module: &Module, value: T, sk: &GLWESecretPrepared, source_xa: &mut Source, source_xe: &mut Source, scratch: &mut Scratch, ) where S: DataRef, Module: VecZnxAddScalarInplace + VecZnxDftBytesOf + VecZnxBigNormalize + VecZnxDftApply + SvpApplyDftToDftInplace + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal + VecZnxNormalize + VecZnxSub + VmpPrepare, Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeGGSW + TakeScalarZnx, { #[cfg(debug_assertions)] { assert!(module.n().is_multiple_of(T::WORD_SIZE)); assert_eq!(self.n(), module.n() as u32); assert_eq!(sk.n(), module.n() as u32); } let (mut tmp_ggsw, scratch_1) = scratch.take_ggsw(self); let (mut pt, scratch_2) = scratch_1.take_scalar_znx(module.n(), 1); for i in 0..T::WORD_SIZE { use poulpy_core::layouts::prepared::Prepare; use poulpy_hal::layouts::ZnxViewMut; pt.at_mut(0, 0)[0] = value.bit(i) as i64; tmp_ggsw.encrypt_sk(&module, &pt, sk, source_xa, source_xe, scratch_2); self.blocks[i].prepare(module, &tmp_ggsw, scratch_2); } } /// Prepares [FheUintBits] to [FheUintBitsPrep]. pub fn prepare(&mut self, module: &Module, bits: &FheUintBlocks, key: &KEY, scratch: &mut Scratch) where BIT: DataRef, KEY: FheUintPrepare, FheUintBlocks>, { key.prepare(module, self, bits, scratch); } } #[cfg(test)] impl FheUintBlocksPrepDebug { pub(crate) fn prepare( &mut self, module: &Module, bits: &FheUintBlocks, key: &KEY, scratch: &mut Scratch, ) where BIT: DataRef, KEY: FheUintPrepare, FheUintBlocks>, { key.prepare(module, self, bits, scratch); } } #[cfg(test)] impl FheUintBlocksPrepDebug { #[allow(dead_code)] pub(crate) fn noise(&self, module: &Module, sk: &GLWESecretPrepared, want: T) where Module: VecZnxDftBytesOf + VecZnxBigBytesOf + VecZnxDftApply + SvpApplyDftToDftInplace + VecZnxIdftApplyConsume + VecZnxBigAddInplace + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxNormalizeTmpBytes + VecZnxBigAlloc + VecZnxDftAlloc + VecZnxBigNormalizeTmpBytes + VecZnxIdftApplyTmpA + VecZnxAddScalarInplace + VecZnxSubInplace, BE: Backend + TakeVecZnxDftImpl + TakeVecZnxBigImpl + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, { for (i, ggsw) in self.blocks.iter().enumerate() { use poulpy_hal::layouts::{ScalarZnx, ZnxViewMut}; let mut pt_want = ScalarZnx::alloc(self.n().into(), 1); pt_want.at_mut(0, 0)[0] = want.bit(i) as i64; ggsw.print_noise(module, sk, &pt_want); } } } impl LWEInfos for FheUintBlocksPrep { fn base2k(&self) -> poulpy_core::layouts::Base2K { self.blocks[0].base2k() } fn k(&self) -> poulpy_core::layouts::TorusPrecision { self.blocks[0].k() } fn n(&self) -> poulpy_core::layouts::RingDegree { self.blocks[0].n() } } impl GLWEInfos for FheUintBlocksPrep { fn rank(&self) -> poulpy_core::layouts::Rank { self.blocks[0].rank() } } impl GGSWInfos for FheUintBlocksPrep { fn dsize(&self) -> poulpy_core::layouts::Dsize { self.blocks[0].dsize() } fn dnum(&self) -> poulpy_core::layouts::Dnum { self.blocks[0].dnum() } } #[cfg(test)] impl LWEInfos for FheUintBlocksPrepDebug { fn base2k(&self) -> poulpy_core::layouts::Base2K { self.blocks[0].base2k() } fn k(&self) -> poulpy_core::layouts::TorusPrecision { self.blocks[0].k() } fn n(&self) -> poulpy_core::layouts::RingDegree { self.blocks[0].n() } } #[cfg(test)] impl GLWEInfos for FheUintBlocksPrepDebug { fn rank(&self) -> poulpy_core::layouts::Rank { self.blocks[0].rank() } } #[cfg(test)] impl GGSWInfos for FheUintBlocksPrepDebug { fn dsize(&self) -> poulpy_core::layouts::Dsize { self.blocks[0].dsize() } fn dnum(&self) -> poulpy_core::layouts::Dnum { self.blocks[0].dnum() } }