From 6924ffd94acde6d6baad149f7d84f4a3d67f67ea Mon Sep 17 00:00:00 2001 From: Pro7ech Date: Tue, 11 Nov 2025 21:24:39 +0100 Subject: [PATCH 1/3] Add prepare multi thread --- poulpy-core/src/scratch.rs | 23 ++- poulpy-hal/src/api/scratch.rs | 19 ++- poulpy-hal/src/layouts/mod.rs | 4 +- poulpy-hal/src/layouts/module.rs | 2 +- poulpy-schemes/Cargo.toml | 2 +- poulpy-schemes/benches/fhe_uint_prepare.rs | 126 +++++++++++++++ .../ciphertexts/fhe_uint_prepared.rs | 153 ++++++++++++------ .../ciphertexts/fhe_uint_prepared_debug.rs | 12 +- poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs | 18 ++- poulpy-schemes/src/tfhe/bdd_arithmetic/mod.rs | 2 +- .../bdd_arithmetic/tests/test_suite/add.rs | 2 +- .../bdd_arithmetic/tests/test_suite/and.rs | 2 +- .../bdd_arithmetic/tests/test_suite/mod.rs | 12 +- .../bdd_arithmetic/tests/test_suite/or.rs | 2 +- .../tests/test_suite/prepare.rs | 6 +- .../bdd_arithmetic/tests/test_suite/sll.rs | 2 +- .../bdd_arithmetic/tests/test_suite/slt.rs | 2 +- .../bdd_arithmetic/tests/test_suite/sltu.rs | 2 +- .../bdd_arithmetic/tests/test_suite/sra.rs | 2 +- .../bdd_arithmetic/tests/test_suite/srl.rs | 2 +- .../bdd_arithmetic/tests/test_suite/sub.rs | 2 +- .../bdd_arithmetic/tests/test_suite/xor.rs | 2 +- .../src/tfhe/blind_rotation/algorithms/mod.rs | 2 +- .../src/tfhe/blind_rotation/layouts/key.rs | 3 +- .../src/tfhe/circuit_bootstrapping/circuit.rs | 30 +++- .../src/tfhe/circuit_bootstrapping/key.rs | 9 ++ .../circuit_bootstrapping/key_prepared.rs | 4 + 27 files changed, 361 insertions(+), 86 deletions(-) create mode 100644 poulpy-schemes/benches/fhe_uint_prepare.rs diff --git a/poulpy-core/src/scratch.rs b/poulpy-core/src/scratch.rs index 2976214..06b8a04 100644 --- a/poulpy-core/src/scratch.rs +++ b/poulpy-core/src/scratch.rs @@ -1,5 +1,5 @@ use poulpy_hal::{ - api::{ModuleN, ScratchAvailable, ScratchTakeBasic, SvpPPolBytesOf, VecZnxDftBytesOf, VmpPMatBytesOf}, + api::{ModuleN, ScratchAvailable, ScratchFromBytes, ScratchTakeBasic, SvpPPolBytesOf, VecZnxDftBytesOf, VmpPMatBytesOf}, layouts::{Backend, Scratch}, }; @@ -7,7 +7,7 @@ use crate::{ dist::Distribution, layouts::{ Degree, GGLWE, GGLWEInfos, GGLWELayout, GGSW, GGSWInfos, GLWE, GLWEAutomorphismKey, GLWEInfos, GLWEPlaintext, - GLWEPrepared, GLWEPublicKey, GLWESecret, GLWESecretTensor, GLWESwitchingKey, GLWETensorKey, Rank, + GLWEPrepared, GLWEPublicKey, GLWESecret, GLWESecretTensor, GLWESwitchingKey, GLWETensorKey, LWE, LWEInfos, Rank, prepared::{ GGLWEPrepared, GGSWPrepared, GLWEAutomorphismKeyPrepared, GLWEPublicKeyPrepared, GLWESecretPrepared, GLWESwitchingKeyPrepared, GLWETensorKeyPrepared, @@ -17,8 +17,23 @@ use crate::{ pub trait ScratchTakeCore where - Self: ScratchTakeBasic + ScratchAvailable, + Self: ScratchTakeBasic + ScratchAvailable + ScratchFromBytes, { + fn take_lwe(&mut self, infos: &A) -> (LWE<&mut [u8]>, &mut Self) + where + A: LWEInfos, + { + let (data, scratch) = self.take_zn(infos.n().into(), 1, infos.size()); + ( + LWE { + k: infos.k(), + base2k: infos.base2k(), + data, + }, + scratch, + ) + } + fn take_glwe(&mut self, infos: &A) -> (GLWE<&mut [u8]>, &mut Self) where A: GLWEInfos, @@ -367,4 +382,4 @@ where } } -impl ScratchTakeCore for Scratch where Self: ScratchTakeBasic + ScratchAvailable {} +impl ScratchTakeCore for Scratch where Self: ScratchTakeBasic + ScratchAvailable + ScratchFromBytes {} diff --git a/poulpy-hal/src/api/scratch.rs b/poulpy-hal/src/api/scratch.rs index 9e3c484..051f52d 100644 --- a/poulpy-hal/src/api/scratch.rs +++ b/poulpy-hal/src/api/scratch.rs @@ -1,6 +1,6 @@ use crate::{ api::{ModuleN, SvpPPolBytesOf, VecZnxBigBytesOf, VecZnxDftBytesOf, VmpPMatBytesOf}, - layouts::{Backend, MatZnx, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}, + layouts::{Backend, MatZnx, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, Zn}, }; /// Allocates a new [crate::layouts::ScratchOwned] of `size` aligned bytes. @@ -28,7 +28,17 @@ pub trait TakeSlice { fn take_slice(&mut self, len: usize) -> (&mut [T], &mut Self); } -impl ScratchTakeBasic for Scratch where Self: TakeSlice {} +impl Scratch +where + Self: TakeSlice + ScratchFromBytes, +{ + pub fn split_at_mut(&mut self, len: usize) -> (&mut Scratch, &mut Self) { + let (take_slice, rem_slice) = self.take_slice(len); + (Self::from_bytes(take_slice), rem_slice) + } +} + +impl ScratchTakeBasic for Scratch where Self: TakeSlice + ScratchFromBytes {} pub trait ScratchTakeBasic where @@ -47,6 +57,11 @@ where (SvpPPol::from_data(take_slice, module.n(), cols), rem_slice) } + fn take_zn(&mut self, n: usize, cols: usize, size: usize) -> (Zn<&mut [u8]>, &mut Self) { + let (take_slice, rem_slice) = self.take_slice(Zn::bytes_of(n, cols, size)); + (Zn::from_data(take_slice, n, cols, size), rem_slice) + } + fn take_vec_znx(&mut self, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) { let (take_slice, rem_slice) = self.take_slice(VecZnx::bytes_of(n, cols, size)); (VecZnx::from_data(take_slice, n, cols, size), rem_slice) diff --git a/poulpy-hal/src/layouts/mod.rs b/poulpy-hal/src/layouts/mod.rs index 0553665..7d4600b 100644 --- a/poulpy-hal/src/layouts/mod.rs +++ b/poulpy-hal/src/layouts/mod.rs @@ -28,8 +28,8 @@ pub use zn::*; pub use znx_base::*; pub trait Data = PartialEq + Eq + Sized + Default; -pub trait DataRef = Data + AsRef<[u8]>; -pub trait DataMut = DataRef + AsMut<[u8]>; +pub trait DataRef = Data + AsRef<[u8]> + Sync; +pub trait DataMut = DataRef + AsMut<[u8]> + Send; pub trait ToOwnedDeep { type Owned; diff --git a/poulpy-hal/src/layouts/module.rs b/poulpy-hal/src/layouts/module.rs index 54e3ffa..2195e49 100644 --- a/poulpy-hal/src/layouts/module.rs +++ b/poulpy-hal/src/layouts/module.rs @@ -13,7 +13,7 @@ use crate::{ }; #[allow(clippy::missing_safety_doc)] -pub trait Backend: Sized { +pub trait Backend: Sized + Sync + Send { type ScalarBig: Copy + Zero + Display + Debug + Pod; type ScalarPrep: Copy + Zero + Display + Debug + Pod; type Handle: 'static; diff --git a/poulpy-schemes/Cargo.toml b/poulpy-schemes/Cargo.toml index 2c025c9..f4058d8 100644 --- a/poulpy-schemes/Cargo.toml +++ b/poulpy-schemes/Cargo.toml @@ -20,5 +20,5 @@ rand = "0.9.2" [[bench]] -name = "circuit_bootstrapping" +name = "fhe_uint_prepare" harness = false \ No newline at end of file diff --git a/poulpy-schemes/benches/fhe_uint_prepare.rs b/poulpy-schemes/benches/fhe_uint_prepare.rs new file mode 100644 index 0000000..d403fc2 --- /dev/null +++ b/poulpy-schemes/benches/fhe_uint_prepare.rs @@ -0,0 +1,126 @@ +use std::hint::black_box; + +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use poulpy_backend::{FFT64Avx, FFT64Ref}; +use poulpy_core::{ + GGSWNoise, GLWEDecrypt, GLWEEncryptSk, GLWENoise, ScratchTakeCore, + layouts::{GGSWLayout, GLWELayout, GLWESecretPreparedFactory, prepared::GLWESecretPrepared}, +}; +use poulpy_hal::{ + api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow}, + layouts::{Backend, Module, Scratch, ScratchOwned}, + source::Source, +}; +use rand::RngCore; + +use poulpy_schemes::tfhe::{ + bdd_arithmetic::{ + BDDKeyEncryptSk, BDDKeyPrepared, BDDKeyPreparedFactory, ExecuteBDDCircuit2WTo1W, FheUint, FheUintPrepare, + FheUintPrepareDebug, FheUintPrepared, FheUintPreparedEncryptSk, FheUintPreparedFactory, + tests::test_suite::TestContext, + }, + blind_rotation::{BlindRotationAlgo, BlindRotationKey, BlindRotationKeyFactory, CGGI}, +}; + +pub fn benc_bdd_prepare( + c: &mut Criterion, + label: &str, + test_context: &TestContext, +) where + Module: ModuleNew + + GLWESecretPreparedFactory + + GLWEDecrypt + + GLWENoise + + FheUintPreparedFactory + + FheUintPreparedEncryptSk + + FheUintPrepareDebug + + BDDKeyEncryptSk + + BDDKeyPreparedFactory + + GGSWNoise + + FheUintPrepare + + ExecuteBDDCircuit2WTo1W + + GLWEEncryptSk, + BlindRotationKey, BRA>: BlindRotationKeyFactory, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchTakeCore, +{ + let group_name: String = format!("bdd_prepare::{label}"); + + let mut group = c.benchmark_group(group_name); + + fn runner(test_context: &TestContext) -> impl FnMut() + where + Module: ModuleNew + + GLWESecretPreparedFactory + + GLWEDecrypt + + GLWENoise + + FheUintPreparedFactory + + FheUintPreparedEncryptSk + + FheUintPrepareDebug + + BDDKeyEncryptSk + + BDDKeyPreparedFactory + + GGSWNoise + + FheUintPrepare + + ExecuteBDDCircuit2WTo1W + + GLWEEncryptSk, + BlindRotationKey, BRA>: BlindRotationKeyFactory, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchTakeCore, + { + let glwe_infos: GLWELayout = test_context.glwe_infos(); + let ggsw_infos: GGSWLayout = test_context.ggsw_infos(); + + let module: &Module = &test_context.module; + let sk_glwe_prep: &GLWESecretPrepared, BE> = &test_context.sk_glwe; + let bdd_key_prepared: &BDDKeyPrepared, BRA, BE> = &test_context.bdd_key; + + let mut source: Source = Source::new([6u8; 32]); + + let mut source_xa: Source = Source::new([2u8; 32]); + let mut source_xe: Source = Source::new([3u8; 32]); + + let threads = 1; + + let mut scratch: ScratchOwned = ScratchOwned::alloc((1 << 22) * threads); + + // GLWE(value) + let mut c_enc: FheUint, u32> = FheUint::alloc_from_infos(&glwe_infos); + let value: u32 = source.next_u32(); + c_enc.encrypt_sk( + module, + value, + sk_glwe_prep, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + // GGSW(0) + let mut c_enc_prep: FheUintPrepared, u32, BE> = + FheUintPrepared::, u32, BE>::alloc_from_infos(module, &ggsw_infos); + + // GGSW(value) + move || { + c_enc_prep.prepare_custom_multi_thread(threads, module, &c_enc, 0, 32, bdd_key_prepared, scratch.borrow()); + black_box(()); + } + } + + let id: BenchmarkId = BenchmarkId::from_parameter(format!("n_glwe: {}", test_context.module.n())); + let mut runner = runner::(test_context); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + + group.finish(); +} + +fn bench_bdd_prepare_cpu_ref_fft64(c: &mut Criterion) { + benc_bdd_prepare::( + c, + "bdd_prepare_fft64_ref", + &TestContext::::new(), + ); +} + +criterion_group!(benches, bench_bdd_prepare_cpu_ref_fft64,); + +criterion_main!(benches); diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared.rs index 32a03c5..07d6c46 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared.rs @@ -1,15 +1,17 @@ use std::marker::PhantomData; +use std::thread; use poulpy_core::layouts::{ Base2K, Dnum, Dsize, GGSWInfos, GGSWPreparedFactory, GLWEInfos, LWEInfos, Rank, TorusPrecision, prepared::GGSWPrepared, }; use poulpy_core::layouts::{ - GGLWEInfos, GGLWEPreparedToRef, GGSWPreparedToMut, GGSWPreparedToRef, GLWEAutomorphismKeyHelper, GetGaloisElement, LWE, + GGLWEInfos, GGLWEPreparedToRef, GGSW, GGSWLayout, GGSWPreparedToMut, GGSWPreparedToRef, GLWEAutomorphismKeyHelper, + GetGaloisElement, LWE, }; use poulpy_core::{GLWECopy, GLWEDecrypt, GLWEPacking, LWEFromGLWE}; use poulpy_core::{GGSWEncryptSk, ScratchTakeCore, layouts::GLWESecretPreparedToRef}; -use poulpy_hal::api::ModuleLogN; +use poulpy_hal::api::{ModuleLogN, ScratchAvailable, ScratchFromBytes}; use poulpy_hal::layouts::{Backend, Data, DataRef, Module}; use poulpy_hal::{ @@ -21,7 +23,7 @@ use poulpy_hal::{ use crate::tfhe::bdd_arithmetic::{BDDKey, BDDKeyHelper, BDDKeyInfos, BDDKeyPrepared, BDDKeyPreparedFactory, FheUint, ToBits}; use crate::tfhe::bdd_arithmetic::{Cmux, FromBits, ScratchTakeBDD, UnsignedInteger}; use crate::tfhe::blind_rotation::BlindRotationAlgo; -use crate::tfhe::circuit_bootstrapping::CirtuitBootstrappingExecute; +use crate::tfhe::circuit_bootstrapping::{CircuitBootstrappingKeyInfos, CirtuitBootstrappingExecute}; /// A prepared FHE ciphertext encrypting the bits of an [UnsignedInteger]. pub struct FheUintPrepared { @@ -219,12 +221,13 @@ impl BDDKeyPrepared } } -pub trait FheUintPrepare { - fn fhe_uint_prepare_tmp_bytes(&self, block_size: usize, extension_factor: usize, res_infos: &R, infos: &A) -> usize +pub trait FheUintPrepare { + fn fhe_uint_prepare_tmp_bytes(&self, block_size: usize, extension_factor: usize, res_infos: &R, bits_infos: &A, bdd_infos: &B) -> usize where R: GGSWInfos, - A: BDDKeyInfos; - fn fhe_uint_prepare( + A: GLWEInfos, + B: BDDKeyInfos; + fn fhe_uint_prepare( &self, res: &mut FheUintPrepared, bits: &FheUint, @@ -234,79 +237,119 @@ pub trait FheUintPrepare; - fn fhe_uint_prepare_custom( + K: BDDKeyHelper + BDDKeyInfos, + Scratch: ScratchFromBytes, + { + self.fhe_uint_prepare_custom(res, bits, 0, T::BITS as usize, key, scratch); + } + fn fhe_uint_prepare_custom( &self, res: &mut FheUintPrepared, bits: &FheUint, bit_start: usize, - bit_end: usize, + bit_count: usize, key: &K, scratch: &mut Scratch, ) where DM: DataMut, DB: DataRef, DK: DataRef, - K: BDDKeyHelper; + K: BDDKeyHelper + BDDKeyInfos, + { + self.fhe_uint_prepare_custom_multi_thread(1, res, bits, bit_start, bit_count, key, scratch) + } + fn fhe_uint_prepare_custom_multi_thread( + &self, + threads: usize, + res: &mut FheUintPrepared, + bits: &FheUint, + bit_start: usize, + bit_count: usize, + key: &K, + scratch: &mut Scratch, + ) where + DM: DataMut, + DB: DataRef, + DK: DataRef, + K: BDDKeyHelper + BDDKeyInfos; } -impl FheUintPrepare for Module +impl FheUintPrepare for Module where Self: LWEFromGLWE + CirtuitBootstrappingExecute + GGSWPreparedFactory, Scratch: ScratchTakeCore, { - fn fhe_uint_prepare_tmp_bytes(&self, block_size: usize, extension_factor: usize, res_infos: &R, bdd_infos: &A) -> usize + fn fhe_uint_prepare_tmp_bytes(&self, block_size: usize, extension_factor: usize, res_infos: &R, bits_infos: &A, bdd_infos: &B) -> usize where R: GGSWInfos, - A: BDDKeyInfos, + A: GLWEInfos, + B: BDDKeyInfos, { self.circuit_bootstrapping_execute_tmp_bytes( block_size, extension_factor, res_infos, &bdd_infos.cbt_infos(), - ) + ) + GGSW::bytes_of_from_infos(res_infos) + + LWE::bytes_of_from_infos(bits_infos) } - fn fhe_uint_prepare( - &self, - res: &mut FheUintPrepared, - bits: &FheUint, - key: &K, - scratch: &mut Scratch, - ) where - DM: DataMut, - DB: DataRef, - DK: DataRef, - K: BDDKeyHelper, - { - self.fhe_uint_prepare_custom(res, bits, 0, T::BITS as usize, key, scratch); - } - - fn fhe_uint_prepare_custom( + fn fhe_uint_prepare_custom_multi_thread( &self, + threads: usize, res: &mut FheUintPrepared, bits: &FheUint, bit_start: usize, - bit_end: usize, + bit_count: usize, key: &K, - scratch: &mut Scratch, + mut scratch: &mut Scratch, ) where DM: DataMut, DB: DataRef, DK: DataRef, - K: BDDKeyHelper, + K: BDDKeyHelper + BDDKeyInfos, { + let bit_end = bit_start + bit_count; let (cbt, ks) = key.get_cbt_key(); - let mut lwe: LWE> = LWE::alloc_from_infos(bits); //TODO: add TakeLWE - let (mut tmp_ggsw, scratch_1) = scratch.take_ggsw(res); - for (bit, dst) in res.bits[bit_start..bit_end].iter_mut().enumerate() { - // TODO: set the rest of the bits to a prepared zero GGSW - bits.get_bit_lwe(self, bit, &mut lwe, ks, scratch_1); - cbt.execute_to_constant(self, &mut tmp_ggsw, &lwe, 1, 1, scratch_1); - dst.prepare(self, &tmp_ggsw, scratch_1); + assert!(bit_end <= T::BITS as usize); + + let scratch_thread_size = self.fhe_uint_prepare_tmp_bytes(cbt.block_size(), 1, res, bits, key); + + assert!(scratch.available() >= threads * scratch_thread_size); + + // How many bits we need to process + let chunk_size: usize = bit_count.div_ceil(threads); // ceil division + + let mut scratches = Vec::new(); + for _ in 0..(threads - 1) { + let (tmp, scratch_new) = scratch.split_at_mut(scratch_thread_size); + scratch = scratch_new; + scratches.push(tmp); } + scratches.push(scratch); + + let ggsw_infos: &GGSWLayout = &res.ggsw_layout(); + + thread::scope(|scope| { + for (thread_index, (scratch_thread, res_bits_chunk)) in scratches + .iter_mut() + .zip(res.bits[bit_start..bit_end].chunks_mut(chunk_size)) + .enumerate() + { + let start: usize = bit_start + thread_index * chunk_size; + + scope.spawn(move || { + let (mut tmp_ggsw, scratch_1) = scratch_thread.take_ggsw(ggsw_infos); + let (mut tmp_lwe, scratch_2) = scratch_1.take_lwe(bits); + for (local_bit, dst) in res_bits_chunk.iter_mut().enumerate() { + bits.get_bit_lwe(self, start + local_bit, &mut tmp_lwe, ks, scratch_2); + cbt.execute_to_constant(self, &mut tmp_ggsw, &tmp_lwe, 1, 1, scratch_2); + dst.prepare(self, &tmp_ggsw, scratch_2); + } + }); + } + }); for i in 0..bit_start { res.bits[i].zero(self); @@ -324,8 +367,8 @@ impl FheUintPrepared { BRA: BlindRotationAlgo, O: DataRef, DK: DataRef, - K: BDDKeyHelper, - M: FheUintPrepare, + K: BDDKeyHelper + BDDKeyInfos, + M: FheUintPrepare, Scratch: ScratchTakeCore, { module.fhe_uint_prepare(self, other, key, scratch); @@ -342,10 +385,30 @@ impl FheUintPrepared { BRA: BlindRotationAlgo, O: DataRef, DK: DataRef, - K: BDDKeyHelper, - M: FheUintPrepare, + K: BDDKeyHelper + BDDKeyInfos, + M: FheUintPrepare, Scratch: ScratchTakeCore, { module.fhe_uint_prepare_custom(self, other, bit_start, bit_end, key, scratch); } + + pub fn prepare_custom_multi_thread( + &mut self, + threads: usize, + module: &M, + other: &FheUint, + bit_start: usize, + bit_end: usize, + key: &K, + scratch: &mut Scratch, + ) where + BRA: BlindRotationAlgo, + O: DataRef, + DK: DataRef, + K: BDDKeyHelper + BDDKeyInfos, + M: FheUintPrepare, + Scratch: ScratchTakeCore, + { + module.fhe_uint_prepare_custom_multi_thread(threads, self, other, bit_start, bit_end, key, scratch); + } } diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared_debug.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared_debug.rs index cd7b7db..182365b 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared_debug.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared_debug.rs @@ -10,7 +10,7 @@ use poulpy_core::layouts::{Base2K, Dnum, Dsize, Rank, TorusPrecision}; use poulpy_core::layouts::{GGSW, GLWESecretPreparedToRef}; use poulpy_core::{ LWEFromGLWE, ScratchTakeCore, - layouts::{GGSWInfos, GGSWPreparedFactory, GLWEInfos, LWE, LWEInfos}, + layouts::{GGSWInfos, GGSWPreparedFactory, GLWEInfos, LWEInfos}, }; use poulpy_hal::api::ModuleN; @@ -124,11 +124,13 @@ where DM: DataMut, DR0: DataRef, DR1: DataRef, - { - let mut lwe: LWE> = LWE::alloc_from_infos(bits); //TODO: add TakeLWE + { + + let (_, scratch_1) = scratch.take_ggsw(res); + let (mut tmp_lwe, scratch_2) = scratch_1.take_lwe(bits); for (bit, dst) in res.bits.iter_mut().enumerate() { - bits.get_bit_lwe(self, bit, &mut lwe, &key.ks, scratch); - key.cbt.execute_to_constant(self, dst, &lwe, 1, 1, scratch); + bits.get_bit_lwe(self, bit, &mut tmp_lwe, &key.ks, scratch_2); + key.cbt.execute_to_constant(self, dst, &tmp_lwe, 1, 1, scratch_2); } } } diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs index 37b89ed..dbd52cf 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs @@ -1,4 +1,5 @@ use crate::tfhe::bdd_arithmetic::FheUintPreparedDebug; +use crate::tfhe::circuit_bootstrapping::CircuitBootstrappingKeyInfos; use crate::tfhe::{ bdd_arithmetic::{FheUint, UnsignedInteger}, blind_rotation::{BlindRotationAlgo, BlindRotationKey, BlindRotationKeyFactory}, @@ -8,7 +9,7 @@ use crate::tfhe::{ }, }; -use poulpy_core::layouts::{GLWEAutomorphismKeyHelper, GLWEAutomorphismKeyPrepared}; +use poulpy_core::layouts::{GGLWEInfos, GLWEAutomorphismKeyHelper, GLWEAutomorphismKeyPrepared}; use poulpy_core::{ GLWEToLWESwitchingKeyEncryptSk, GetDistribution, ScratchTakeCore, layouts::{ @@ -135,6 +136,21 @@ where pub(crate) ks: GLWEToLWEKeyPrepared, } +impl BDDKeyInfos for BDDKeyPrepared{ + fn cbt_infos(&self) -> CircuitBootstrappingKeyLayout { + CircuitBootstrappingKeyLayout { layout_brk: self.cbt.brk_infos(), layout_atk: self.cbt.atk_infos(), layout_tsk: self.cbt.tsk_infos() } + } + fn ks_infos(&self) -> GLWEToLWEKeyLayout { + GLWEToLWEKeyLayout{ + n: self.ks.n(), + base2k: self.ks.base2k(), + k: self.ks.k(), + rank_in: self.ks.rank_in(), + dnum: self.ks.dnum() + } + } +} + impl GLWEAutomorphismKeyHelper, BE> for BDDKeyPrepared { diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/mod.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/mod.rs index e3fb490..7b61bf2 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/mod.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/mod.rs @@ -16,7 +16,7 @@ pub use key::*; pub mod tests; -pub trait UnsignedInteger: Copy + 'static { +pub trait UnsignedInteger: Copy + Sync + Send + 'static { const BITS: u32; const LOG_BITS: u32; const LOG_BYTES: u32; diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/add.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/add.rs index 683c919..7810dd3 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/add.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/add.rs @@ -30,7 +30,7 @@ where + BDDKeyEncryptSk + BDDKeyPreparedFactory + GGSWNoise - + FheUintPrepare + + FheUintPrepare + ExecuteBDDCircuit2WTo1W + GLWEEncryptSk, BlindRotationKey, BRA>: BlindRotationKeyFactory, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/and.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/and.rs index 928d83b..ee468f6 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/and.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/and.rs @@ -30,7 +30,7 @@ where + BDDKeyEncryptSk + BDDKeyPreparedFactory + GGSWNoise - + FheUintPrepare + + FheUintPrepare + ExecuteBDDCircuit2WTo1W + GLWEEncryptSk, BlindRotationKey, BRA>: BlindRotationKeyFactory, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs index 59fd529..488967a 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs @@ -75,6 +75,14 @@ where } impl TestContext { + pub fn glwe_infos(&self) -> GLWELayout { + TEST_GLWE_INFOS + } + + pub fn ggsw_infos(&self) -> GGSWLayout { + TEST_GGSW_INFOS + } + pub fn new() -> Self where Module: ModuleNew @@ -125,8 +133,8 @@ impl TestContext { } } -pub(crate) const TEST_N_GLWE: u32 = 256; -pub(crate) const TEST_N_LWE: u32 = 77; +pub(crate) const TEST_N_GLWE: u32 = 1024; +pub(crate) const TEST_N_LWE: u32 = 574; pub(crate) const TEST_BASE2K: u32 = 13; pub(crate) const TEST_K_GLWE: u32 = 26; pub(crate) const TEST_K_GGSW: u32 = 39; diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/or.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/or.rs index 69b773c..91ce50f 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/or.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/or.rs @@ -30,7 +30,7 @@ where + BDDKeyEncryptSk + BDDKeyPreparedFactory + GGSWNoise - + FheUintPrepare + + FheUintPrepare + ExecuteBDDCircuit2WTo1W + GLWEEncryptSk, BlindRotationKey, BRA>: BlindRotationKeyFactory, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/prepare.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/prepare.rs index 7da039e..b7ed299 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/prepare.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/prepare.rs @@ -30,7 +30,7 @@ where + BDDKeyEncryptSk + BDDKeyPreparedFactory + GGSWNoise - + FheUintPrepare + + FheUintPrepare + ExecuteBDDCircuit2WTo1W + GLWEEncryptSk, BlindRotationKey, BRA>: BlindRotationKeyFactory, @@ -67,8 +67,10 @@ where let mut c_enc_prep_debug: FheUintPreparedDebug, u32> = FheUintPreparedDebug::, u32>::alloc_from_infos(module, &ggsw_infos); + let mut scratch_2 = ScratchOwned::alloc(module.fhe_uint_prepare_tmp_bytes(7, 1, &c_enc_prep_debug, &c_enc, bdd_key_prepared)); + // GGSW(value) - c_enc_prep_debug.prepare(module, &c_enc, bdd_key_prepared, scratch.borrow()); + c_enc_prep_debug.prepare(module, &c_enc, bdd_key_prepared, scratch_2.borrow()); let max_noise = |col_i: usize| { let mut noise: f64 = -(ggsw_infos.size() as f64 * TEST_BASE2K as f64) + SIGMA.log2() + 2.0; diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sll.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sll.rs index 9c47883..b37d816 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sll.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sll.rs @@ -30,7 +30,7 @@ where + BDDKeyEncryptSk + BDDKeyPreparedFactory + GGSWNoise - + FheUintPrepare + + FheUintPrepare + ExecuteBDDCircuit2WTo1W + GLWEEncryptSk, BlindRotationKey, BRA>: BlindRotationKeyFactory, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/slt.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/slt.rs index fc9b5e1..dca9f6f 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/slt.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/slt.rs @@ -30,7 +30,7 @@ where + BDDKeyEncryptSk + BDDKeyPreparedFactory + GGSWNoise - + FheUintPrepare + + FheUintPrepare + ExecuteBDDCircuit2WTo1W + GLWEEncryptSk, BlindRotationKey, BRA>: BlindRotationKeyFactory, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sltu.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sltu.rs index 444da03..aa7317d 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sltu.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sltu.rs @@ -30,7 +30,7 @@ where + BDDKeyEncryptSk + BDDKeyPreparedFactory + GGSWNoise - + FheUintPrepare + + FheUintPrepare + ExecuteBDDCircuit2WTo1W + GLWEEncryptSk, BlindRotationKey, BRA>: BlindRotationKeyFactory, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sra.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sra.rs index abb4269..730c539 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sra.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sra.rs @@ -30,7 +30,7 @@ where + BDDKeyEncryptSk + BDDKeyPreparedFactory + GGSWNoise - + FheUintPrepare + + FheUintPrepare + ExecuteBDDCircuit2WTo1W + GLWEEncryptSk, BlindRotationKey, BRA>: BlindRotationKeyFactory, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/srl.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/srl.rs index d9087fd..61c3fb6 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/srl.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/srl.rs @@ -30,7 +30,7 @@ where + BDDKeyEncryptSk + BDDKeyPreparedFactory + GGSWNoise - + FheUintPrepare + + FheUintPrepare + ExecuteBDDCircuit2WTo1W + GLWEEncryptSk, BlindRotationKey, BRA>: BlindRotationKeyFactory, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sub.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sub.rs index 6df0c92..3a73fbf 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sub.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sub.rs @@ -30,7 +30,7 @@ where + BDDKeyEncryptSk + BDDKeyPreparedFactory + GGSWNoise - + FheUintPrepare + + FheUintPrepare + ExecuteBDDCircuit2WTo1W + GLWEEncryptSk, BlindRotationKey, BRA>: BlindRotationKeyFactory, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/xor.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/xor.rs index e035246..40c05b8 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/xor.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/xor.rs @@ -30,7 +30,7 @@ where + BDDKeyEncryptSk + BDDKeyPreparedFactory + GGSWNoise - + FheUintPrepare + + FheUintPrepare + ExecuteBDDCircuit2WTo1W + GLWEEncryptSk, BlindRotationKey, BRA>: BlindRotationKeyFactory, diff --git a/poulpy-schemes/src/tfhe/blind_rotation/algorithms/mod.rs b/poulpy-schemes/src/tfhe/blind_rotation/algorithms/mod.rs index f25fb51..0cc4859 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/algorithms/mod.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/algorithms/mod.rs @@ -11,7 +11,7 @@ use poulpy_hal::layouts::{Backend, DataMut, DataRef, Scratch, ZnxView}; use crate::tfhe::blind_rotation::{BlindRotationKeyInfos, BlindRotationKeyPrepared, LookUpTableRotationDirection, LookupTable}; -pub trait BlindRotationAlgo {} +pub trait BlindRotationAlgo: Sync {} pub trait BlindRotationExecute { fn blind_rotation_execute_tmp_bytes( diff --git a/poulpy-schemes/src/tfhe/blind_rotation/layouts/key.rs b/poulpy-schemes/src/tfhe/blind_rotation/layouts/key.rs index 182c973..e8748be 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/layouts/key.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/layouts/key.rs @@ -188,8 +188,7 @@ impl BlindRotationKeyInfos for BlindRotation } impl BlindRotationKey { - #[allow(dead_code)] - fn block_size(&self) -> usize { + pub fn block_size(&self) -> usize { match self.dist { Distribution::BinaryBlock(value) => value, _ => 1, diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs index 83be647..17371b1 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs @@ -1,15 +1,14 @@ use std::collections::HashMap; use poulpy_hal::{ - api::{ModuleLogN, ModuleN, ScratchOwnedAlloc, ScratchOwnedBorrow}, + api::{ModuleLogN, ModuleN, ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow}, layouts::{Backend, DataRef, Module, Scratch, ScratchOwned, ToOwnedDeep}, }; use poulpy_core::{ GGSWFromGGLWE, GLWEDecrypt, GLWEPacking, GLWERotate, GLWETrace, ScratchTakeCore, layouts::{ - Dsize, GGLWEInfos, GGLWELayout, GGLWEPreparedToRef, GGSWInfos, GGSWToMut, GLWEAutomorphismKeyHelper, GLWEInfos, - GLWESecretPreparedFactory, GLWEToMut, GLWEToRef, GetGaloisElement, LWEInfos, LWEToRef, + Dsize, GGLWE, GGLWEInfos, GGLWELayout, GGLWEPreparedToRef, GGSWInfos, GGSWToMut, GLWEAutomorphismKeyHelper, GLWEInfos, GLWESecretPreparedFactory, GLWEToMut, GLWEToRef, GetGaloisElement, LWEInfos, LWEToRef, Rank }, }; @@ -132,6 +131,17 @@ where R: GGSWInfos, A: CircuitBootstrappingKeyInfos, { + + let gglwe_infos: GGLWELayout = GGLWELayout { + n: res_infos.n(), + base2k: res_infos.base2k(), + k: res_infos.k(), + dnum: res_infos.dnum(), + dsize: Dsize(1), + rank_in: res_infos.rank().max(Rank(1)).into(), + rank_out: res_infos.rank(), + }; + self.blind_rotation_execute_tmp_bytes( block_size, extension_factor, @@ -139,7 +149,7 @@ where &cbt_infos.brk_infos(), ) .max(self.glwe_trace_tmp_bytes(res_infos, res_infos, &cbt_infos.atk_infos())) - .max(self.ggsw_from_gglwe_tmp_bytes(res_infos, &cbt_infos.tsk_infos())) + .max(self.ggsw_from_gglwe_tmp_bytes(res_infos, &cbt_infos.tsk_infos())) + GLWE::bytes_of_from_infos(res_infos) + GGLWE::bytes_of_from_infos(&gglwe_infos) } fn circuit_bootstrapping_execute_to_constant( @@ -154,7 +164,10 @@ where R: GGSWToMut + GGSWInfos, L: LWEToRef + LWEInfos, D: DataRef, - { + { + + assert!(scratch.available() >= self.circuit_bootstrapping_execute_tmp_bytes(key.block_size(), extension_factor, res, key)); + circuit_bootstrap_core( false, self, @@ -181,7 +194,10 @@ where R: GGSWToMut + GGSWInfos, L: LWEToRef + LWEInfos, D: DataRef, - { + { + + assert!(scratch.available() >= self.circuit_bootstrapping_execute_tmp_bytes(key.block_size(), extension_factor, res, key)); + circuit_bootstrap_core( true, self, @@ -223,7 +239,7 @@ pub fn circuit_bootstrap_core( + ModuleLogN, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchTakeCore, -{ +{ let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); let lwe: &LWE<&[u8]> = &lwe.to_ref(); diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/key.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/key.rs index 33da31e..633f080 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/key.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/key.rs @@ -19,6 +19,7 @@ use crate::tfhe::blind_rotation::{ }; pub trait CircuitBootstrappingKeyInfos { + fn block_size(&self) -> usize; fn brk_infos(&self) -> BlindRotationKeyLayout; fn atk_infos(&self) -> GLWEAutomorphismKeyLayout; fn tsk_infos(&self) -> GGLWEToGGSWKeyLayout; @@ -32,6 +33,10 @@ pub struct CircuitBootstrappingKeyLayout { } impl CircuitBootstrappingKeyInfos for CircuitBootstrappingKeyLayout { + fn block_size(&self) -> usize { + unimplemented!("unimplemented for CircuitBootstrappingKeyLayout") + } + fn atk_infos(&self) -> GLWEAutomorphismKeyLayout { self.layout_atk } @@ -164,6 +169,10 @@ where } impl CircuitBootstrappingKeyInfos for CircuitBootstrappingKey { + fn block_size(&self) -> usize { + self.brk.block_size() + } + fn atk_infos(&self) -> GLWEAutomorphismKeyLayout { let (_, atk) = self.atk.iter().next().expect("atk is empty"); GLWEAutomorphismKeyLayout { diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/key_prepared.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/key_prepared.rs index e5fca47..ee071bd 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/key_prepared.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/key_prepared.rs @@ -122,6 +122,10 @@ impl GLWEAutomorphismKeyHelper< } impl CircuitBootstrappingKeyInfos for CircuitBootstrappingKeyPrepared { + fn block_size(&self) -> usize { + self.brk.block_size() + } + fn atk_infos(&self) -> GLWEAutomorphismKeyLayout { let (_, atk) = self.atk.iter().next().expect("atk is empty"); GLWEAutomorphismKeyLayout { From 1423de1c46bc3d81099f75c5e8473ebe185b06e0 Mon Sep 17 00:00:00 2001 From: Pro7ech Date: Wed, 12 Nov 2025 11:02:37 +0100 Subject: [PATCH 2/3] Add multi-thread bdd eval --- Cargo.lock | 7 + poulpy-hal/src/api/scratch.rs | 14 +- poulpy-schemes/Cargo.toml | 4 +- poulpy-schemes/benches/fhe_uint_prepare.rs | 126 ----------- .../src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs | 149 +++++++++---- .../ciphertexts/fhe_uint_prepared.rs | 35 +-- .../ciphertexts/fhe_uint_prepared_debug.rs | 6 +- .../src/tfhe/bdd_arithmetic/eval.rs | 204 +++++++++++++----- poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs | 12 +- .../bdd_arithmetic/tests/test_suite/add.rs | 2 +- .../bdd_arithmetic/tests/test_suite/and.rs | 2 +- .../bdd_arithmetic/tests/test_suite/mod.rs | 4 +- .../bdd_arithmetic/tests/test_suite/or.rs | 2 +- .../tests/test_suite/prepare.rs | 2 +- .../bdd_arithmetic/tests/test_suite/sll.rs | 2 +- .../bdd_arithmetic/tests/test_suite/slt.rs | 2 +- .../bdd_arithmetic/tests/test_suite/sltu.rs | 2 +- .../bdd_arithmetic/tests/test_suite/sra.rs | 2 +- .../bdd_arithmetic/tests/test_suite/srl.rs | 2 +- .../bdd_arithmetic/tests/test_suite/sub.rs | 2 +- .../bdd_arithmetic/tests/test_suite/xor.rs | 2 +- .../src/tfhe/circuit_bootstrapping/circuit.rs | 26 ++- 22 files changed, 336 insertions(+), 273 deletions(-) delete mode 100644 poulpy-schemes/benches/fhe_uint_prepare.rs diff --git a/Cargo.lock b/Cargo.lock index d0d2744..ebc0f43 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -323,6 +323,12 @@ version = "11.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + [[package]] name = "plotters" version = "0.3.7" @@ -406,6 +412,7 @@ dependencies = [ "byteorder", "criterion", "itertools 0.14.0", + "paste", "poulpy-backend", "poulpy-core", "poulpy-hal", diff --git a/poulpy-hal/src/api/scratch.rs b/poulpy-hal/src/api/scratch.rs index 051f52d..aef02e1 100644 --- a/poulpy-hal/src/api/scratch.rs +++ b/poulpy-hal/src/api/scratch.rs @@ -30,12 +30,24 @@ pub trait TakeSlice { impl Scratch where - Self: TakeSlice + ScratchFromBytes, + Self: TakeSlice + ScratchAvailable + ScratchFromBytes, { pub fn split_at_mut(&mut self, len: usize) -> (&mut Scratch, &mut Self) { let (take_slice, rem_slice) = self.take_slice(len); (Self::from_bytes(take_slice), rem_slice) } + + pub fn split_mut(&mut self, n: usize, len: usize) -> (Vec<&mut Scratch>, &mut Self) { + assert!(self.available() >= n * len); + let mut scratches: Vec<&mut Scratch> = Vec::with_capacity(n); + let mut scratch: &mut Scratch = self; + for _ in 0..n { + let (tmp, scratch_new) = scratch.split_at_mut(len); + scratch = scratch_new; + scratches.push(tmp); + } + (scratches, scratch) + } } impl ScratchTakeBasic for Scratch where Self: TakeSlice + ScratchFromBytes {} diff --git a/poulpy-schemes/Cargo.toml b/poulpy-schemes/Cargo.toml index f4058d8..bf1ec39 100644 --- a/poulpy-schemes/Cargo.toml +++ b/poulpy-schemes/Cargo.toml @@ -17,8 +17,8 @@ criterion = {workspace = true} itertools = "0.14.0" byteorder = "1.5.0" rand = "0.9.2" - +paste = "1.0.15" [[bench]] -name = "fhe_uint_prepare" +name = "circuit_bootstrapping" harness = false \ No newline at end of file diff --git a/poulpy-schemes/benches/fhe_uint_prepare.rs b/poulpy-schemes/benches/fhe_uint_prepare.rs deleted file mode 100644 index d403fc2..0000000 --- a/poulpy-schemes/benches/fhe_uint_prepare.rs +++ /dev/null @@ -1,126 +0,0 @@ -use std::hint::black_box; - -use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; -use poulpy_backend::{FFT64Avx, FFT64Ref}; -use poulpy_core::{ - GGSWNoise, GLWEDecrypt, GLWEEncryptSk, GLWENoise, ScratchTakeCore, - layouts::{GGSWLayout, GLWELayout, GLWESecretPreparedFactory, prepared::GLWESecretPrepared}, -}; -use poulpy_hal::{ - api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow}, - layouts::{Backend, Module, Scratch, ScratchOwned}, - source::Source, -}; -use rand::RngCore; - -use poulpy_schemes::tfhe::{ - bdd_arithmetic::{ - BDDKeyEncryptSk, BDDKeyPrepared, BDDKeyPreparedFactory, ExecuteBDDCircuit2WTo1W, FheUint, FheUintPrepare, - FheUintPrepareDebug, FheUintPrepared, FheUintPreparedEncryptSk, FheUintPreparedFactory, - tests::test_suite::TestContext, - }, - blind_rotation::{BlindRotationAlgo, BlindRotationKey, BlindRotationKeyFactory, CGGI}, -}; - -pub fn benc_bdd_prepare( - c: &mut Criterion, - label: &str, - test_context: &TestContext, -) where - Module: ModuleNew - + GLWESecretPreparedFactory - + GLWEDecrypt - + GLWENoise - + FheUintPreparedFactory - + FheUintPreparedEncryptSk - + FheUintPrepareDebug - + BDDKeyEncryptSk - + BDDKeyPreparedFactory - + GGSWNoise - + FheUintPrepare - + ExecuteBDDCircuit2WTo1W - + GLWEEncryptSk, - BlindRotationKey, BRA>: BlindRotationKeyFactory, - ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, - Scratch: ScratchTakeCore, -{ - let group_name: String = format!("bdd_prepare::{label}"); - - let mut group = c.benchmark_group(group_name); - - fn runner(test_context: &TestContext) -> impl FnMut() - where - Module: ModuleNew - + GLWESecretPreparedFactory - + GLWEDecrypt - + GLWENoise - + FheUintPreparedFactory - + FheUintPreparedEncryptSk - + FheUintPrepareDebug - + BDDKeyEncryptSk - + BDDKeyPreparedFactory - + GGSWNoise - + FheUintPrepare - + ExecuteBDDCircuit2WTo1W - + GLWEEncryptSk, - BlindRotationKey, BRA>: BlindRotationKeyFactory, - ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, - Scratch: ScratchTakeCore, - { - let glwe_infos: GLWELayout = test_context.glwe_infos(); - let ggsw_infos: GGSWLayout = test_context.ggsw_infos(); - - let module: &Module = &test_context.module; - let sk_glwe_prep: &GLWESecretPrepared, BE> = &test_context.sk_glwe; - let bdd_key_prepared: &BDDKeyPrepared, BRA, BE> = &test_context.bdd_key; - - let mut source: Source = Source::new([6u8; 32]); - - let mut source_xa: Source = Source::new([2u8; 32]); - let mut source_xe: Source = Source::new([3u8; 32]); - - let threads = 1; - - let mut scratch: ScratchOwned = ScratchOwned::alloc((1 << 22) * threads); - - // GLWE(value) - let mut c_enc: FheUint, u32> = FheUint::alloc_from_infos(&glwe_infos); - let value: u32 = source.next_u32(); - c_enc.encrypt_sk( - module, - value, - sk_glwe_prep, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - // GGSW(0) - let mut c_enc_prep: FheUintPrepared, u32, BE> = - FheUintPrepared::, u32, BE>::alloc_from_infos(module, &ggsw_infos); - - // GGSW(value) - move || { - c_enc_prep.prepare_custom_multi_thread(threads, module, &c_enc, 0, 32, bdd_key_prepared, scratch.borrow()); - black_box(()); - } - } - - let id: BenchmarkId = BenchmarkId::from_parameter(format!("n_glwe: {}", test_context.module.n())); - let mut runner = runner::(test_context); - group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); - - group.finish(); -} - -fn bench_bdd_prepare_cpu_ref_fft64(c: &mut Criterion) { - benc_bdd_prepare::( - c, - "bdd_prepare_fft64_ref", - &TestContext::::new(), - ); -} - -criterion_group!(benches, bench_bdd_prepare_cpu_ref_fft64,); - -criterion_main!(benches); diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs index b56f907..6cd65b1 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs @@ -13,17 +13,14 @@ use crate::tfhe::bdd_arithmetic::{ BitSize, ExecuteBDDCircuit, FheUint, FheUintPrepared, GetBitCircuitInfo, GetGGSWBit, UnsignedInteger, circuits, }; -impl ExecuteBDDCircuit2WTo1W for Module where - Self: Sized + ExecuteBDDCircuit + GLWEPacking + GLWECopy -{ -} +impl ExecuteBDDCircuit2WTo1W for Module where Self: Sized + ExecuteBDDCircuit + GLWEPacking + GLWECopy +{} -pub trait ExecuteBDDCircuit2WTo1W +pub trait ExecuteBDDCircuit2WTo1W where - Self: Sized + ModuleLogN + ExecuteBDDCircuit + GLWEPacking + GLWECopy, + Self: Sized + ModuleLogN + ExecuteBDDCircuit + GLWEPacking + GLWECopy, { - /// Operations Z x Z -> Z - fn execute_bdd_circuit_2w_to_1w( + fn execute_bdd_circuit_2w_to_1w( &self, out: &mut FheUint, circuit: &C, @@ -32,6 +29,31 @@ where key: &H, scratch: &mut Scratch, ) where + T: UnsignedInteger, + C: GetBitCircuitInfo, + R: DataMut, + A: DataRef, + B: DataRef, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, + H: GLWEAutomorphismKeyHelper, + Scratch: ScratchTakeCore, + { + self.execute_bdd_circuit_2w_to_1w_multi_thread(1, out, circuit, a, b, key, scratch); + } + + #[allow(clippy::too_many_arguments)] + /// Operations Z x Z -> Z + fn execute_bdd_circuit_2w_to_1w_multi_thread( + &self, + threads: usize, + out: &mut FheUint, + circuit: &C, + a: &FheUintPrepared, + b: &FheUintPrepared, + key: &H, + scratch: &mut Scratch, + ) where + T: UnsignedInteger, C: GetBitCircuitInfo, R: DataMut, A: DataRef, @@ -50,7 +72,7 @@ where let (mut out_bits, scratch_1) = scratch.take_glwe_slice(T::BITS as usize, out); // Evaluates out[i] = circuit[i](a, b) - self.execute_bdd_circuit(&mut out_bits, &helper, circuit, scratch_1); + self.execute_bdd_circuit_multi_thread(threads, &mut out_bits, &helper, circuit, scratch_1); // Repacks the bits out.pack(self, out_bits, key, scratch_1); @@ -100,22 +122,43 @@ where #[macro_export] macro_rules! define_bdd_2w_to_1w_trait { ($(#[$meta:meta])* $vis:vis $trait_name:ident, $method_name:ident) => { - $(#[$meta])* - $vis trait $trait_name { - fn $method_name( - &mut self, - module: &M, - a: &FheUintPrepared, - b: &FheUintPrepared, - key: &H, - scratch: &mut Scratch, - ) where - M: ExecuteBDDCircuit2WTo1W, - A: DataRef, - B: DataRef, - K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, - H: GLWEAutomorphismKeyHelper, - Scratch: ScratchTakeCore; + paste::paste! { + $(#[$meta])* + $vis trait $trait_name { + + /// Single-threaded version + fn $method_name( + &mut self, + module: &M, + a: &FheUintPrepared, + b: &FheUintPrepared, + key: &H, + scratch: &mut Scratch, + ) where + M: ExecuteBDDCircuit2WTo1W, + A: DataRef, + B: DataRef, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, + H: GLWEAutomorphismKeyHelper, + Scratch: ScratchTakeCore; + + /// Multithreaded version – same vis, method_name + "_multi_thread" + fn [<$method_name _multi_thread>]( + &mut self, + threads: usize, + module: &M, + a: &FheUintPrepared, + b: &FheUintPrepared, + key: &H, + scratch: &mut Scratch, + ) where + M: ExecuteBDDCircuit2WTo1W, + A: DataRef, + B: DataRef, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, + H: GLWEAutomorphismKeyHelper, + Scratch: ScratchTakeCore; + } } }; } @@ -123,23 +166,45 @@ macro_rules! define_bdd_2w_to_1w_trait { #[macro_export] macro_rules! impl_bdd_2w_to_1w_trait { ($trait_name:ident, $method_name:ident, $ty:ty, $circuit_ty:ty, $output_circuits:path) => { - impl $trait_name<$ty, BE> for FheUint { - fn $method_name( - &mut self, - module: &M, - a: &FheUintPrepared, - b: &FheUintPrepared, - key: &H, - scratch: &mut Scratch, - ) where - M: ExecuteBDDCircuit2WTo1W<$ty, BE>, - A: DataRef, - B: DataRef, - K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, - H: GLWEAutomorphismKeyHelper, - Scratch: ScratchTakeCore, - { - module.execute_bdd_circuit_2w_to_1w(self, &$output_circuits, a, b, key, scratch) + paste::paste! { + impl $trait_name<$ty, BE> for FheUint { + + fn $method_name( + &mut self, + module: &M, + a: &FheUintPrepared, + b: &FheUintPrepared, + key: &H, + scratch: &mut Scratch, + ) where + M: ExecuteBDDCircuit2WTo1W, + A: DataRef, + B: DataRef, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, + H: GLWEAutomorphismKeyHelper, + Scratch: ScratchTakeCore, + { + module.execute_bdd_circuit_2w_to_1w(self, &$output_circuits, a, b, key, scratch) + } + + fn [<$method_name _multi_thread>]( + &mut self, + threads: usize, + module: &M, + a: &FheUintPrepared, + b: &FheUintPrepared, + key: &H, + scratch: &mut Scratch, + ) where + M: ExecuteBDDCircuit2WTo1W, + A: DataRef, + B: DataRef, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, + H: GLWEAutomorphismKeyHelper, + Scratch: ScratchTakeCore, + { + module.execute_bdd_circuit_2w_to_1w_multi_thread(threads, self, &$output_circuits, a, b, key, scratch) + } } } }; diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared.rs index 07d6c46..c4a232f 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared.rs @@ -33,7 +33,7 @@ pub struct FheUintPrepared { impl FheUintPreparedFactory for Module where Self: Sized + GGSWPreparedFactory {} -pub trait GetGGSWBit { +pub trait GetGGSWBit: Sync { fn get_bit(&self, bit: usize) -> GGSWPrepared<&[u8], BE>; } @@ -222,7 +222,14 @@ impl BDDKeyPrepared } pub trait FheUintPrepare { - fn fhe_uint_prepare_tmp_bytes(&self, block_size: usize, extension_factor: usize, res_infos: &R, bits_infos: &A, bdd_infos: &B) -> usize + fn fhe_uint_prepare_tmp_bytes( + &self, + block_size: usize, + extension_factor: usize, + res_infos: &R, + bits_infos: &A, + bdd_infos: &B, + ) -> usize where R: GGSWInfos, A: GLWEInfos, @@ -258,6 +265,7 @@ pub trait FheUintPrepare { { self.fhe_uint_prepare_custom_multi_thread(1, res, bits, bit_start, bit_count, key, scratch) } + #[allow(clippy::too_many_arguments)] fn fhe_uint_prepare_custom_multi_thread( &self, threads: usize, @@ -279,7 +287,14 @@ where Self: LWEFromGLWE + CirtuitBootstrappingExecute + GGSWPreparedFactory, Scratch: ScratchTakeCore, { - fn fhe_uint_prepare_tmp_bytes(&self, block_size: usize, extension_factor: usize, res_infos: &R, bits_infos: &A, bdd_infos: &B) -> usize + fn fhe_uint_prepare_tmp_bytes( + &self, + block_size: usize, + extension_factor: usize, + res_infos: &R, + bits_infos: &A, + bdd_infos: &B, + ) -> usize where R: GGSWInfos, A: GLWEInfos, @@ -302,7 +317,7 @@ where bit_start: usize, bit_count: usize, key: &K, - mut scratch: &mut Scratch, + scratch: &mut Scratch, ) where DM: DataMut, DB: DataRef, @@ -318,16 +333,9 @@ where assert!(scratch.available() >= threads * scratch_thread_size); - // How many bits we need to process - let chunk_size: usize = bit_count.div_ceil(threads); // ceil division + let chunk_size: usize = bit_count.div_ceil(threads); - let mut scratches = Vec::new(); - for _ in 0..(threads - 1) { - let (tmp, scratch_new) = scratch.split_at_mut(scratch_thread_size); - scratch = scratch_new; - scratches.push(tmp); - } - scratches.push(scratch); + let (mut scratches, _) = scratch.split_mut(threads, scratch_thread_size); let ggsw_infos: &GGSWLayout = &res.ggsw_layout(); @@ -392,6 +400,7 @@ impl FheUintPrepared { module.fhe_uint_prepare_custom(self, other, bit_start, bit_end, key, scratch); } + #[allow(clippy::too_many_arguments)] pub fn prepare_custom_multi_thread( &mut self, threads: usize, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared_debug.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared_debug.rs index 182365b..d536d57 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared_debug.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared_debug.rs @@ -124,13 +124,13 @@ where DM: DataMut, DR0: DataRef, DR1: DataRef, - { - + { let (_, scratch_1) = scratch.take_ggsw(res); let (mut tmp_lwe, scratch_2) = scratch_1.take_lwe(bits); for (bit, dst) in res.bits.iter_mut().enumerate() { bits.get_bit_lwe(self, bit, &mut tmp_lwe, &key.ks, scratch_2); - key.cbt.execute_to_constant(self, dst, &tmp_lwe, 1, 1, scratch_2); + key.cbt + .execute_to_constant(self, dst, &tmp_lwe, 1, 1, scratch_2); } } } diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs index 618304f..af8c59b 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs @@ -1,4 +1,5 @@ use core::panic; +use std::thread; use itertools::Itertools; use poulpy_core::{ @@ -6,17 +7,20 @@ use poulpy_core::{ layouts::{GGSWInfos, GGSWPrepared, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, prepared::GGSWPreparedToRef}, }; use poulpy_hal::{ - api::{ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftBytesOf}, + api::{ + ScratchAvailable, ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, + VecZnxDftBytesOf, + }, layouts::{Backend, DataMut, Module, Scratch, VecZnxBig, ZnxZero}, }; use crate::tfhe::bdd_arithmetic::{GetGGSWBit, UnsignedInteger}; -pub trait BitCircuitInfo { +pub trait BitCircuitInfo: Sync { fn info(&self) -> (&[Node], usize); } -pub trait GetBitCircuitInfo { +pub trait GetBitCircuitInfo: Sync { fn input_size(&self) -> usize; fn output_size(&self) -> usize; fn get_circuit(&self, bit: usize) -> (&[Node], usize); @@ -49,9 +53,34 @@ where } } -pub trait ExecuteBDDCircuit { - fn execute_bdd_circuit(&self, out: &mut [GLWE], inputs: &G, circuit: &C, scratch: &mut Scratch) +pub trait ExecuteBDDCircuit { + fn execute_bdd_circuit_tmp_bytes(&self, res_infos: &R, state_size: usize, ggsw_infos: &G) -> usize where + R: GLWEInfos, + G: GGSWInfos; + + fn execute_bdd_circuit( + &self, + out: &mut [GLWE], + inputs: &G, + circuit: &C, + scratch: &mut Scratch, + ) where + G: GetGGSWBit + BitSize, + C: GetBitCircuitInfo, + O: DataMut, + { + self.execute_bdd_circuit_multi_thread(1, out, inputs, circuit, scratch); + } + + fn execute_bdd_circuit_multi_thread( + &self, + threads: usize, + out: &mut [GLWE], + inputs: &G, + circuit: &C, + scratch: &mut Scratch, + ) where G: GetGGSWBit + BitSize, C: GetBitCircuitInfo, O: DataMut; @@ -61,13 +90,27 @@ pub trait BitSize { fn bit_size(&self) -> usize; } -impl ExecuteBDDCircuit for Module +impl ExecuteBDDCircuit for Module where Self: Cmux + GLWECopy, Scratch: ScratchTakeCore, { - fn execute_bdd_circuit(&self, out: &mut [GLWE], inputs: &G, circuit: &C, scratch: &mut Scratch) + fn execute_bdd_circuit_tmp_bytes(&self, res_infos: &R, state_size: usize, ggsw_infos: &G) -> usize where + R: GLWEInfos, + G: GGSWInfos, + { + 2 * state_size * GLWE::bytes_of_from_infos(res_infos) + self.cmux_tmp_bytes(res_infos, res_infos, ggsw_infos) + } + + fn execute_bdd_circuit_multi_thread( + &self, + threads: usize, + out: &mut [GLWE], + inputs: &G, + circuit: &C, + scratch: &mut Scratch, + ) where G: GetGGSWBit + BitSize, C: GetBitCircuitInfo, O: DataMut, @@ -88,66 +131,43 @@ where ); } - for (i, out_i) in out.iter_mut().enumerate().take(circuit.output_size()) { - let (nodes, max_inter_state) = circuit.get_circuit(i); + let mut max_state_size = 0; + for i in 0..circuit.output_size() { + let (_, state_size) = circuit.get_circuit(i); + max_state_size = max_state_size.max(state_size) + } - if max_inter_state == 0 { - out_i.data_mut().zero(); - } else { - assert!(nodes.len().is_multiple_of(max_inter_state)); + let scratch_thread_size: usize = self.execute_bdd_circuit_tmp_bytes(&out[0], max_state_size, &inputs.get_bit(0)); - let (mut level, scratch_1) = scratch.take_glwe_slice(max_inter_state * 2, out_i); + assert!( + scratch.available() >= threads * scratch_thread_size, + "scratch.available(): {} < threads:{threads} * scratch_thread_size: {scratch_thread_size}", + scratch.available() + ); - level.iter_mut().for_each(|ct| ct.data_mut().zero()); + let (mut scratches, _) = scratch.split_mut(threads, scratch_thread_size); - // TODO: implement API on GLWE - level[1] - .data_mut() - .encode_coeff_i64(out_i.base2k().into(), 0, 2, 0, 1); + let chunk_size: usize = circuit.output_size().div_ceil(threads); - let mut level_ref = level.iter_mut().collect_vec(); - let (mut prev_level, mut next_level) = level_ref.split_at_mut(max_inter_state); + thread::scope(|scope| { + for (scratch_thread, out_chunk) in scratches + .iter_mut() + .zip(out[..circuit.output_size()].chunks_mut(chunk_size)) + { + // Capture chunk + thread scratch by move + scope.spawn(move || { + for (idx, out_i) in out_chunk.iter_mut().enumerate() { + let (nodes, state_size) = circuit.get_circuit(idx); - let (all_but_last, last) = nodes.split_at(nodes.len() - max_inter_state); - - for nodes_lvl in all_but_last.chunks_exact(max_inter_state) { - for (j, node) in nodes_lvl.iter().enumerate() { - match node { - Node::Cmux(in_idx, hi_idx, lo_idx) => { - self.cmux( - next_level[j], - prev_level[*hi_idx], - prev_level[*lo_idx], - &inputs.get_bit(*in_idx), - scratch_1, - ); - } - Node::Copy => self.glwe_copy(next_level[j], prev_level[j]), /* Update BDD circuits to order Cmux -> Copy -> None so that mem swap can be used */ - Node::None => {} + if state_size == 0 { + out_i.data_mut().zero(); + } else { + eval_level(self, out_i, inputs, nodes, state_size, *scratch_thread); } } - - (prev_level, next_level) = (next_level, prev_level); - } - - // Last chunck of max_inter_state Nodes is always structured as - // [CMUX, NONE, NONE, ..., NONE] - match &last[0] { - Node::Cmux(in_idx, hi_idx, lo_idx) => { - self.cmux( - out_i, - prev_level[*hi_idx], - prev_level[*lo_idx], - &inputs.get_bit(*in_idx), - scratch_1, - ); - } - _ => { - panic!("invalid last node, should be CMUX") - } - } + }); } - } + }); for out_i in out.iter_mut().skip(circuit.output_size()) { out_i.data_mut().zero(); @@ -155,6 +175,74 @@ where } } +fn eval_level( + module: &M, + res: &mut R, + inputs: &G, + nodes: &[Node], + state_size: usize, + scratch: &mut Scratch, +) where + M: Cmux + GLWECopy, + R: GLWEToMut, + G: GetGGSWBit + BitSize, + Scratch: ScratchTakeCore, +{ + assert!(nodes.len().is_multiple_of(state_size)); + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + + let (mut level, scratch_1) = scratch.take_glwe_slice(state_size * 2, res); + + level.iter_mut().for_each(|ct| ct.data_mut().zero()); + + // TODO: implement API on GLWE + level[1] + .data_mut() + .encode_coeff_i64(res.base2k().into(), 0, 2, 0, 1); + + let mut level_ref: Vec<&mut GLWE<&mut [u8]>> = level.iter_mut().collect_vec(); + let (mut prev_level, mut next_level) = level_ref.split_at_mut(state_size); + + let (all_but_last, last) = nodes.split_at(nodes.len() - state_size); + + for nodes_lvl in all_but_last.chunks_exact(state_size) { + for (j, node) in nodes_lvl.iter().enumerate() { + match node { + Node::Cmux(in_idx, hi_idx, lo_idx) => { + module.cmux( + next_level[j], + prev_level[*hi_idx], + prev_level[*lo_idx], + &inputs.get_bit(*in_idx), + scratch_1, + ); + } + Node::Copy => module.glwe_copy(next_level[j], prev_level[j]), /* Update BDD circuits to order Cmux -> Copy -> None so that mem swap can be used */ + Node::None => {} + } + } + + (prev_level, next_level) = (next_level, prev_level); + } + + // Last chunck of max_inter_state Nodes is always structured as + // [CMUX, NONE, NONE, ..., NONE] + match &last[0] { + Node::Cmux(in_idx, hi_idx, lo_idx) => { + module.cmux( + res, + prev_level[*hi_idx], + prev_level[*lo_idx], + &inputs.get_bit(*in_idx), + scratch_1, + ); + } + _ => { + panic!("invalid last node, should be CMUX") + } + } +} + impl BitCircuit { pub const fn new(nodes: [Node; N], max_inter_state: usize) -> Self { Self { diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs index dbd52cf..e3555f8 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs @@ -136,17 +136,21 @@ where pub(crate) ks: GLWEToLWEKeyPrepared, } -impl BDDKeyInfos for BDDKeyPrepared{ +impl BDDKeyInfos for BDDKeyPrepared { fn cbt_infos(&self) -> CircuitBootstrappingKeyLayout { - CircuitBootstrappingKeyLayout { layout_brk: self.cbt.brk_infos(), layout_atk: self.cbt.atk_infos(), layout_tsk: self.cbt.tsk_infos() } + CircuitBootstrappingKeyLayout { + layout_brk: self.cbt.brk_infos(), + layout_atk: self.cbt.atk_infos(), + layout_tsk: self.cbt.tsk_infos(), + } } fn ks_infos(&self) -> GLWEToLWEKeyLayout { - GLWEToLWEKeyLayout{ + GLWEToLWEKeyLayout { n: self.ks.n(), base2k: self.ks.base2k(), k: self.ks.k(), rank_in: self.ks.rank_in(), - dnum: self.ks.dnum() + dnum: self.ks.dnum(), } } } diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/add.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/add.rs index 7810dd3..d41f18b 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/add.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/add.rs @@ -31,7 +31,7 @@ where + BDDKeyPreparedFactory + GGSWNoise + FheUintPrepare - + ExecuteBDDCircuit2WTo1W + + ExecuteBDDCircuit2WTo1W + GLWEEncryptSk, BlindRotationKey, BRA>: BlindRotationKeyFactory, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/and.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/and.rs index ee468f6..0c74fe9 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/and.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/and.rs @@ -31,7 +31,7 @@ where + BDDKeyPreparedFactory + GGSWNoise + FheUintPrepare - + ExecuteBDDCircuit2WTo1W + + ExecuteBDDCircuit2WTo1W + GLWEEncryptSk, BlindRotationKey, BRA>: BlindRotationKeyFactory, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs index 488967a..08de35a 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs @@ -133,8 +133,8 @@ impl TestContext { } } -pub(crate) const TEST_N_GLWE: u32 = 1024; -pub(crate) const TEST_N_LWE: u32 = 574; +pub(crate) const TEST_N_GLWE: u32 = 256; +pub(crate) const TEST_N_LWE: u32 = 77; pub(crate) const TEST_BASE2K: u32 = 13; pub(crate) const TEST_K_GLWE: u32 = 26; pub(crate) const TEST_K_GGSW: u32 = 39; diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/or.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/or.rs index 91ce50f..2564ae7 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/or.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/or.rs @@ -31,7 +31,7 @@ where + BDDKeyPreparedFactory + GGSWNoise + FheUintPrepare - + ExecuteBDDCircuit2WTo1W + + ExecuteBDDCircuit2WTo1W + GLWEEncryptSk, BlindRotationKey, BRA>: BlindRotationKeyFactory, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/prepare.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/prepare.rs index b7ed299..c99dd2a 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/prepare.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/prepare.rs @@ -31,7 +31,7 @@ where + BDDKeyPreparedFactory + GGSWNoise + FheUintPrepare - + ExecuteBDDCircuit2WTo1W + + ExecuteBDDCircuit2WTo1W + GLWEEncryptSk, BlindRotationKey, BRA>: BlindRotationKeyFactory, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sll.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sll.rs index b37d816..b20ad5a 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sll.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sll.rs @@ -31,7 +31,7 @@ where + BDDKeyPreparedFactory + GGSWNoise + FheUintPrepare - + ExecuteBDDCircuit2WTo1W + + ExecuteBDDCircuit2WTo1W + GLWEEncryptSk, BlindRotationKey, BRA>: BlindRotationKeyFactory, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/slt.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/slt.rs index dca9f6f..6e55738 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/slt.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/slt.rs @@ -31,7 +31,7 @@ where + BDDKeyPreparedFactory + GGSWNoise + FheUintPrepare - + ExecuteBDDCircuit2WTo1W + + ExecuteBDDCircuit2WTo1W + GLWEEncryptSk, BlindRotationKey, BRA>: BlindRotationKeyFactory, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sltu.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sltu.rs index aa7317d..4195fbe 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sltu.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sltu.rs @@ -31,7 +31,7 @@ where + BDDKeyPreparedFactory + GGSWNoise + FheUintPrepare - + ExecuteBDDCircuit2WTo1W + + ExecuteBDDCircuit2WTo1W + GLWEEncryptSk, BlindRotationKey, BRA>: BlindRotationKeyFactory, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sra.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sra.rs index 730c539..364c880 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sra.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sra.rs @@ -31,7 +31,7 @@ where + BDDKeyPreparedFactory + GGSWNoise + FheUintPrepare - + ExecuteBDDCircuit2WTo1W + + ExecuteBDDCircuit2WTo1W + GLWEEncryptSk, BlindRotationKey, BRA>: BlindRotationKeyFactory, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/srl.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/srl.rs index 61c3fb6..ce9a65d 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/srl.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/srl.rs @@ -31,7 +31,7 @@ where + BDDKeyPreparedFactory + GGSWNoise + FheUintPrepare - + ExecuteBDDCircuit2WTo1W + + ExecuteBDDCircuit2WTo1W + GLWEEncryptSk, BlindRotationKey, BRA>: BlindRotationKeyFactory, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sub.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sub.rs index 3a73fbf..445279d 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sub.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sub.rs @@ -31,7 +31,7 @@ where + BDDKeyPreparedFactory + GGSWNoise + FheUintPrepare - + ExecuteBDDCircuit2WTo1W + + ExecuteBDDCircuit2WTo1W + GLWEEncryptSk, BlindRotationKey, BRA>: BlindRotationKeyFactory, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/xor.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/xor.rs index 40c05b8..8db9fcd 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/xor.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/xor.rs @@ -31,7 +31,7 @@ where + BDDKeyPreparedFactory + GGSWNoise + FheUintPrepare - + ExecuteBDDCircuit2WTo1W + + ExecuteBDDCircuit2WTo1W + GLWEEncryptSk, BlindRotationKey, BRA>: BlindRotationKeyFactory, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs index 17371b1..62b706a 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs @@ -8,7 +8,8 @@ use poulpy_hal::{ use poulpy_core::{ GGSWFromGGLWE, GLWEDecrypt, GLWEPacking, GLWERotate, GLWETrace, ScratchTakeCore, layouts::{ - Dsize, GGLWE, GGLWEInfos, GGLWELayout, GGLWEPreparedToRef, GGSWInfos, GGSWToMut, GLWEAutomorphismKeyHelper, GLWEInfos, GLWESecretPreparedFactory, GLWEToMut, GLWEToRef, GetGaloisElement, LWEInfos, LWEToRef, Rank + Dsize, GGLWE, GGLWEInfos, GGLWELayout, GGLWEPreparedToRef, GGSWInfos, GGSWToMut, GLWEAutomorphismKeyHelper, GLWEInfos, + GLWESecretPreparedFactory, GLWEToMut, GLWEToRef, GetGaloisElement, LWEInfos, LWEToRef, Rank, }, }; @@ -131,14 +132,13 @@ where R: GGSWInfos, A: CircuitBootstrappingKeyInfos, { - let gglwe_infos: GGLWELayout = GGLWELayout { n: res_infos.n(), base2k: res_infos.base2k(), k: res_infos.k(), dnum: res_infos.dnum(), dsize: Dsize(1), - rank_in: res_infos.rank().max(Rank(1)).into(), + rank_in: res_infos.rank().max(Rank(1)), rank_out: res_infos.rank(), }; @@ -149,7 +149,9 @@ where &cbt_infos.brk_infos(), ) .max(self.glwe_trace_tmp_bytes(res_infos, res_infos, &cbt_infos.atk_infos())) - .max(self.ggsw_from_gglwe_tmp_bytes(res_infos, &cbt_infos.tsk_infos())) + GLWE::bytes_of_from_infos(res_infos) + GGLWE::bytes_of_from_infos(&gglwe_infos) + .max(self.ggsw_from_gglwe_tmp_bytes(res_infos, &cbt_infos.tsk_infos())) + + GLWE::bytes_of_from_infos(res_infos) + + GGLWE::bytes_of_from_infos(&gglwe_infos) } fn circuit_bootstrapping_execute_to_constant( @@ -164,9 +166,10 @@ where R: GGSWToMut + GGSWInfos, L: LWEToRef + LWEInfos, D: DataRef, - { - - assert!(scratch.available() >= self.circuit_bootstrapping_execute_tmp_bytes(key.block_size(), extension_factor, res, key)); + { + assert!( + scratch.available() >= self.circuit_bootstrapping_execute_tmp_bytes(key.block_size(), extension_factor, res, key) + ); circuit_bootstrap_core( false, @@ -194,9 +197,10 @@ where R: GGSWToMut + GGSWInfos, L: LWEToRef + LWEInfos, D: DataRef, - { - - assert!(scratch.available() >= self.circuit_bootstrapping_execute_tmp_bytes(key.block_size(), extension_factor, res, key)); + { + assert!( + scratch.available() >= self.circuit_bootstrapping_execute_tmp_bytes(key.block_size(), extension_factor, res, key) + ); circuit_bootstrap_core( true, @@ -239,7 +243,7 @@ pub fn circuit_bootstrap_core( + ModuleLogN, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchTakeCore, -{ +{ let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); let lwe: &LWE<&[u8]> = &lwe.to_ref(); From 33e1656368c92b40626a4447ae858f4fdcff2e21 Mon Sep 17 00:00:00 2001 From: Pro7ech Date: Wed, 12 Nov 2025 15:08:21 +0100 Subject: [PATCH 3/3] Remove T from GetBitCircuit --- .../src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs | 4 +-- .../src/tfhe/bdd_arithmetic/eval.rs | 25 ++++++++----------- 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs index 6cd65b1..f167657 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs @@ -30,7 +30,7 @@ where scratch: &mut Scratch, ) where T: UnsignedInteger, - C: GetBitCircuitInfo, + C: GetBitCircuitInfo, R: DataMut, A: DataRef, B: DataRef, @@ -54,7 +54,7 @@ where scratch: &mut Scratch, ) where T: UnsignedInteger, - C: GetBitCircuitInfo, + C: GetBitCircuitInfo, R: DataMut, A: DataRef, B: DataRef, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs index af8c59b..99eac90 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs @@ -14,13 +14,13 @@ use poulpy_hal::{ layouts::{Backend, DataMut, Module, Scratch, VecZnxBig, ZnxZero}, }; -use crate::tfhe::bdd_arithmetic::{GetGGSWBit, UnsignedInteger}; +use crate::tfhe::bdd_arithmetic::GetGGSWBit; pub trait BitCircuitInfo: Sync { fn info(&self) -> (&[Node], usize); } -pub trait GetBitCircuitInfo: Sync { +pub trait GetBitCircuitInfo: Sync { fn input_size(&self) -> usize; fn output_size(&self) -> usize; fn get_circuit(&self, bit: usize) -> (&[Node], usize); @@ -38,7 +38,7 @@ pub trait BitCircuitFamily { pub struct Circuit(pub [C; N]); -impl GetBitCircuitInfo for Circuit +impl GetBitCircuitInfo for Circuit where C: BitCircuitInfo + BitCircuitFamily, { @@ -59,21 +59,16 @@ pub trait ExecuteBDDCircuit { R: GLWEInfos, G: GGSWInfos; - fn execute_bdd_circuit( - &self, - out: &mut [GLWE], - inputs: &G, - circuit: &C, - scratch: &mut Scratch, - ) where + fn execute_bdd_circuit(&self, out: &mut [GLWE], inputs: &G, circuit: &C, scratch: &mut Scratch) + where G: GetGGSWBit + BitSize, - C: GetBitCircuitInfo, + C: GetBitCircuitInfo, O: DataMut, { self.execute_bdd_circuit_multi_thread(1, out, inputs, circuit, scratch); } - fn execute_bdd_circuit_multi_thread( + fn execute_bdd_circuit_multi_thread( &self, threads: usize, out: &mut [GLWE], @@ -82,7 +77,7 @@ pub trait ExecuteBDDCircuit { scratch: &mut Scratch, ) where G: GetGGSWBit + BitSize, - C: GetBitCircuitInfo, + C: GetBitCircuitInfo, O: DataMut; } @@ -103,7 +98,7 @@ where 2 * state_size * GLWE::bytes_of_from_infos(res_infos) + self.cmux_tmp_bytes(res_infos, res_infos, ggsw_infos) } - fn execute_bdd_circuit_multi_thread( + fn execute_bdd_circuit_multi_thread( &self, threads: usize, out: &mut [GLWE], @@ -112,7 +107,7 @@ where scratch: &mut Scratch, ) where G: GetGGSWBit + BitSize, - C: GetBitCircuitInfo, + C: GetBitCircuitInfo, O: DataMut, { #[cfg(debug_assertions)]