From 1423de1c46bc3d81099f75c5e8473ebe185b06e0 Mon Sep 17 00:00:00 2001 From: Pro7ech Date: Wed, 12 Nov 2025 11:02:37 +0100 Subject: [PATCH] Add multi-thread bdd eval --- Cargo.lock | 7 + poulpy-hal/src/api/scratch.rs | 14 +- poulpy-schemes/Cargo.toml | 4 +- poulpy-schemes/benches/fhe_uint_prepare.rs | 126 ----------- .../src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs | 149 +++++++++---- .../ciphertexts/fhe_uint_prepared.rs | 35 +-- .../ciphertexts/fhe_uint_prepared_debug.rs | 6 +- .../src/tfhe/bdd_arithmetic/eval.rs | 204 +++++++++++++----- poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs | 12 +- .../bdd_arithmetic/tests/test_suite/add.rs | 2 +- .../bdd_arithmetic/tests/test_suite/and.rs | 2 +- .../bdd_arithmetic/tests/test_suite/mod.rs | 4 +- .../bdd_arithmetic/tests/test_suite/or.rs | 2 +- .../tests/test_suite/prepare.rs | 2 +- .../bdd_arithmetic/tests/test_suite/sll.rs | 2 +- .../bdd_arithmetic/tests/test_suite/slt.rs | 2 +- .../bdd_arithmetic/tests/test_suite/sltu.rs | 2 +- .../bdd_arithmetic/tests/test_suite/sra.rs | 2 +- .../bdd_arithmetic/tests/test_suite/srl.rs | 2 +- .../bdd_arithmetic/tests/test_suite/sub.rs | 2 +- .../bdd_arithmetic/tests/test_suite/xor.rs | 2 +- .../src/tfhe/circuit_bootstrapping/circuit.rs | 26 ++- 22 files changed, 336 insertions(+), 273 deletions(-) delete mode 100644 poulpy-schemes/benches/fhe_uint_prepare.rs diff --git a/Cargo.lock b/Cargo.lock index d0d2744..ebc0f43 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -323,6 +323,12 @@ version = "11.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + [[package]] name = "plotters" version = "0.3.7" @@ -406,6 +412,7 @@ dependencies = [ "byteorder", "criterion", "itertools 0.14.0", + "paste", "poulpy-backend", "poulpy-core", "poulpy-hal", diff --git a/poulpy-hal/src/api/scratch.rs b/poulpy-hal/src/api/scratch.rs index 051f52d..aef02e1 100644 --- a/poulpy-hal/src/api/scratch.rs +++ b/poulpy-hal/src/api/scratch.rs @@ -30,12 +30,24 @@ pub trait TakeSlice { impl Scratch where - Self: TakeSlice + ScratchFromBytes, + Self: TakeSlice + ScratchAvailable + ScratchFromBytes, { pub fn split_at_mut(&mut self, len: usize) -> (&mut Scratch, &mut Self) { let (take_slice, rem_slice) = self.take_slice(len); (Self::from_bytes(take_slice), rem_slice) } + + pub fn split_mut(&mut self, n: usize, len: usize) -> (Vec<&mut Scratch>, &mut Self) { + assert!(self.available() >= n * len); + let mut scratches: Vec<&mut Scratch> = Vec::with_capacity(n); + let mut scratch: &mut Scratch = self; + for _ in 0..n { + let (tmp, scratch_new) = scratch.split_at_mut(len); + scratch = scratch_new; + scratches.push(tmp); + } + (scratches, scratch) + } } impl ScratchTakeBasic for Scratch where Self: TakeSlice + ScratchFromBytes {} diff --git a/poulpy-schemes/Cargo.toml b/poulpy-schemes/Cargo.toml index f4058d8..bf1ec39 100644 --- a/poulpy-schemes/Cargo.toml +++ b/poulpy-schemes/Cargo.toml @@ -17,8 +17,8 @@ criterion = {workspace = true} itertools = "0.14.0" byteorder = "1.5.0" rand = "0.9.2" - +paste = "1.0.15" [[bench]] -name = "fhe_uint_prepare" +name = "circuit_bootstrapping" harness = false \ No newline at end of file diff --git a/poulpy-schemes/benches/fhe_uint_prepare.rs b/poulpy-schemes/benches/fhe_uint_prepare.rs deleted file mode 100644 index d403fc2..0000000 --- a/poulpy-schemes/benches/fhe_uint_prepare.rs +++ /dev/null @@ -1,126 +0,0 @@ -use std::hint::black_box; - -use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; -use poulpy_backend::{FFT64Avx, FFT64Ref}; -use poulpy_core::{ - GGSWNoise, GLWEDecrypt, GLWEEncryptSk, GLWENoise, ScratchTakeCore, - layouts::{GGSWLayout, GLWELayout, GLWESecretPreparedFactory, prepared::GLWESecretPrepared}, -}; -use poulpy_hal::{ - api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow}, - layouts::{Backend, Module, Scratch, ScratchOwned}, - source::Source, -}; -use rand::RngCore; - -use poulpy_schemes::tfhe::{ - bdd_arithmetic::{ - BDDKeyEncryptSk, BDDKeyPrepared, BDDKeyPreparedFactory, ExecuteBDDCircuit2WTo1W, FheUint, FheUintPrepare, - FheUintPrepareDebug, FheUintPrepared, FheUintPreparedEncryptSk, FheUintPreparedFactory, - tests::test_suite::TestContext, - }, - blind_rotation::{BlindRotationAlgo, BlindRotationKey, BlindRotationKeyFactory, CGGI}, -}; - -pub fn benc_bdd_prepare( - c: &mut Criterion, - label: &str, - test_context: &TestContext, -) where - Module: ModuleNew - + GLWESecretPreparedFactory - + GLWEDecrypt - + GLWENoise - + FheUintPreparedFactory - + FheUintPreparedEncryptSk - + FheUintPrepareDebug - + BDDKeyEncryptSk - + BDDKeyPreparedFactory - + GGSWNoise - + FheUintPrepare - + ExecuteBDDCircuit2WTo1W - + GLWEEncryptSk, - BlindRotationKey, BRA>: BlindRotationKeyFactory, - ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, - Scratch: ScratchTakeCore, -{ - let group_name: String = format!("bdd_prepare::{label}"); - - let mut group = c.benchmark_group(group_name); - - fn runner(test_context: &TestContext) -> impl FnMut() - where - Module: ModuleNew - + GLWESecretPreparedFactory - + GLWEDecrypt - + GLWENoise - + FheUintPreparedFactory - + FheUintPreparedEncryptSk - + FheUintPrepareDebug - + BDDKeyEncryptSk - + BDDKeyPreparedFactory - + GGSWNoise - + FheUintPrepare - + ExecuteBDDCircuit2WTo1W - + GLWEEncryptSk, - BlindRotationKey, BRA>: BlindRotationKeyFactory, - ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, - Scratch: ScratchTakeCore, - { - let glwe_infos: GLWELayout = test_context.glwe_infos(); - let ggsw_infos: GGSWLayout = test_context.ggsw_infos(); - - let module: &Module = &test_context.module; - let sk_glwe_prep: &GLWESecretPrepared, BE> = &test_context.sk_glwe; - let bdd_key_prepared: &BDDKeyPrepared, BRA, BE> = &test_context.bdd_key; - - let mut source: Source = Source::new([6u8; 32]); - - let mut source_xa: Source = Source::new([2u8; 32]); - let mut source_xe: Source = Source::new([3u8; 32]); - - let threads = 1; - - let mut scratch: ScratchOwned = ScratchOwned::alloc((1 << 22) * threads); - - // GLWE(value) - let mut c_enc: FheUint, u32> = FheUint::alloc_from_infos(&glwe_infos); - let value: u32 = source.next_u32(); - c_enc.encrypt_sk( - module, - value, - sk_glwe_prep, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - // GGSW(0) - let mut c_enc_prep: FheUintPrepared, u32, BE> = - FheUintPrepared::, u32, BE>::alloc_from_infos(module, &ggsw_infos); - - // GGSW(value) - move || { - c_enc_prep.prepare_custom_multi_thread(threads, module, &c_enc, 0, 32, bdd_key_prepared, scratch.borrow()); - black_box(()); - } - } - - let id: BenchmarkId = BenchmarkId::from_parameter(format!("n_glwe: {}", test_context.module.n())); - let mut runner = runner::(test_context); - group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); - - group.finish(); -} - -fn bench_bdd_prepare_cpu_ref_fft64(c: &mut Criterion) { - benc_bdd_prepare::( - c, - "bdd_prepare_fft64_ref", - &TestContext::::new(), - ); -} - -criterion_group!(benches, bench_bdd_prepare_cpu_ref_fft64,); - -criterion_main!(benches); diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs index b56f907..6cd65b1 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs @@ -13,17 +13,14 @@ use crate::tfhe::bdd_arithmetic::{ BitSize, ExecuteBDDCircuit, FheUint, FheUintPrepared, GetBitCircuitInfo, GetGGSWBit, UnsignedInteger, circuits, }; -impl ExecuteBDDCircuit2WTo1W for Module where - Self: Sized + ExecuteBDDCircuit + GLWEPacking + GLWECopy -{ -} +impl ExecuteBDDCircuit2WTo1W for Module where Self: Sized + ExecuteBDDCircuit + GLWEPacking + GLWECopy +{} -pub trait ExecuteBDDCircuit2WTo1W +pub trait ExecuteBDDCircuit2WTo1W where - Self: Sized + ModuleLogN + ExecuteBDDCircuit + GLWEPacking + GLWECopy, + Self: Sized + ModuleLogN + ExecuteBDDCircuit + GLWEPacking + GLWECopy, { - /// Operations Z x Z -> Z - fn execute_bdd_circuit_2w_to_1w( + fn execute_bdd_circuit_2w_to_1w( &self, out: &mut FheUint, circuit: &C, @@ -32,6 +29,31 @@ where key: &H, scratch: &mut Scratch, ) where + T: UnsignedInteger, + C: GetBitCircuitInfo, + R: DataMut, + A: DataRef, + B: DataRef, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, + H: GLWEAutomorphismKeyHelper, + Scratch: ScratchTakeCore, + { + self.execute_bdd_circuit_2w_to_1w_multi_thread(1, out, circuit, a, b, key, scratch); + } + + #[allow(clippy::too_many_arguments)] + /// Operations Z x Z -> Z + fn execute_bdd_circuit_2w_to_1w_multi_thread( + &self, + threads: usize, + out: &mut FheUint, + circuit: &C, + a: &FheUintPrepared, + b: &FheUintPrepared, + key: &H, + scratch: &mut Scratch, + ) where + T: UnsignedInteger, C: GetBitCircuitInfo, R: DataMut, A: DataRef, @@ -50,7 +72,7 @@ where let (mut out_bits, scratch_1) = scratch.take_glwe_slice(T::BITS as usize, out); // Evaluates out[i] = circuit[i](a, b) - self.execute_bdd_circuit(&mut out_bits, &helper, circuit, scratch_1); + self.execute_bdd_circuit_multi_thread(threads, &mut out_bits, &helper, circuit, scratch_1); // Repacks the bits out.pack(self, out_bits, key, scratch_1); @@ -100,22 +122,43 @@ where #[macro_export] macro_rules! define_bdd_2w_to_1w_trait { ($(#[$meta:meta])* $vis:vis $trait_name:ident, $method_name:ident) => { - $(#[$meta])* - $vis trait $trait_name { - fn $method_name( - &mut self, - module: &M, - a: &FheUintPrepared, - b: &FheUintPrepared, - key: &H, - scratch: &mut Scratch, - ) where - M: ExecuteBDDCircuit2WTo1W, - A: DataRef, - B: DataRef, - K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, - H: GLWEAutomorphismKeyHelper, - Scratch: ScratchTakeCore; + paste::paste! { + $(#[$meta])* + $vis trait $trait_name { + + /// Single-threaded version + fn $method_name( + &mut self, + module: &M, + a: &FheUintPrepared, + b: &FheUintPrepared, + key: &H, + scratch: &mut Scratch, + ) where + M: ExecuteBDDCircuit2WTo1W, + A: DataRef, + B: DataRef, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, + H: GLWEAutomorphismKeyHelper, + Scratch: ScratchTakeCore; + + /// Multithreaded version – same vis, method_name + "_multi_thread" + fn [<$method_name _multi_thread>]( + &mut self, + threads: usize, + module: &M, + a: &FheUintPrepared, + b: &FheUintPrepared, + key: &H, + scratch: &mut Scratch, + ) where + M: ExecuteBDDCircuit2WTo1W, + A: DataRef, + B: DataRef, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, + H: GLWEAutomorphismKeyHelper, + Scratch: ScratchTakeCore; + } } }; } @@ -123,23 +166,45 @@ macro_rules! define_bdd_2w_to_1w_trait { #[macro_export] macro_rules! impl_bdd_2w_to_1w_trait { ($trait_name:ident, $method_name:ident, $ty:ty, $circuit_ty:ty, $output_circuits:path) => { - impl $trait_name<$ty, BE> for FheUint { - fn $method_name( - &mut self, - module: &M, - a: &FheUintPrepared, - b: &FheUintPrepared, - key: &H, - scratch: &mut Scratch, - ) where - M: ExecuteBDDCircuit2WTo1W<$ty, BE>, - A: DataRef, - B: DataRef, - K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, - H: GLWEAutomorphismKeyHelper, - Scratch: ScratchTakeCore, - { - module.execute_bdd_circuit_2w_to_1w(self, &$output_circuits, a, b, key, scratch) + paste::paste! { + impl $trait_name<$ty, BE> for FheUint { + + fn $method_name( + &mut self, + module: &M, + a: &FheUintPrepared, + b: &FheUintPrepared, + key: &H, + scratch: &mut Scratch, + ) where + M: ExecuteBDDCircuit2WTo1W, + A: DataRef, + B: DataRef, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, + H: GLWEAutomorphismKeyHelper, + Scratch: ScratchTakeCore, + { + module.execute_bdd_circuit_2w_to_1w(self, &$output_circuits, a, b, key, scratch) + } + + fn [<$method_name _multi_thread>]( + &mut self, + threads: usize, + module: &M, + a: &FheUintPrepared, + b: &FheUintPrepared, + key: &H, + scratch: &mut Scratch, + ) where + M: ExecuteBDDCircuit2WTo1W, + A: DataRef, + B: DataRef, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, + H: GLWEAutomorphismKeyHelper, + Scratch: ScratchTakeCore, + { + module.execute_bdd_circuit_2w_to_1w_multi_thread(threads, self, &$output_circuits, a, b, key, scratch) + } } } }; diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared.rs index 07d6c46..c4a232f 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared.rs @@ -33,7 +33,7 @@ pub struct FheUintPrepared { impl FheUintPreparedFactory for Module where Self: Sized + GGSWPreparedFactory {} -pub trait GetGGSWBit { +pub trait GetGGSWBit: Sync { fn get_bit(&self, bit: usize) -> GGSWPrepared<&[u8], BE>; } @@ -222,7 +222,14 @@ impl BDDKeyPrepared } pub trait FheUintPrepare { - fn fhe_uint_prepare_tmp_bytes(&self, block_size: usize, extension_factor: usize, res_infos: &R, bits_infos: &A, bdd_infos: &B) -> usize + fn fhe_uint_prepare_tmp_bytes( + &self, + block_size: usize, + extension_factor: usize, + res_infos: &R, + bits_infos: &A, + bdd_infos: &B, + ) -> usize where R: GGSWInfos, A: GLWEInfos, @@ -258,6 +265,7 @@ pub trait FheUintPrepare { { self.fhe_uint_prepare_custom_multi_thread(1, res, bits, bit_start, bit_count, key, scratch) } + #[allow(clippy::too_many_arguments)] fn fhe_uint_prepare_custom_multi_thread( &self, threads: usize, @@ -279,7 +287,14 @@ where Self: LWEFromGLWE + CirtuitBootstrappingExecute + GGSWPreparedFactory, Scratch: ScratchTakeCore, { - fn fhe_uint_prepare_tmp_bytes(&self, block_size: usize, extension_factor: usize, res_infos: &R, bits_infos: &A, bdd_infos: &B) -> usize + fn fhe_uint_prepare_tmp_bytes( + &self, + block_size: usize, + extension_factor: usize, + res_infos: &R, + bits_infos: &A, + bdd_infos: &B, + ) -> usize where R: GGSWInfos, A: GLWEInfos, @@ -302,7 +317,7 @@ where bit_start: usize, bit_count: usize, key: &K, - mut scratch: &mut Scratch, + scratch: &mut Scratch, ) where DM: DataMut, DB: DataRef, @@ -318,16 +333,9 @@ where assert!(scratch.available() >= threads * scratch_thread_size); - // How many bits we need to process - let chunk_size: usize = bit_count.div_ceil(threads); // ceil division + let chunk_size: usize = bit_count.div_ceil(threads); - let mut scratches = Vec::new(); - for _ in 0..(threads - 1) { - let (tmp, scratch_new) = scratch.split_at_mut(scratch_thread_size); - scratch = scratch_new; - scratches.push(tmp); - } - scratches.push(scratch); + let (mut scratches, _) = scratch.split_mut(threads, scratch_thread_size); let ggsw_infos: &GGSWLayout = &res.ggsw_layout(); @@ -392,6 +400,7 @@ impl FheUintPrepared { module.fhe_uint_prepare_custom(self, other, bit_start, bit_end, key, scratch); } + #[allow(clippy::too_many_arguments)] pub fn prepare_custom_multi_thread( &mut self, threads: usize, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared_debug.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared_debug.rs index 182365b..d536d57 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared_debug.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared_debug.rs @@ -124,13 +124,13 @@ where DM: DataMut, DR0: DataRef, DR1: DataRef, - { - + { let (_, scratch_1) = scratch.take_ggsw(res); let (mut tmp_lwe, scratch_2) = scratch_1.take_lwe(bits); for (bit, dst) in res.bits.iter_mut().enumerate() { bits.get_bit_lwe(self, bit, &mut tmp_lwe, &key.ks, scratch_2); - key.cbt.execute_to_constant(self, dst, &tmp_lwe, 1, 1, scratch_2); + key.cbt + .execute_to_constant(self, dst, &tmp_lwe, 1, 1, scratch_2); } } } diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs index 618304f..af8c59b 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs @@ -1,4 +1,5 @@ use core::panic; +use std::thread; use itertools::Itertools; use poulpy_core::{ @@ -6,17 +7,20 @@ use poulpy_core::{ layouts::{GGSWInfos, GGSWPrepared, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, prepared::GGSWPreparedToRef}, }; use poulpy_hal::{ - api::{ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftBytesOf}, + api::{ + ScratchAvailable, ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, + VecZnxDftBytesOf, + }, layouts::{Backend, DataMut, Module, Scratch, VecZnxBig, ZnxZero}, }; use crate::tfhe::bdd_arithmetic::{GetGGSWBit, UnsignedInteger}; -pub trait BitCircuitInfo { +pub trait BitCircuitInfo: Sync { fn info(&self) -> (&[Node], usize); } -pub trait GetBitCircuitInfo { +pub trait GetBitCircuitInfo: Sync { fn input_size(&self) -> usize; fn output_size(&self) -> usize; fn get_circuit(&self, bit: usize) -> (&[Node], usize); @@ -49,9 +53,34 @@ where } } -pub trait ExecuteBDDCircuit { - fn execute_bdd_circuit(&self, out: &mut [GLWE], inputs: &G, circuit: &C, scratch: &mut Scratch) +pub trait ExecuteBDDCircuit { + fn execute_bdd_circuit_tmp_bytes(&self, res_infos: &R, state_size: usize, ggsw_infos: &G) -> usize where + R: GLWEInfos, + G: GGSWInfos; + + fn execute_bdd_circuit( + &self, + out: &mut [GLWE], + inputs: &G, + circuit: &C, + scratch: &mut Scratch, + ) where + G: GetGGSWBit + BitSize, + C: GetBitCircuitInfo, + O: DataMut, + { + self.execute_bdd_circuit_multi_thread(1, out, inputs, circuit, scratch); + } + + fn execute_bdd_circuit_multi_thread( + &self, + threads: usize, + out: &mut [GLWE], + inputs: &G, + circuit: &C, + scratch: &mut Scratch, + ) where G: GetGGSWBit + BitSize, C: GetBitCircuitInfo, O: DataMut; @@ -61,13 +90,27 @@ pub trait BitSize { fn bit_size(&self) -> usize; } -impl ExecuteBDDCircuit for Module +impl ExecuteBDDCircuit for Module where Self: Cmux + GLWECopy, Scratch: ScratchTakeCore, { - fn execute_bdd_circuit(&self, out: &mut [GLWE], inputs: &G, circuit: &C, scratch: &mut Scratch) + fn execute_bdd_circuit_tmp_bytes(&self, res_infos: &R, state_size: usize, ggsw_infos: &G) -> usize where + R: GLWEInfos, + G: GGSWInfos, + { + 2 * state_size * GLWE::bytes_of_from_infos(res_infos) + self.cmux_tmp_bytes(res_infos, res_infos, ggsw_infos) + } + + fn execute_bdd_circuit_multi_thread( + &self, + threads: usize, + out: &mut [GLWE], + inputs: &G, + circuit: &C, + scratch: &mut Scratch, + ) where G: GetGGSWBit + BitSize, C: GetBitCircuitInfo, O: DataMut, @@ -88,66 +131,43 @@ where ); } - for (i, out_i) in out.iter_mut().enumerate().take(circuit.output_size()) { - let (nodes, max_inter_state) = circuit.get_circuit(i); + let mut max_state_size = 0; + for i in 0..circuit.output_size() { + let (_, state_size) = circuit.get_circuit(i); + max_state_size = max_state_size.max(state_size) + } - if max_inter_state == 0 { - out_i.data_mut().zero(); - } else { - assert!(nodes.len().is_multiple_of(max_inter_state)); + let scratch_thread_size: usize = self.execute_bdd_circuit_tmp_bytes(&out[0], max_state_size, &inputs.get_bit(0)); - let (mut level, scratch_1) = scratch.take_glwe_slice(max_inter_state * 2, out_i); + assert!( + scratch.available() >= threads * scratch_thread_size, + "scratch.available(): {} < threads:{threads} * scratch_thread_size: {scratch_thread_size}", + scratch.available() + ); - level.iter_mut().for_each(|ct| ct.data_mut().zero()); + let (mut scratches, _) = scratch.split_mut(threads, scratch_thread_size); - // TODO: implement API on GLWE - level[1] - .data_mut() - .encode_coeff_i64(out_i.base2k().into(), 0, 2, 0, 1); + let chunk_size: usize = circuit.output_size().div_ceil(threads); - let mut level_ref = level.iter_mut().collect_vec(); - let (mut prev_level, mut next_level) = level_ref.split_at_mut(max_inter_state); + thread::scope(|scope| { + for (scratch_thread, out_chunk) in scratches + .iter_mut() + .zip(out[..circuit.output_size()].chunks_mut(chunk_size)) + { + // Capture chunk + thread scratch by move + scope.spawn(move || { + for (idx, out_i) in out_chunk.iter_mut().enumerate() { + let (nodes, state_size) = circuit.get_circuit(idx); - let (all_but_last, last) = nodes.split_at(nodes.len() - max_inter_state); - - for nodes_lvl in all_but_last.chunks_exact(max_inter_state) { - for (j, node) in nodes_lvl.iter().enumerate() { - match node { - Node::Cmux(in_idx, hi_idx, lo_idx) => { - self.cmux( - next_level[j], - prev_level[*hi_idx], - prev_level[*lo_idx], - &inputs.get_bit(*in_idx), - scratch_1, - ); - } - Node::Copy => self.glwe_copy(next_level[j], prev_level[j]), /* Update BDD circuits to order Cmux -> Copy -> None so that mem swap can be used */ - Node::None => {} + if state_size == 0 { + out_i.data_mut().zero(); + } else { + eval_level(self, out_i, inputs, nodes, state_size, *scratch_thread); } } - - (prev_level, next_level) = (next_level, prev_level); - } - - // Last chunck of max_inter_state Nodes is always structured as - // [CMUX, NONE, NONE, ..., NONE] - match &last[0] { - Node::Cmux(in_idx, hi_idx, lo_idx) => { - self.cmux( - out_i, - prev_level[*hi_idx], - prev_level[*lo_idx], - &inputs.get_bit(*in_idx), - scratch_1, - ); - } - _ => { - panic!("invalid last node, should be CMUX") - } - } + }); } - } + }); for out_i in out.iter_mut().skip(circuit.output_size()) { out_i.data_mut().zero(); @@ -155,6 +175,74 @@ where } } +fn eval_level( + module: &M, + res: &mut R, + inputs: &G, + nodes: &[Node], + state_size: usize, + scratch: &mut Scratch, +) where + M: Cmux + GLWECopy, + R: GLWEToMut, + G: GetGGSWBit + BitSize, + Scratch: ScratchTakeCore, +{ + assert!(nodes.len().is_multiple_of(state_size)); + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + + let (mut level, scratch_1) = scratch.take_glwe_slice(state_size * 2, res); + + level.iter_mut().for_each(|ct| ct.data_mut().zero()); + + // TODO: implement API on GLWE + level[1] + .data_mut() + .encode_coeff_i64(res.base2k().into(), 0, 2, 0, 1); + + let mut level_ref: Vec<&mut GLWE<&mut [u8]>> = level.iter_mut().collect_vec(); + let (mut prev_level, mut next_level) = level_ref.split_at_mut(state_size); + + let (all_but_last, last) = nodes.split_at(nodes.len() - state_size); + + for nodes_lvl in all_but_last.chunks_exact(state_size) { + for (j, node) in nodes_lvl.iter().enumerate() { + match node { + Node::Cmux(in_idx, hi_idx, lo_idx) => { + module.cmux( + next_level[j], + prev_level[*hi_idx], + prev_level[*lo_idx], + &inputs.get_bit(*in_idx), + scratch_1, + ); + } + Node::Copy => module.glwe_copy(next_level[j], prev_level[j]), /* Update BDD circuits to order Cmux -> Copy -> None so that mem swap can be used */ + Node::None => {} + } + } + + (prev_level, next_level) = (next_level, prev_level); + } + + // Last chunck of max_inter_state Nodes is always structured as + // [CMUX, NONE, NONE, ..., NONE] + match &last[0] { + Node::Cmux(in_idx, hi_idx, lo_idx) => { + module.cmux( + res, + prev_level[*hi_idx], + prev_level[*lo_idx], + &inputs.get_bit(*in_idx), + scratch_1, + ); + } + _ => { + panic!("invalid last node, should be CMUX") + } + } +} + impl BitCircuit { pub const fn new(nodes: [Node; N], max_inter_state: usize) -> Self { Self { diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs index dbd52cf..e3555f8 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs @@ -136,17 +136,21 @@ where pub(crate) ks: GLWEToLWEKeyPrepared, } -impl BDDKeyInfos for BDDKeyPrepared{ +impl BDDKeyInfos for BDDKeyPrepared { fn cbt_infos(&self) -> CircuitBootstrappingKeyLayout { - CircuitBootstrappingKeyLayout { layout_brk: self.cbt.brk_infos(), layout_atk: self.cbt.atk_infos(), layout_tsk: self.cbt.tsk_infos() } + CircuitBootstrappingKeyLayout { + layout_brk: self.cbt.brk_infos(), + layout_atk: self.cbt.atk_infos(), + layout_tsk: self.cbt.tsk_infos(), + } } fn ks_infos(&self) -> GLWEToLWEKeyLayout { - GLWEToLWEKeyLayout{ + GLWEToLWEKeyLayout { n: self.ks.n(), base2k: self.ks.base2k(), k: self.ks.k(), rank_in: self.ks.rank_in(), - dnum: self.ks.dnum() + dnum: self.ks.dnum(), } } } diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/add.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/add.rs index 7810dd3..d41f18b 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/add.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/add.rs @@ -31,7 +31,7 @@ where + BDDKeyPreparedFactory + GGSWNoise + FheUintPrepare - + ExecuteBDDCircuit2WTo1W + + ExecuteBDDCircuit2WTo1W + GLWEEncryptSk, BlindRotationKey, BRA>: BlindRotationKeyFactory, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/and.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/and.rs index ee468f6..0c74fe9 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/and.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/and.rs @@ -31,7 +31,7 @@ where + BDDKeyPreparedFactory + GGSWNoise + FheUintPrepare - + ExecuteBDDCircuit2WTo1W + + ExecuteBDDCircuit2WTo1W + GLWEEncryptSk, BlindRotationKey, BRA>: BlindRotationKeyFactory, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs index 488967a..08de35a 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs @@ -133,8 +133,8 @@ impl TestContext { } } -pub(crate) const TEST_N_GLWE: u32 = 1024; -pub(crate) const TEST_N_LWE: u32 = 574; +pub(crate) const TEST_N_GLWE: u32 = 256; +pub(crate) const TEST_N_LWE: u32 = 77; pub(crate) const TEST_BASE2K: u32 = 13; pub(crate) const TEST_K_GLWE: u32 = 26; pub(crate) const TEST_K_GGSW: u32 = 39; diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/or.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/or.rs index 91ce50f..2564ae7 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/or.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/or.rs @@ -31,7 +31,7 @@ where + BDDKeyPreparedFactory + GGSWNoise + FheUintPrepare - + ExecuteBDDCircuit2WTo1W + + ExecuteBDDCircuit2WTo1W + GLWEEncryptSk, BlindRotationKey, BRA>: BlindRotationKeyFactory, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/prepare.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/prepare.rs index b7ed299..c99dd2a 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/prepare.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/prepare.rs @@ -31,7 +31,7 @@ where + BDDKeyPreparedFactory + GGSWNoise + FheUintPrepare - + ExecuteBDDCircuit2WTo1W + + ExecuteBDDCircuit2WTo1W + GLWEEncryptSk, BlindRotationKey, BRA>: BlindRotationKeyFactory, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sll.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sll.rs index b37d816..b20ad5a 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sll.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sll.rs @@ -31,7 +31,7 @@ where + BDDKeyPreparedFactory + GGSWNoise + FheUintPrepare - + ExecuteBDDCircuit2WTo1W + + ExecuteBDDCircuit2WTo1W + GLWEEncryptSk, BlindRotationKey, BRA>: BlindRotationKeyFactory, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/slt.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/slt.rs index dca9f6f..6e55738 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/slt.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/slt.rs @@ -31,7 +31,7 @@ where + BDDKeyPreparedFactory + GGSWNoise + FheUintPrepare - + ExecuteBDDCircuit2WTo1W + + ExecuteBDDCircuit2WTo1W + GLWEEncryptSk, BlindRotationKey, BRA>: BlindRotationKeyFactory, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sltu.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sltu.rs index aa7317d..4195fbe 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sltu.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sltu.rs @@ -31,7 +31,7 @@ where + BDDKeyPreparedFactory + GGSWNoise + FheUintPrepare - + ExecuteBDDCircuit2WTo1W + + ExecuteBDDCircuit2WTo1W + GLWEEncryptSk, BlindRotationKey, BRA>: BlindRotationKeyFactory, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sra.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sra.rs index 730c539..364c880 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sra.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sra.rs @@ -31,7 +31,7 @@ where + BDDKeyPreparedFactory + GGSWNoise + FheUintPrepare - + ExecuteBDDCircuit2WTo1W + + ExecuteBDDCircuit2WTo1W + GLWEEncryptSk, BlindRotationKey, BRA>: BlindRotationKeyFactory, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/srl.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/srl.rs index 61c3fb6..ce9a65d 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/srl.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/srl.rs @@ -31,7 +31,7 @@ where + BDDKeyPreparedFactory + GGSWNoise + FheUintPrepare - + ExecuteBDDCircuit2WTo1W + + ExecuteBDDCircuit2WTo1W + GLWEEncryptSk, BlindRotationKey, BRA>: BlindRotationKeyFactory, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sub.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sub.rs index 3a73fbf..445279d 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sub.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/sub.rs @@ -31,7 +31,7 @@ where + BDDKeyPreparedFactory + GGSWNoise + FheUintPrepare - + ExecuteBDDCircuit2WTo1W + + ExecuteBDDCircuit2WTo1W + GLWEEncryptSk, BlindRotationKey, BRA>: BlindRotationKeyFactory, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/xor.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/xor.rs index 40c05b8..8db9fcd 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/xor.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/xor.rs @@ -31,7 +31,7 @@ where + BDDKeyPreparedFactory + GGSWNoise + FheUintPrepare - + ExecuteBDDCircuit2WTo1W + + ExecuteBDDCircuit2WTo1W + GLWEEncryptSk, BlindRotationKey, BRA>: BlindRotationKeyFactory, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs index 17371b1..62b706a 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs @@ -8,7 +8,8 @@ use poulpy_hal::{ use poulpy_core::{ GGSWFromGGLWE, GLWEDecrypt, GLWEPacking, GLWERotate, GLWETrace, ScratchTakeCore, layouts::{ - Dsize, GGLWE, GGLWEInfos, GGLWELayout, GGLWEPreparedToRef, GGSWInfos, GGSWToMut, GLWEAutomorphismKeyHelper, GLWEInfos, GLWESecretPreparedFactory, GLWEToMut, GLWEToRef, GetGaloisElement, LWEInfos, LWEToRef, Rank + Dsize, GGLWE, GGLWEInfos, GGLWELayout, GGLWEPreparedToRef, GGSWInfos, GGSWToMut, GLWEAutomorphismKeyHelper, GLWEInfos, + GLWESecretPreparedFactory, GLWEToMut, GLWEToRef, GetGaloisElement, LWEInfos, LWEToRef, Rank, }, }; @@ -131,14 +132,13 @@ where R: GGSWInfos, A: CircuitBootstrappingKeyInfos, { - let gglwe_infos: GGLWELayout = GGLWELayout { n: res_infos.n(), base2k: res_infos.base2k(), k: res_infos.k(), dnum: res_infos.dnum(), dsize: Dsize(1), - rank_in: res_infos.rank().max(Rank(1)).into(), + rank_in: res_infos.rank().max(Rank(1)), rank_out: res_infos.rank(), }; @@ -149,7 +149,9 @@ where &cbt_infos.brk_infos(), ) .max(self.glwe_trace_tmp_bytes(res_infos, res_infos, &cbt_infos.atk_infos())) - .max(self.ggsw_from_gglwe_tmp_bytes(res_infos, &cbt_infos.tsk_infos())) + GLWE::bytes_of_from_infos(res_infos) + GGLWE::bytes_of_from_infos(&gglwe_infos) + .max(self.ggsw_from_gglwe_tmp_bytes(res_infos, &cbt_infos.tsk_infos())) + + GLWE::bytes_of_from_infos(res_infos) + + GGLWE::bytes_of_from_infos(&gglwe_infos) } fn circuit_bootstrapping_execute_to_constant( @@ -164,9 +166,10 @@ where R: GGSWToMut + GGSWInfos, L: LWEToRef + LWEInfos, D: DataRef, - { - - assert!(scratch.available() >= self.circuit_bootstrapping_execute_tmp_bytes(key.block_size(), extension_factor, res, key)); + { + assert!( + scratch.available() >= self.circuit_bootstrapping_execute_tmp_bytes(key.block_size(), extension_factor, res, key) + ); circuit_bootstrap_core( false, @@ -194,9 +197,10 @@ where R: GGSWToMut + GGSWInfos, L: LWEToRef + LWEInfos, D: DataRef, - { - - assert!(scratch.available() >= self.circuit_bootstrapping_execute_tmp_bytes(key.block_size(), extension_factor, res, key)); + { + assert!( + scratch.available() >= self.circuit_bootstrapping_execute_tmp_bytes(key.block_size(), extension_factor, res, key) + ); circuit_bootstrap_core( true, @@ -239,7 +243,7 @@ pub fn circuit_bootstrap_core( + ModuleLogN, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchTakeCore, -{ +{ let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); let lwe: &LWE<&[u8]> = &lwe.to_ref();