From 1d23dfc078beefd9f6a174000dd089e636065473 Mon Sep 17 00:00:00 2001 From: Pro7ech Date: Fri, 7 Nov 2025 08:49:32 +0100 Subject: [PATCH] Update to custom fheuint prepare --- poulpy-backend/src/cpu_fft64_avx/vmp.rs | 12 ++++- poulpy-backend/src/cpu_fft64_ref/vmp.rs | 12 ++++- .../src/cpu_spqlios/fft64/vmp_pmat.rs | 12 ++++- poulpy-core/src/layouts/prepared/ggsw.rs | 13 +++-- poulpy-core/src/scratch.rs | 6 +-- poulpy-hal/src/api/vmp_pmat.rs | 6 +++ poulpy-hal/src/delegates/vmp_pmat.rs | 16 +++++- poulpy-hal/src/oep/vmp_pmat.rs | 10 ++++ poulpy-hal/src/reference/fft64/vmp.rs | 9 +++- .../ciphertexts/fhe_uint_prepared.rs | 50 +++++++++++-------- 10 files changed, 112 insertions(+), 34 deletions(-) diff --git a/poulpy-backend/src/cpu_fft64_avx/vmp.rs b/poulpy-backend/src/cpu_fft64_avx/vmp.rs index fcb6236..b98ef1d 100644 --- a/poulpy-backend/src/cpu_fft64_avx/vmp.rs +++ b/poulpy-backend/src/cpu_fft64_avx/vmp.rs @@ -6,10 +6,11 @@ use poulpy_hal::{ }, oep::{ VmpApplyDftToDftAddImpl, VmpApplyDftToDftAddTmpBytesImpl, VmpApplyDftToDftImpl, VmpApplyDftToDftTmpBytesImpl, - VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPrepareImpl, VmpPrepareTmpBytesImpl, + VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPrepareImpl, VmpPrepareTmpBytesImpl, VmpZeroImpl, }, reference::fft64::vmp::{ vmp_apply_dft_to_dft, vmp_apply_dft_to_dft_add, vmp_apply_dft_to_dft_tmp_bytes, vmp_prepare, vmp_prepare_tmp_bytes, + vmp_zero, }, }; @@ -141,3 +142,12 @@ unsafe impl VmpApplyDftToDftAddTmpBytesImpl for FFT64Avx { vmp_apply_dft_to_dft_tmp_bytes(a_size, b_rows, b_cols_in) } } + +unsafe impl VmpZeroImpl for FFT64Avx { + fn vmp_zero_impl(_module: &Module, res: &mut R) + where + R: VmpPMatToMut, + { + vmp_zero(res); + } +} diff --git a/poulpy-backend/src/cpu_fft64_ref/vmp.rs b/poulpy-backend/src/cpu_fft64_ref/vmp.rs index 34cbf07..3cee4c7 100644 --- a/poulpy-backend/src/cpu_fft64_ref/vmp.rs +++ b/poulpy-backend/src/cpu_fft64_ref/vmp.rs @@ -6,10 +6,11 @@ use poulpy_hal::{ }, oep::{ VmpApplyDftToDftAddImpl, VmpApplyDftToDftAddTmpBytesImpl, VmpApplyDftToDftImpl, VmpApplyDftToDftTmpBytesImpl, - VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPrepareImpl, VmpPrepareTmpBytesImpl, + VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPrepareImpl, VmpPrepareTmpBytesImpl, VmpZeroImpl, }, reference::fft64::vmp::{ vmp_apply_dft_to_dft, vmp_apply_dft_to_dft_add, vmp_apply_dft_to_dft_tmp_bytes, vmp_prepare, vmp_prepare_tmp_bytes, + vmp_zero, }, }; @@ -141,3 +142,12 @@ unsafe impl VmpApplyDftToDftAddTmpBytesImpl for FFT64Ref { vmp_apply_dft_to_dft_tmp_bytes(a_size, b_rows, b_cols_in) } } + +unsafe impl VmpZeroImpl for FFT64Ref { + fn vmp_zero_impl(_module: &Module, res: &mut R) + where + R: VmpPMatToMut, + { + vmp_zero(res); + } +} diff --git a/poulpy-backend/src/cpu_spqlios/fft64/vmp_pmat.rs b/poulpy-backend/src/cpu_spqlios/fft64/vmp_pmat.rs index ff1eaa2..bd60680 100644 --- a/poulpy-backend/src/cpu_spqlios/fft64/vmp_pmat.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/vmp_pmat.rs @@ -6,8 +6,9 @@ use poulpy_hal::{ }, oep::{ VmpApplyDftToDftAddImpl, VmpApplyDftToDftAddTmpBytesImpl, VmpApplyDftToDftImpl, VmpApplyDftToDftTmpBytesImpl, - VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPMatFromBytesImpl, VmpPrepareImpl, VmpPrepareTmpBytesImpl, + VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPMatFromBytesImpl, VmpPrepareImpl, VmpPrepareTmpBytesImpl, VmpZeroImpl, }, + reference::fft64::vmp::vmp_zero, }; use crate::cpu_spqlios::{ @@ -269,3 +270,12 @@ unsafe impl VmpApplyDftToDftAddImpl for FFT64Spqlios { } } } + +unsafe impl VmpZeroImpl for FFT64Spqlios { + fn vmp_zero_impl(_module: &Module, res: &mut R) + where + R: VmpPMatToMut, + { + vmp_zero(res); + } +} diff --git a/poulpy-core/src/layouts/prepared/ggsw.rs b/poulpy-core/src/layouts/prepared/ggsw.rs index 3115980..39bb726 100644 --- a/poulpy-core/src/layouts/prepared/ggsw.rs +++ b/poulpy-core/src/layouts/prepared/ggsw.rs @@ -1,5 +1,5 @@ use poulpy_hal::{ - api::{VmpPMatAlloc, VmpPMatBytesOf, VmpPrepare, VmpPrepareTmpBytes}, + api::{VmpPMatAlloc, VmpPMatBytesOf, VmpPrepare, VmpPrepareTmpBytes, VmpZero}, layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VmpPMat, VmpPMatToMut, VmpPMatToRef, ZnxInfos}, }; @@ -51,7 +51,7 @@ impl GGSWInfos for GGSWPrepared { pub trait GGSWPreparedFactory where - Self: GetDegree + VmpPMatAlloc + VmpPMatBytesOf + VmpPrepareTmpBytes + VmpPrepare, + Self: GetDegree + VmpPMatAlloc + VmpPMatBytesOf + VmpPrepareTmpBytes + VmpPrepare + VmpZero, { fn alloc_ggsw_prepared( &self, @@ -163,7 +163,7 @@ where } impl GGSWPreparedFactory for Module where - Self: GetDegree + VmpPMatAlloc + VmpPMatBytesOf + VmpPrepareTmpBytes + VmpPrepare + Self: GetDegree + VmpPMatAlloc + VmpPMatBytesOf + VmpPrepareTmpBytes + VmpPrepare + VmpZero { } @@ -223,6 +223,13 @@ impl GGSWPrepared { { module.ggsw_prepare(self, other, scratch); } + + pub fn zero(&mut self, module: &M) + where + M: GGSWPreparedFactory, + { + module.vmp_zero(&mut self.data); + } } pub trait GGSWPreparedToMut { diff --git a/poulpy-core/src/scratch.rs b/poulpy-core/src/scratch.rs index 252385c..2976214 100644 --- a/poulpy-core/src/scratch.rs +++ b/poulpy-core/src/scratch.rs @@ -155,11 +155,7 @@ where ) } - fn take_ggsw_slice( - &mut self, - size: usize, - infos: &A, - ) -> (Vec>, &mut Self) + fn take_ggsw_slice(&mut self, size: usize, infos: &A) -> (Vec>, &mut Self) where A: GGSWInfos, { diff --git a/poulpy-hal/src/api/vmp_pmat.rs b/poulpy-hal/src/api/vmp_pmat.rs index de3433a..e5ebc64 100644 --- a/poulpy-hal/src/api/vmp_pmat.rs +++ b/poulpy-hal/src/api/vmp_pmat.rs @@ -111,3 +111,9 @@ pub trait VmpApplyDftToDftAdd { A: VecZnxDftToRef, C: VmpPMatToRef; } + +pub trait VmpZero { + fn vmp_zero(&self, res: &mut R) + where + R: VmpPMatToMut; +} diff --git a/poulpy-hal/src/delegates/vmp_pmat.rs b/poulpy-hal/src/delegates/vmp_pmat.rs index 2c65508..69598cb 100644 --- a/poulpy-hal/src/delegates/vmp_pmat.rs +++ b/poulpy-hal/src/delegates/vmp_pmat.rs @@ -1,7 +1,7 @@ use crate::{ api::{ VmpApplyDft, VmpApplyDftTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftAddTmpBytes, - VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPMatBytesOf, VmpPMatFromBytes, VmpPrepare, VmpPrepareTmpBytes, + VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPMatBytesOf, VmpPMatFromBytes, VmpPrepare, VmpPrepareTmpBytes, VmpZero, }, layouts::{ Backend, MatZnxToRef, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, VmpPMatOwned, VmpPMatToMut, @@ -10,7 +10,7 @@ use crate::{ oep::{ VmpApplyDftImpl, VmpApplyDftTmpBytesImpl, VmpApplyDftToDftAddImpl, VmpApplyDftToDftAddTmpBytesImpl, VmpApplyDftToDftImpl, VmpApplyDftToDftTmpBytesImpl, VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPMatFromBytesImpl, VmpPrepareImpl, - VmpPrepareTmpBytesImpl, + VmpPrepareTmpBytesImpl, VmpZeroImpl, }, }; @@ -161,3 +161,15 @@ where B::vmp_apply_dft_to_dft_add_impl(self, res, a, b, scale, scratch); } } + +impl VmpZero for Module +where + B: Backend + VmpZeroImpl, +{ + fn vmp_zero(&self, res: &mut R) + where + R: VmpPMatToMut, + { + B::vmp_zero_impl(self, res); + } +} diff --git a/poulpy-hal/src/oep/vmp_pmat.rs b/poulpy-hal/src/oep/vmp_pmat.rs index bdca416..8813f41 100644 --- a/poulpy-hal/src/oep/vmp_pmat.rs +++ b/poulpy-hal/src/oep/vmp_pmat.rs @@ -145,3 +145,13 @@ pub unsafe trait VmpApplyDftToDftAddImpl { A: VecZnxDftToRef, C: VmpPMatToRef; } + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See TODO. +/// * See [crate::api::VmpZero] for corresponding public API. +/// # Safety [crate::doc::backend_safety] for safety contract. +pub unsafe trait VmpZeroImpl { + fn vmp_zero_impl(module: &Module, res: &mut R) + where + R: VmpPMatToMut; +} diff --git a/poulpy-hal/src/reference/fft64/vmp.rs b/poulpy-hal/src/reference/fft64/vmp.rs index 07e1a8d..ac401b3 100644 --- a/poulpy-hal/src/reference/fft64/vmp.rs +++ b/poulpy-hal/src/reference/fft64/vmp.rs @@ -1,6 +1,6 @@ use crate::{ cast_mut, - layouts::{MatZnx, MatZnxToRef, VecZnx, VecZnxToRef, VmpPMatToMut, ZnxView, ZnxViewMut}, + layouts::{DataViewMut, MatZnx, MatZnxToRef, VecZnx, VecZnxToRef, VmpPMatToMut, ZnxView, ZnxViewMut}, oep::VecZnxDftAllocBytesImpl, reference::fft64::{ reim::{ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimZero}, @@ -157,6 +157,13 @@ pub fn vmp_apply_dft_to_dft_tmp_bytes(a_size: usize, prows: usize, pcols_in: usi (16 + 8 * row_max * pcols_in) * size_of::() } +pub fn vmp_zero(res: &mut R) +where + R: VmpPMatToMut, +{ + res.to_mut().data_mut().fill(0); +} + pub fn vmp_apply_dft_to_dft(res: &mut R, a: &A, pmat: &M, tmp_bytes: &mut [f64]) where BE: Backend diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared.rs index 3e6d447..594755a 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared.rs @@ -217,18 +217,19 @@ pub trait FheUintPrepare; - fn fhe_uint_prepare_partial( + fn fhe_uint_prepare_custom( &self, res: &mut FheUintPrepared, bits: &FheUint, - count: usize, + bit_start: usize, + bit_end: usize, key: &K, scratch: &mut Scratch, ) where DM: DataMut, DB: DataRef, DK: DataRef, - K: BDDKeyHelper; + K: BDDKeyHelper; } impl FheUintPrepare for Module @@ -261,22 +262,15 @@ where DK: DataRef, K: BDDKeyHelper, { - let (cbt, ks) = key.get_cbt_key(); - - let mut lwe: LWE> = LWE::alloc_from_infos(bits); //TODO: add TakeLWE - let (mut tmp_ggsw, scratch_1) = scratch.take_ggsw(res); - for (bit, dst) in res.bits.iter_mut().enumerate() { - bits.get_bit_lwe(self, bit, &mut lwe, ks, scratch_1); - cbt.execute_to_constant(self, &mut tmp_ggsw, &lwe, 1, 1, scratch_1); - dst.prepare(self, &tmp_ggsw, scratch_1); - } + self.fhe_uint_prepare_custom(res, bits, 0, T::BITS as usize, key, scratch); } - fn fhe_uint_prepare_partial( + fn fhe_uint_prepare_custom( &self, res: &mut FheUintPrepared, bits: &FheUint, - count: usize, + bit_start: usize, + bit_end: usize, key: &K, scratch: &mut Scratch, ) where @@ -289,12 +283,21 @@ where let mut lwe: LWE> = LWE::alloc_from_infos(bits); //TODO: add TakeLWE let (mut tmp_ggsw, scratch_1) = scratch.take_ggsw(res); - for (bit, dst) in res.bits[0..count].iter_mut().enumerate() { // TODO: set the rest of the bits to a prepared zero GGSW + for (bit, dst) in res.bits[bit_start..bit_end].iter_mut().enumerate() { + // TODO: set the rest of the bits to a prepared zero GGSW bits.get_bit_lwe(self, bit, &mut lwe, ks, scratch_1); cbt.execute_to_constant(self, &mut tmp_ggsw, &lwe, 1, 1, scratch_1); dst.prepare(self, &tmp_ggsw, scratch_1); } - } + + for i in 0..bit_start { + res.bits[i].zero(self); + } + + for i in bit_end..T::BITS as usize { + res.bits[i].zero(self); + } + } } impl FheUintPrepared { @@ -309,8 +312,15 @@ impl FheUintPrepared { { module.fhe_uint_prepare(self, other, key, scratch); } - pub fn prepare_partial(&mut self, module: &M, other: &FheUint, count: usize, key: &K, scratch: &mut Scratch) - where + pub fn prepare_partial( + &mut self, + module: &M, + other: &FheUint, + bit_start: usize, + bit_end: usize, + key: &K, + scratch: &mut Scratch, + ) where BRA: BlindRotationAlgo, O: DataRef, DK: DataRef, @@ -318,6 +328,6 @@ impl FheUintPrepared { M: FheUintPrepare, Scratch: ScratchTakeCore, { - module.fhe_uint_prepare_partial(self, other, count, key, scratch); - } + module.fhe_uint_prepare_custom(self, other, bit_start, bit_end, key, scratch); + } }