mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
Merge pull request #112 from phantomzone-org/bdd_multi_thread
Bdd multi thread
This commit is contained in:
7
Cargo.lock
generated
7
Cargo.lock
generated
@@ -323,6 +323,12 @@ version = "11.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9"
|
||||
|
||||
[[package]]
|
||||
name = "paste"
|
||||
version = "1.0.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
|
||||
|
||||
[[package]]
|
||||
name = "plotters"
|
||||
version = "0.3.7"
|
||||
@@ -406,6 +412,7 @@ dependencies = [
|
||||
"byteorder",
|
||||
"criterion",
|
||||
"itertools 0.14.0",
|
||||
"paste",
|
||||
"poulpy-backend",
|
||||
"poulpy-core",
|
||||
"poulpy-hal",
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use poulpy_hal::{
|
||||
api::{ModuleN, ScratchAvailable, ScratchTakeBasic, SvpPPolBytesOf, VecZnxDftBytesOf, VmpPMatBytesOf},
|
||||
api::{ModuleN, ScratchAvailable, ScratchFromBytes, ScratchTakeBasic, SvpPPolBytesOf, VecZnxDftBytesOf, VmpPMatBytesOf},
|
||||
layouts::{Backend, Scratch},
|
||||
};
|
||||
|
||||
@@ -7,7 +7,7 @@ use crate::{
|
||||
dist::Distribution,
|
||||
layouts::{
|
||||
Degree, GGLWE, GGLWEInfos, GGLWELayout, GGSW, GGSWInfos, GLWE, GLWEAutomorphismKey, GLWEInfos, GLWEPlaintext,
|
||||
GLWEPrepared, GLWEPublicKey, GLWESecret, GLWESecretTensor, GLWESwitchingKey, GLWETensorKey, Rank,
|
||||
GLWEPrepared, GLWEPublicKey, GLWESecret, GLWESecretTensor, GLWESwitchingKey, GLWETensorKey, LWE, LWEInfos, Rank,
|
||||
prepared::{
|
||||
GGLWEPrepared, GGSWPrepared, GLWEAutomorphismKeyPrepared, GLWEPublicKeyPrepared, GLWESecretPrepared,
|
||||
GLWESwitchingKeyPrepared, GLWETensorKeyPrepared,
|
||||
@@ -17,8 +17,23 @@ use crate::{
|
||||
|
||||
pub trait ScratchTakeCore<B: Backend>
|
||||
where
|
||||
Self: ScratchTakeBasic + ScratchAvailable,
|
||||
Self: ScratchTakeBasic + ScratchAvailable + ScratchFromBytes<B>,
|
||||
{
|
||||
fn take_lwe<A>(&mut self, infos: &A) -> (LWE<&mut [u8]>, &mut Self)
|
||||
where
|
||||
A: LWEInfos,
|
||||
{
|
||||
let (data, scratch) = self.take_zn(infos.n().into(), 1, infos.size());
|
||||
(
|
||||
LWE {
|
||||
k: infos.k(),
|
||||
base2k: infos.base2k(),
|
||||
data,
|
||||
},
|
||||
scratch,
|
||||
)
|
||||
}
|
||||
|
||||
fn take_glwe<A>(&mut self, infos: &A) -> (GLWE<&mut [u8]>, &mut Self)
|
||||
where
|
||||
A: GLWEInfos,
|
||||
@@ -367,4 +382,4 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ScratchTakeCore<B> for Scratch<B> where Self: ScratchTakeBasic + ScratchAvailable {}
|
||||
impl<B: Backend> ScratchTakeCore<B> for Scratch<B> where Self: ScratchTakeBasic + ScratchAvailable + ScratchFromBytes<B> {}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::{
|
||||
api::{ModuleN, SvpPPolBytesOf, VecZnxBigBytesOf, VecZnxDftBytesOf, VmpPMatBytesOf},
|
||||
layouts::{Backend, MatZnx, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat},
|
||||
layouts::{Backend, MatZnx, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, Zn},
|
||||
};
|
||||
|
||||
/// Allocates a new [crate::layouts::ScratchOwned] of `size` aligned bytes.
|
||||
@@ -28,7 +28,29 @@ pub trait TakeSlice {
|
||||
fn take_slice<T>(&mut self, len: usize) -> (&mut [T], &mut Self);
|
||||
}
|
||||
|
||||
impl<B: Backend> ScratchTakeBasic for Scratch<B> where Self: TakeSlice {}
|
||||
impl<BE: Backend> Scratch<BE>
|
||||
where
|
||||
Self: TakeSlice + ScratchAvailable + ScratchFromBytes<BE>,
|
||||
{
|
||||
pub fn split_at_mut(&mut self, len: usize) -> (&mut Scratch<BE>, &mut Self) {
|
||||
let (take_slice, rem_slice) = self.take_slice(len);
|
||||
(Self::from_bytes(take_slice), rem_slice)
|
||||
}
|
||||
|
||||
pub fn split_mut(&mut self, n: usize, len: usize) -> (Vec<&mut Scratch<BE>>, &mut Self) {
|
||||
assert!(self.available() >= n * len);
|
||||
let mut scratches: Vec<&mut Scratch<BE>> = Vec::with_capacity(n);
|
||||
let mut scratch: &mut Scratch<BE> = self;
|
||||
for _ in 0..n {
|
||||
let (tmp, scratch_new) = scratch.split_at_mut(len);
|
||||
scratch = scratch_new;
|
||||
scratches.push(tmp);
|
||||
}
|
||||
(scratches, scratch)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ScratchTakeBasic for Scratch<B> where Self: TakeSlice + ScratchFromBytes<B> {}
|
||||
|
||||
pub trait ScratchTakeBasic
|
||||
where
|
||||
@@ -47,6 +69,11 @@ where
|
||||
(SvpPPol::from_data(take_slice, module.n(), cols), rem_slice)
|
||||
}
|
||||
|
||||
fn take_zn(&mut self, n: usize, cols: usize, size: usize) -> (Zn<&mut [u8]>, &mut Self) {
|
||||
let (take_slice, rem_slice) = self.take_slice(Zn::bytes_of(n, cols, size));
|
||||
(Zn::from_data(take_slice, n, cols, size), rem_slice)
|
||||
}
|
||||
|
||||
fn take_vec_znx(&mut self, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) {
|
||||
let (take_slice, rem_slice) = self.take_slice(VecZnx::bytes_of(n, cols, size));
|
||||
(VecZnx::from_data(take_slice, n, cols, size), rem_slice)
|
||||
|
||||
@@ -28,8 +28,8 @@ pub use zn::*;
|
||||
pub use znx_base::*;
|
||||
|
||||
pub trait Data = PartialEq + Eq + Sized + Default;
|
||||
pub trait DataRef = Data + AsRef<[u8]>;
|
||||
pub trait DataMut = DataRef + AsMut<[u8]>;
|
||||
pub trait DataRef = Data + AsRef<[u8]> + Sync;
|
||||
pub trait DataMut = DataRef + AsMut<[u8]> + Send;
|
||||
|
||||
pub trait ToOwnedDeep {
|
||||
type Owned;
|
||||
|
||||
@@ -13,7 +13,7 @@ use crate::{
|
||||
};
|
||||
|
||||
#[allow(clippy::missing_safety_doc)]
|
||||
pub trait Backend: Sized {
|
||||
pub trait Backend: Sized + Sync + Send {
|
||||
type ScalarBig: Copy + Zero + Display + Debug + Pod;
|
||||
type ScalarPrep: Copy + Zero + Display + Debug + Pod;
|
||||
type Handle: 'static;
|
||||
|
||||
@@ -17,7 +17,7 @@ criterion = {workspace = true}
|
||||
itertools = "0.14.0"
|
||||
byteorder = "1.5.0"
|
||||
rand = "0.9.2"
|
||||
|
||||
paste = "1.0.15"
|
||||
|
||||
[[bench]]
|
||||
name = "circuit_bootstrapping"
|
||||
|
||||
@@ -13,17 +13,14 @@ use crate::tfhe::bdd_arithmetic::{
|
||||
BitSize, ExecuteBDDCircuit, FheUint, FheUintPrepared, GetBitCircuitInfo, GetGGSWBit, UnsignedInteger, circuits,
|
||||
};
|
||||
|
||||
impl<T: UnsignedInteger, BE: Backend> ExecuteBDDCircuit2WTo1W<T, BE> for Module<BE> where
|
||||
Self: Sized + ExecuteBDDCircuit<T, BE> + GLWEPacking<BE> + GLWECopy
|
||||
{
|
||||
}
|
||||
impl<BE: Backend> ExecuteBDDCircuit2WTo1W<BE> for Module<BE> where Self: Sized + ExecuteBDDCircuit<BE> + GLWEPacking<BE> + GLWECopy
|
||||
{}
|
||||
|
||||
pub trait ExecuteBDDCircuit2WTo1W<T: UnsignedInteger, BE: Backend>
|
||||
pub trait ExecuteBDDCircuit2WTo1W<BE: Backend>
|
||||
where
|
||||
Self: Sized + ModuleLogN + ExecuteBDDCircuit<T, BE> + GLWEPacking<BE> + GLWECopy,
|
||||
Self: Sized + ModuleLogN + ExecuteBDDCircuit<BE> + GLWEPacking<BE> + GLWECopy,
|
||||
{
|
||||
/// Operations Z x Z -> Z
|
||||
fn execute_bdd_circuit_2w_to_1w<R, C, A, B, K, H>(
|
||||
fn execute_bdd_circuit_2w_to_1w<R, C, A, B, K, H, T>(
|
||||
&self,
|
||||
out: &mut FheUint<R, T>,
|
||||
circuit: &C,
|
||||
@@ -32,7 +29,32 @@ where
|
||||
key: &H,
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
C: GetBitCircuitInfo<T>,
|
||||
T: UnsignedInteger,
|
||||
C: GetBitCircuitInfo,
|
||||
R: DataMut,
|
||||
A: DataRef,
|
||||
B: DataRef,
|
||||
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
|
||||
H: GLWEAutomorphismKeyHelper<K, BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
self.execute_bdd_circuit_2w_to_1w_multi_thread(1, out, circuit, a, b, key, scratch);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// Operations Z x Z -> Z
|
||||
fn execute_bdd_circuit_2w_to_1w_multi_thread<R, C, A, B, K, H, T>(
|
||||
&self,
|
||||
threads: usize,
|
||||
out: &mut FheUint<R, T>,
|
||||
circuit: &C,
|
||||
a: &FheUintPrepared<A, T, BE>,
|
||||
b: &FheUintPrepared<B, T, BE>,
|
||||
key: &H,
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
T: UnsignedInteger,
|
||||
C: GetBitCircuitInfo,
|
||||
R: DataMut,
|
||||
A: DataRef,
|
||||
B: DataRef,
|
||||
@@ -50,7 +72,7 @@ where
|
||||
let (mut out_bits, scratch_1) = scratch.take_glwe_slice(T::BITS as usize, out);
|
||||
|
||||
// Evaluates out[i] = circuit[i](a, b)
|
||||
self.execute_bdd_circuit(&mut out_bits, &helper, circuit, scratch_1);
|
||||
self.execute_bdd_circuit_multi_thread(threads, &mut out_bits, &helper, circuit, scratch_1);
|
||||
|
||||
// Repacks the bits
|
||||
out.pack(self, out_bits, key, scratch_1);
|
||||
@@ -100,22 +122,43 @@ where
|
||||
#[macro_export]
|
||||
macro_rules! define_bdd_2w_to_1w_trait {
|
||||
($(#[$meta:meta])* $vis:vis $trait_name:ident, $method_name:ident) => {
|
||||
$(#[$meta])*
|
||||
$vis trait $trait_name<T: UnsignedInteger, BE: Backend> {
|
||||
fn $method_name<A, M, K, H, B>(
|
||||
&mut self,
|
||||
module: &M,
|
||||
a: &FheUintPrepared<A, T, BE>,
|
||||
b: &FheUintPrepared<B, T, BE>,
|
||||
key: &H,
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
M: ExecuteBDDCircuit2WTo1W<T, BE>,
|
||||
A: DataRef,
|
||||
B: DataRef,
|
||||
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
|
||||
H: GLWEAutomorphismKeyHelper<K, BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>;
|
||||
paste::paste! {
|
||||
$(#[$meta])*
|
||||
$vis trait $trait_name<T: UnsignedInteger, BE: Backend> {
|
||||
|
||||
/// Single-threaded version
|
||||
fn $method_name<A, M, K, H, B>(
|
||||
&mut self,
|
||||
module: &M,
|
||||
a: &FheUintPrepared<A, T, BE>,
|
||||
b: &FheUintPrepared<B, T, BE>,
|
||||
key: &H,
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
M: ExecuteBDDCircuit2WTo1W<BE>,
|
||||
A: DataRef,
|
||||
B: DataRef,
|
||||
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
|
||||
H: GLWEAutomorphismKeyHelper<K, BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>;
|
||||
|
||||
/// Multithreaded version – same vis, method_name + "_multi_thread"
|
||||
fn [<$method_name _multi_thread>]<A, M, K, H, B>(
|
||||
&mut self,
|
||||
threads: usize,
|
||||
module: &M,
|
||||
a: &FheUintPrepared<A, T, BE>,
|
||||
b: &FheUintPrepared<B, T, BE>,
|
||||
key: &H,
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
M: ExecuteBDDCircuit2WTo1W<BE>,
|
||||
A: DataRef,
|
||||
B: DataRef,
|
||||
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
|
||||
H: GLWEAutomorphismKeyHelper<K, BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>;
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
@@ -123,23 +166,45 @@ macro_rules! define_bdd_2w_to_1w_trait {
|
||||
#[macro_export]
|
||||
macro_rules! impl_bdd_2w_to_1w_trait {
|
||||
($trait_name:ident, $method_name:ident, $ty:ty, $circuit_ty:ty, $output_circuits:path) => {
|
||||
impl<D: DataMut, BE: Backend> $trait_name<$ty, BE> for FheUint<D, $ty> {
|
||||
fn $method_name<A, M, K, H, B>(
|
||||
&mut self,
|
||||
module: &M,
|
||||
a: &FheUintPrepared<A, $ty, BE>,
|
||||
b: &FheUintPrepared<B, $ty, BE>,
|
||||
key: &H,
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
M: ExecuteBDDCircuit2WTo1W<$ty, BE>,
|
||||
A: DataRef,
|
||||
B: DataRef,
|
||||
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
|
||||
H: GLWEAutomorphismKeyHelper<K, BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
module.execute_bdd_circuit_2w_to_1w(self, &$output_circuits, a, b, key, scratch)
|
||||
paste::paste! {
|
||||
impl<D: DataMut, BE: Backend> $trait_name<$ty, BE> for FheUint<D, $ty> {
|
||||
|
||||
fn $method_name<A, M, K, H, B>(
|
||||
&mut self,
|
||||
module: &M,
|
||||
a: &FheUintPrepared<A, $ty, BE>,
|
||||
b: &FheUintPrepared<B, $ty, BE>,
|
||||
key: &H,
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
M: ExecuteBDDCircuit2WTo1W<BE>,
|
||||
A: DataRef,
|
||||
B: DataRef,
|
||||
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
|
||||
H: GLWEAutomorphismKeyHelper<K, BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
module.execute_bdd_circuit_2w_to_1w(self, &$output_circuits, a, b, key, scratch)
|
||||
}
|
||||
|
||||
fn [<$method_name _multi_thread>]<A, M, K, H, B>(
|
||||
&mut self,
|
||||
threads: usize,
|
||||
module: &M,
|
||||
a: &FheUintPrepared<A, $ty, BE>,
|
||||
b: &FheUintPrepared<B, $ty, BE>,
|
||||
key: &H,
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
M: ExecuteBDDCircuit2WTo1W<BE>,
|
||||
A: DataRef,
|
||||
B: DataRef,
|
||||
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
|
||||
H: GLWEAutomorphismKeyHelper<K, BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
module.execute_bdd_circuit_2w_to_1w_multi_thread(threads, self, &$output_circuits, a, b, key, scratch)
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
use std::marker::PhantomData;
|
||||
use std::thread;
|
||||
|
||||
use poulpy_core::layouts::{
|
||||
Base2K, Dnum, Dsize, GGSWInfos, GGSWPreparedFactory, GLWEInfos, LWEInfos, Rank, TorusPrecision, prepared::GGSWPrepared,
|
||||
};
|
||||
use poulpy_core::layouts::{
|
||||
GGLWEInfos, GGLWEPreparedToRef, GGSWPreparedToMut, GGSWPreparedToRef, GLWEAutomorphismKeyHelper, GetGaloisElement, LWE,
|
||||
GGLWEInfos, GGLWEPreparedToRef, GGSW, GGSWLayout, GGSWPreparedToMut, GGSWPreparedToRef, GLWEAutomorphismKeyHelper,
|
||||
GetGaloisElement, LWE,
|
||||
};
|
||||
use poulpy_core::{GLWECopy, GLWEDecrypt, GLWEPacking, LWEFromGLWE};
|
||||
|
||||
use poulpy_core::{GGSWEncryptSk, ScratchTakeCore, layouts::GLWESecretPreparedToRef};
|
||||
use poulpy_hal::api::ModuleLogN;
|
||||
use poulpy_hal::api::{ModuleLogN, ScratchAvailable, ScratchFromBytes};
|
||||
use poulpy_hal::layouts::{Backend, Data, DataRef, Module};
|
||||
|
||||
use poulpy_hal::{
|
||||
@@ -21,7 +23,7 @@ use poulpy_hal::{
|
||||
use crate::tfhe::bdd_arithmetic::{BDDKey, BDDKeyHelper, BDDKeyInfos, BDDKeyPrepared, BDDKeyPreparedFactory, FheUint, ToBits};
|
||||
use crate::tfhe::bdd_arithmetic::{Cmux, FromBits, ScratchTakeBDD, UnsignedInteger};
|
||||
use crate::tfhe::blind_rotation::BlindRotationAlgo;
|
||||
use crate::tfhe::circuit_bootstrapping::CirtuitBootstrappingExecute;
|
||||
use crate::tfhe::circuit_bootstrapping::{CircuitBootstrappingKeyInfos, CirtuitBootstrappingExecute};
|
||||
|
||||
/// A prepared FHE ciphertext encrypting the bits of an [UnsignedInteger].
|
||||
pub struct FheUintPrepared<D: Data, T: UnsignedInteger, B: Backend> {
|
||||
@@ -31,7 +33,7 @@ pub struct FheUintPrepared<D: Data, T: UnsignedInteger, B: Backend> {
|
||||
|
||||
impl<T: UnsignedInteger, BE: Backend> FheUintPreparedFactory<T, BE> for Module<BE> where Self: Sized + GGSWPreparedFactory<BE> {}
|
||||
|
||||
pub trait GetGGSWBit<BE: Backend> {
|
||||
pub trait GetGGSWBit<BE: Backend>: Sync {
|
||||
fn get_bit(&self, bit: usize) -> GGSWPrepared<&[u8], BE>;
|
||||
}
|
||||
|
||||
@@ -219,12 +221,20 @@ impl<D: DataMut, BRA: BlindRotationAlgo, BE: Backend> BDDKeyPrepared<D, BRA, BE>
|
||||
}
|
||||
}
|
||||
|
||||
pub trait FheUintPrepare<BRA: BlindRotationAlgo, T: UnsignedInteger, BE: Backend> {
|
||||
fn fhe_uint_prepare_tmp_bytes<R, A>(&self, block_size: usize, extension_factor: usize, res_infos: &R, infos: &A) -> usize
|
||||
pub trait FheUintPrepare<BRA: BlindRotationAlgo, BE: Backend> {
|
||||
fn fhe_uint_prepare_tmp_bytes<R, A, B>(
|
||||
&self,
|
||||
block_size: usize,
|
||||
extension_factor: usize,
|
||||
res_infos: &R,
|
||||
bits_infos: &A,
|
||||
bdd_infos: &B,
|
||||
) -> usize
|
||||
where
|
||||
R: GGSWInfos,
|
||||
A: BDDKeyInfos;
|
||||
fn fhe_uint_prepare<DM, DB, DK, K>(
|
||||
A: GLWEInfos,
|
||||
B: BDDKeyInfos;
|
||||
fn fhe_uint_prepare<DM, DB, DK, K, T: UnsignedInteger>(
|
||||
&self,
|
||||
res: &mut FheUintPrepared<DM, T, BE>,
|
||||
bits: &FheUint<DB, T>,
|
||||
@@ -234,79 +244,120 @@ pub trait FheUintPrepare<BRA: BlindRotationAlgo, T: UnsignedInteger, BE: Backend
|
||||
DM: DataMut,
|
||||
DB: DataRef,
|
||||
DK: DataRef,
|
||||
K: BDDKeyHelper<DK, BRA, BE>;
|
||||
fn fhe_uint_prepare_custom<DM, DB, DK, K>(
|
||||
K: BDDKeyHelper<DK, BRA, BE> + BDDKeyInfos,
|
||||
Scratch<BE>: ScratchFromBytes<BE>,
|
||||
{
|
||||
self.fhe_uint_prepare_custom(res, bits, 0, T::BITS as usize, key, scratch);
|
||||
}
|
||||
fn fhe_uint_prepare_custom<DM, DB, DK, K, T: UnsignedInteger>(
|
||||
&self,
|
||||
res: &mut FheUintPrepared<DM, T, BE>,
|
||||
bits: &FheUint<DB, T>,
|
||||
bit_start: usize,
|
||||
bit_end: usize,
|
||||
bit_count: usize,
|
||||
key: &K,
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
DM: DataMut,
|
||||
DB: DataRef,
|
||||
DK: DataRef,
|
||||
K: BDDKeyHelper<DK, BRA, BE>;
|
||||
K: BDDKeyHelper<DK, BRA, BE> + BDDKeyInfos,
|
||||
{
|
||||
self.fhe_uint_prepare_custom_multi_thread(1, res, bits, bit_start, bit_count, key, scratch)
|
||||
}
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn fhe_uint_prepare_custom_multi_thread<DM, DB, DK, K, T: UnsignedInteger>(
|
||||
&self,
|
||||
threads: usize,
|
||||
res: &mut FheUintPrepared<DM, T, BE>,
|
||||
bits: &FheUint<DB, T>,
|
||||
bit_start: usize,
|
||||
bit_count: usize,
|
||||
key: &K,
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
DM: DataMut,
|
||||
DB: DataRef,
|
||||
DK: DataRef,
|
||||
K: BDDKeyHelper<DK, BRA, BE> + BDDKeyInfos;
|
||||
}
|
||||
|
||||
impl<BRA: BlindRotationAlgo, BE: Backend, T: UnsignedInteger> FheUintPrepare<BRA, T, BE> for Module<BE>
|
||||
impl<BRA: BlindRotationAlgo, BE: Backend> FheUintPrepare<BRA, BE> for Module<BE>
|
||||
where
|
||||
Self: LWEFromGLWE<BE> + CirtuitBootstrappingExecute<BRA, BE> + GGSWPreparedFactory<BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
fn fhe_uint_prepare_tmp_bytes<R, A>(&self, block_size: usize, extension_factor: usize, res_infos: &R, bdd_infos: &A) -> usize
|
||||
fn fhe_uint_prepare_tmp_bytes<R, A, B>(
|
||||
&self,
|
||||
block_size: usize,
|
||||
extension_factor: usize,
|
||||
res_infos: &R,
|
||||
bits_infos: &A,
|
||||
bdd_infos: &B,
|
||||
) -> usize
|
||||
where
|
||||
R: GGSWInfos,
|
||||
A: BDDKeyInfos,
|
||||
A: GLWEInfos,
|
||||
B: BDDKeyInfos,
|
||||
{
|
||||
self.circuit_bootstrapping_execute_tmp_bytes(
|
||||
block_size,
|
||||
extension_factor,
|
||||
res_infos,
|
||||
&bdd_infos.cbt_infos(),
|
||||
)
|
||||
) + GGSW::bytes_of_from_infos(res_infos)
|
||||
+ LWE::bytes_of_from_infos(bits_infos)
|
||||
}
|
||||
|
||||
fn fhe_uint_prepare<DM, DB, DK, K>(
|
||||
&self,
|
||||
res: &mut FheUintPrepared<DM, T, BE>,
|
||||
bits: &FheUint<DB, T>,
|
||||
key: &K,
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
DM: DataMut,
|
||||
DB: DataRef,
|
||||
DK: DataRef,
|
||||
K: BDDKeyHelper<DK, BRA, BE>,
|
||||
{
|
||||
self.fhe_uint_prepare_custom(res, bits, 0, T::BITS as usize, key, scratch);
|
||||
}
|
||||
|
||||
fn fhe_uint_prepare_custom<DM, DB, DK, K>(
|
||||
fn fhe_uint_prepare_custom_multi_thread<DM, DB, DK, K, T: UnsignedInteger>(
|
||||
&self,
|
||||
threads: usize,
|
||||
res: &mut FheUintPrepared<DM, T, BE>,
|
||||
bits: &FheUint<DB, T>,
|
||||
bit_start: usize,
|
||||
bit_end: usize,
|
||||
bit_count: usize,
|
||||
key: &K,
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
DM: DataMut,
|
||||
DB: DataRef,
|
||||
DK: DataRef,
|
||||
K: BDDKeyHelper<DK, BRA, BE>,
|
||||
K: BDDKeyHelper<DK, BRA, BE> + BDDKeyInfos,
|
||||
{
|
||||
let bit_end = bit_start + bit_count;
|
||||
let (cbt, ks) = key.get_cbt_key();
|
||||
|
||||
let mut lwe: LWE<Vec<u8>> = LWE::alloc_from_infos(bits); //TODO: add TakeLWE
|
||||
let (mut tmp_ggsw, scratch_1) = scratch.take_ggsw(res);
|
||||
for (bit, dst) in res.bits[bit_start..bit_end].iter_mut().enumerate() {
|
||||
// TODO: set the rest of the bits to a prepared zero GGSW
|
||||
bits.get_bit_lwe(self, bit, &mut lwe, ks, scratch_1);
|
||||
cbt.execute_to_constant(self, &mut tmp_ggsw, &lwe, 1, 1, scratch_1);
|
||||
dst.prepare(self, &tmp_ggsw, scratch_1);
|
||||
}
|
||||
assert!(bit_end <= T::BITS as usize);
|
||||
|
||||
let scratch_thread_size = self.fhe_uint_prepare_tmp_bytes(cbt.block_size(), 1, res, bits, key);
|
||||
|
||||
assert!(scratch.available() >= threads * scratch_thread_size);
|
||||
|
||||
let chunk_size: usize = bit_count.div_ceil(threads);
|
||||
|
||||
let (mut scratches, _) = scratch.split_mut(threads, scratch_thread_size);
|
||||
|
||||
let ggsw_infos: &GGSWLayout = &res.ggsw_layout();
|
||||
|
||||
thread::scope(|scope| {
|
||||
for (thread_index, (scratch_thread, res_bits_chunk)) in scratches
|
||||
.iter_mut()
|
||||
.zip(res.bits[bit_start..bit_end].chunks_mut(chunk_size))
|
||||
.enumerate()
|
||||
{
|
||||
let start: usize = bit_start + thread_index * chunk_size;
|
||||
|
||||
scope.spawn(move || {
|
||||
let (mut tmp_ggsw, scratch_1) = scratch_thread.take_ggsw(ggsw_infos);
|
||||
let (mut tmp_lwe, scratch_2) = scratch_1.take_lwe(bits);
|
||||
for (local_bit, dst) in res_bits_chunk.iter_mut().enumerate() {
|
||||
bits.get_bit_lwe(self, start + local_bit, &mut tmp_lwe, ks, scratch_2);
|
||||
cbt.execute_to_constant(self, &mut tmp_ggsw, &tmp_lwe, 1, 1, scratch_2);
|
||||
dst.prepare(self, &tmp_ggsw, scratch_2);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
for i in 0..bit_start {
|
||||
res.bits[i].zero(self);
|
||||
@@ -324,8 +375,8 @@ impl<D: DataMut, T: UnsignedInteger, BE: Backend> FheUintPrepared<D, T, BE> {
|
||||
BRA: BlindRotationAlgo,
|
||||
O: DataRef,
|
||||
DK: DataRef,
|
||||
K: BDDKeyHelper<DK, BRA, BE>,
|
||||
M: FheUintPrepare<BRA, T, BE>,
|
||||
K: BDDKeyHelper<DK, BRA, BE> + BDDKeyInfos,
|
||||
M: FheUintPrepare<BRA, BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
module.fhe_uint_prepare(self, other, key, scratch);
|
||||
@@ -342,10 +393,31 @@ impl<D: DataMut, T: UnsignedInteger, BE: Backend> FheUintPrepared<D, T, BE> {
|
||||
BRA: BlindRotationAlgo,
|
||||
O: DataRef,
|
||||
DK: DataRef,
|
||||
K: BDDKeyHelper<DK, BRA, BE>,
|
||||
M: FheUintPrepare<BRA, T, BE>,
|
||||
K: BDDKeyHelper<DK, BRA, BE> + BDDKeyInfos,
|
||||
M: FheUintPrepare<BRA, BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
module.fhe_uint_prepare_custom(self, other, bit_start, bit_end, key, scratch);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn prepare_custom_multi_thread<BRA, M, O, K, DK>(
|
||||
&mut self,
|
||||
threads: usize,
|
||||
module: &M,
|
||||
other: &FheUint<O, T>,
|
||||
bit_start: usize,
|
||||
bit_end: usize,
|
||||
key: &K,
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
BRA: BlindRotationAlgo,
|
||||
O: DataRef,
|
||||
DK: DataRef,
|
||||
K: BDDKeyHelper<DK, BRA, BE> + BDDKeyInfos,
|
||||
M: FheUintPrepare<BRA, BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
module.fhe_uint_prepare_custom_multi_thread(threads, self, other, bit_start, bit_end, key, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ use poulpy_core::layouts::{Base2K, Dnum, Dsize, Rank, TorusPrecision};
|
||||
use poulpy_core::layouts::{GGSW, GLWESecretPreparedToRef};
|
||||
use poulpy_core::{
|
||||
LWEFromGLWE, ScratchTakeCore,
|
||||
layouts::{GGSWInfos, GGSWPreparedFactory, GLWEInfos, LWE, LWEInfos},
|
||||
layouts::{GGSWInfos, GGSWPreparedFactory, GLWEInfos, LWEInfos},
|
||||
};
|
||||
|
||||
use poulpy_hal::api::ModuleN;
|
||||
@@ -125,10 +125,12 @@ where
|
||||
DR0: DataRef,
|
||||
DR1: DataRef,
|
||||
{
|
||||
let mut lwe: LWE<Vec<u8>> = LWE::alloc_from_infos(bits); //TODO: add TakeLWE
|
||||
let (_, scratch_1) = scratch.take_ggsw(res);
|
||||
let (mut tmp_lwe, scratch_2) = scratch_1.take_lwe(bits);
|
||||
for (bit, dst) in res.bits.iter_mut().enumerate() {
|
||||
bits.get_bit_lwe(self, bit, &mut lwe, &key.ks, scratch);
|
||||
key.cbt.execute_to_constant(self, dst, &lwe, 1, 1, scratch);
|
||||
bits.get_bit_lwe(self, bit, &mut tmp_lwe, &key.ks, scratch_2);
|
||||
key.cbt
|
||||
.execute_to_constant(self, dst, &tmp_lwe, 1, 1, scratch_2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use core::panic;
|
||||
use std::thread;
|
||||
|
||||
use itertools::Itertools;
|
||||
use poulpy_core::{
|
||||
@@ -6,17 +7,20 @@ use poulpy_core::{
|
||||
layouts::{GGSWInfos, GGSWPrepared, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, prepared::GGSWPreparedToRef},
|
||||
};
|
||||
use poulpy_hal::{
|
||||
api::{ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftBytesOf},
|
||||
api::{
|
||||
ScratchAvailable, ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
|
||||
VecZnxDftBytesOf,
|
||||
},
|
||||
layouts::{Backend, DataMut, Module, Scratch, VecZnxBig, ZnxZero},
|
||||
};
|
||||
|
||||
use crate::tfhe::bdd_arithmetic::{GetGGSWBit, UnsignedInteger};
|
||||
use crate::tfhe::bdd_arithmetic::GetGGSWBit;
|
||||
|
||||
pub trait BitCircuitInfo {
|
||||
pub trait BitCircuitInfo: Sync {
|
||||
fn info(&self) -> (&[Node], usize);
|
||||
}
|
||||
|
||||
pub trait GetBitCircuitInfo<T: UnsignedInteger> {
|
||||
pub trait GetBitCircuitInfo: Sync {
|
||||
fn input_size(&self) -> usize;
|
||||
fn output_size(&self) -> usize;
|
||||
fn get_circuit(&self, bit: usize) -> (&[Node], usize);
|
||||
@@ -34,7 +38,7 @@ pub trait BitCircuitFamily {
|
||||
|
||||
pub struct Circuit<C: BitCircuitInfo, const N: usize>(pub [C; N]);
|
||||
|
||||
impl<C, T: UnsignedInteger, const N: usize> GetBitCircuitInfo<T> for Circuit<C, N>
|
||||
impl<C, const N: usize> GetBitCircuitInfo for Circuit<C, N>
|
||||
where
|
||||
C: BitCircuitInfo + BitCircuitFamily,
|
||||
{
|
||||
@@ -49,11 +53,31 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ExecuteBDDCircuit<T: UnsignedInteger, BE: Backend> {
|
||||
pub trait ExecuteBDDCircuit<BE: Backend> {
|
||||
fn execute_bdd_circuit_tmp_bytes<R, G>(&self, res_infos: &R, state_size: usize, ggsw_infos: &G) -> usize
|
||||
where
|
||||
R: GLWEInfos,
|
||||
G: GGSWInfos;
|
||||
|
||||
fn execute_bdd_circuit<C, G, O>(&self, out: &mut [GLWE<O>], inputs: &G, circuit: &C, scratch: &mut Scratch<BE>)
|
||||
where
|
||||
G: GetGGSWBit<BE> + BitSize,
|
||||
C: GetBitCircuitInfo<T>,
|
||||
C: GetBitCircuitInfo,
|
||||
O: DataMut,
|
||||
{
|
||||
self.execute_bdd_circuit_multi_thread(1, out, inputs, circuit, scratch);
|
||||
}
|
||||
|
||||
fn execute_bdd_circuit_multi_thread<C, G, O>(
|
||||
&self,
|
||||
threads: usize,
|
||||
out: &mut [GLWE<O>],
|
||||
inputs: &G,
|
||||
circuit: &C,
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
G: GetGGSWBit<BE> + BitSize,
|
||||
C: GetBitCircuitInfo,
|
||||
O: DataMut;
|
||||
}
|
||||
|
||||
@@ -61,15 +85,29 @@ pub trait BitSize {
|
||||
fn bit_size(&self) -> usize;
|
||||
}
|
||||
|
||||
impl<T: UnsignedInteger, BE: Backend> ExecuteBDDCircuit<T, BE> for Module<BE>
|
||||
impl<BE: Backend> ExecuteBDDCircuit<BE> for Module<BE>
|
||||
where
|
||||
Self: Cmux<BE> + GLWECopy,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
fn execute_bdd_circuit<C, G, O>(&self, out: &mut [GLWE<O>], inputs: &G, circuit: &C, scratch: &mut Scratch<BE>)
|
||||
fn execute_bdd_circuit_tmp_bytes<R, G>(&self, res_infos: &R, state_size: usize, ggsw_infos: &G) -> usize
|
||||
where
|
||||
R: GLWEInfos,
|
||||
G: GGSWInfos,
|
||||
{
|
||||
2 * state_size * GLWE::bytes_of_from_infos(res_infos) + self.cmux_tmp_bytes(res_infos, res_infos, ggsw_infos)
|
||||
}
|
||||
|
||||
fn execute_bdd_circuit_multi_thread<C, G, O>(
|
||||
&self,
|
||||
threads: usize,
|
||||
out: &mut [GLWE<O>],
|
||||
inputs: &G,
|
||||
circuit: &C,
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
G: GetGGSWBit<BE> + BitSize,
|
||||
C: GetBitCircuitInfo<T>,
|
||||
C: GetBitCircuitInfo,
|
||||
O: DataMut,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
@@ -88,66 +126,43 @@ where
|
||||
);
|
||||
}
|
||||
|
||||
for (i, out_i) in out.iter_mut().enumerate().take(circuit.output_size()) {
|
||||
let (nodes, max_inter_state) = circuit.get_circuit(i);
|
||||
let mut max_state_size = 0;
|
||||
for i in 0..circuit.output_size() {
|
||||
let (_, state_size) = circuit.get_circuit(i);
|
||||
max_state_size = max_state_size.max(state_size)
|
||||
}
|
||||
|
||||
if max_inter_state == 0 {
|
||||
out_i.data_mut().zero();
|
||||
} else {
|
||||
assert!(nodes.len().is_multiple_of(max_inter_state));
|
||||
let scratch_thread_size: usize = self.execute_bdd_circuit_tmp_bytes(&out[0], max_state_size, &inputs.get_bit(0));
|
||||
|
||||
let (mut level, scratch_1) = scratch.take_glwe_slice(max_inter_state * 2, out_i);
|
||||
assert!(
|
||||
scratch.available() >= threads * scratch_thread_size,
|
||||
"scratch.available(): {} < threads:{threads} * scratch_thread_size: {scratch_thread_size}",
|
||||
scratch.available()
|
||||
);
|
||||
|
||||
level.iter_mut().for_each(|ct| ct.data_mut().zero());
|
||||
let (mut scratches, _) = scratch.split_mut(threads, scratch_thread_size);
|
||||
|
||||
// TODO: implement API on GLWE
|
||||
level[1]
|
||||
.data_mut()
|
||||
.encode_coeff_i64(out_i.base2k().into(), 0, 2, 0, 1);
|
||||
let chunk_size: usize = circuit.output_size().div_ceil(threads);
|
||||
|
||||
let mut level_ref = level.iter_mut().collect_vec();
|
||||
let (mut prev_level, mut next_level) = level_ref.split_at_mut(max_inter_state);
|
||||
thread::scope(|scope| {
|
||||
for (scratch_thread, out_chunk) in scratches
|
||||
.iter_mut()
|
||||
.zip(out[..circuit.output_size()].chunks_mut(chunk_size))
|
||||
{
|
||||
// Capture chunk + thread scratch by move
|
||||
scope.spawn(move || {
|
||||
for (idx, out_i) in out_chunk.iter_mut().enumerate() {
|
||||
let (nodes, state_size) = circuit.get_circuit(idx);
|
||||
|
||||
let (all_but_last, last) = nodes.split_at(nodes.len() - max_inter_state);
|
||||
|
||||
for nodes_lvl in all_but_last.chunks_exact(max_inter_state) {
|
||||
for (j, node) in nodes_lvl.iter().enumerate() {
|
||||
match node {
|
||||
Node::Cmux(in_idx, hi_idx, lo_idx) => {
|
||||
self.cmux(
|
||||
next_level[j],
|
||||
prev_level[*hi_idx],
|
||||
prev_level[*lo_idx],
|
||||
&inputs.get_bit(*in_idx),
|
||||
scratch_1,
|
||||
);
|
||||
}
|
||||
Node::Copy => self.glwe_copy(next_level[j], prev_level[j]), /* Update BDD circuits to order Cmux -> Copy -> None so that mem swap can be used */
|
||||
Node::None => {}
|
||||
if state_size == 0 {
|
||||
out_i.data_mut().zero();
|
||||
} else {
|
||||
eval_level(self, out_i, inputs, nodes, state_size, *scratch_thread);
|
||||
}
|
||||
}
|
||||
|
||||
(prev_level, next_level) = (next_level, prev_level);
|
||||
}
|
||||
|
||||
// Last chunck of max_inter_state Nodes is always structured as
|
||||
// [CMUX, NONE, NONE, ..., NONE]
|
||||
match &last[0] {
|
||||
Node::Cmux(in_idx, hi_idx, lo_idx) => {
|
||||
self.cmux(
|
||||
out_i,
|
||||
prev_level[*hi_idx],
|
||||
prev_level[*lo_idx],
|
||||
&inputs.get_bit(*in_idx),
|
||||
scratch_1,
|
||||
);
|
||||
}
|
||||
_ => {
|
||||
panic!("invalid last node, should be CMUX")
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
for out_i in out.iter_mut().skip(circuit.output_size()) {
|
||||
out_i.data_mut().zero();
|
||||
@@ -155,6 +170,74 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
fn eval_level<M, R, G, BE: Backend>(
|
||||
module: &M,
|
||||
res: &mut R,
|
||||
inputs: &G,
|
||||
nodes: &[Node],
|
||||
state_size: usize,
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
M: Cmux<BE> + GLWECopy,
|
||||
R: GLWEToMut,
|
||||
G: GetGGSWBit<BE> + BitSize,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
assert!(nodes.len().is_multiple_of(state_size));
|
||||
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
|
||||
|
||||
let (mut level, scratch_1) = scratch.take_glwe_slice(state_size * 2, res);
|
||||
|
||||
level.iter_mut().for_each(|ct| ct.data_mut().zero());
|
||||
|
||||
// TODO: implement API on GLWE
|
||||
level[1]
|
||||
.data_mut()
|
||||
.encode_coeff_i64(res.base2k().into(), 0, 2, 0, 1);
|
||||
|
||||
let mut level_ref: Vec<&mut GLWE<&mut [u8]>> = level.iter_mut().collect_vec();
|
||||
let (mut prev_level, mut next_level) = level_ref.split_at_mut(state_size);
|
||||
|
||||
let (all_but_last, last) = nodes.split_at(nodes.len() - state_size);
|
||||
|
||||
for nodes_lvl in all_but_last.chunks_exact(state_size) {
|
||||
for (j, node) in nodes_lvl.iter().enumerate() {
|
||||
match node {
|
||||
Node::Cmux(in_idx, hi_idx, lo_idx) => {
|
||||
module.cmux(
|
||||
next_level[j],
|
||||
prev_level[*hi_idx],
|
||||
prev_level[*lo_idx],
|
||||
&inputs.get_bit(*in_idx),
|
||||
scratch_1,
|
||||
);
|
||||
}
|
||||
Node::Copy => module.glwe_copy(next_level[j], prev_level[j]), /* Update BDD circuits to order Cmux -> Copy -> None so that mem swap can be used */
|
||||
Node::None => {}
|
||||
}
|
||||
}
|
||||
|
||||
(prev_level, next_level) = (next_level, prev_level);
|
||||
}
|
||||
|
||||
// Last chunck of max_inter_state Nodes is always structured as
|
||||
// [CMUX, NONE, NONE, ..., NONE]
|
||||
match &last[0] {
|
||||
Node::Cmux(in_idx, hi_idx, lo_idx) => {
|
||||
module.cmux(
|
||||
res,
|
||||
prev_level[*hi_idx],
|
||||
prev_level[*lo_idx],
|
||||
&inputs.get_bit(*in_idx),
|
||||
scratch_1,
|
||||
);
|
||||
}
|
||||
_ => {
|
||||
panic!("invalid last node, should be CMUX")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize> BitCircuit<N> {
|
||||
pub const fn new(nodes: [Node; N], max_inter_state: usize) -> Self {
|
||||
Self {
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use crate::tfhe::bdd_arithmetic::FheUintPreparedDebug;
|
||||
use crate::tfhe::circuit_bootstrapping::CircuitBootstrappingKeyInfos;
|
||||
use crate::tfhe::{
|
||||
bdd_arithmetic::{FheUint, UnsignedInteger},
|
||||
blind_rotation::{BlindRotationAlgo, BlindRotationKey, BlindRotationKeyFactory},
|
||||
@@ -8,7 +9,7 @@ use crate::tfhe::{
|
||||
},
|
||||
};
|
||||
|
||||
use poulpy_core::layouts::{GLWEAutomorphismKeyHelper, GLWEAutomorphismKeyPrepared};
|
||||
use poulpy_core::layouts::{GGLWEInfos, GLWEAutomorphismKeyHelper, GLWEAutomorphismKeyPrepared};
|
||||
use poulpy_core::{
|
||||
GLWEToLWESwitchingKeyEncryptSk, GetDistribution, ScratchTakeCore,
|
||||
layouts::{
|
||||
@@ -135,6 +136,25 @@ where
|
||||
pub(crate) ks: GLWEToLWEKeyPrepared<D, BE>,
|
||||
}
|
||||
|
||||
impl<D: DataRef, BRA: BlindRotationAlgo, BE: Backend> BDDKeyInfos for BDDKeyPrepared<D, BRA, BE> {
|
||||
fn cbt_infos(&self) -> CircuitBootstrappingKeyLayout {
|
||||
CircuitBootstrappingKeyLayout {
|
||||
layout_brk: self.cbt.brk_infos(),
|
||||
layout_atk: self.cbt.atk_infos(),
|
||||
layout_tsk: self.cbt.tsk_infos(),
|
||||
}
|
||||
}
|
||||
fn ks_infos(&self) -> GLWEToLWEKeyLayout {
|
||||
GLWEToLWEKeyLayout {
|
||||
n: self.ks.n(),
|
||||
base2k: self.ks.base2k(),
|
||||
k: self.ks.k(),
|
||||
rank_in: self.ks.rank_in(),
|
||||
dnum: self.ks.dnum(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef, BRA: BlindRotationAlgo, BE: Backend> GLWEAutomorphismKeyHelper<GLWEAutomorphismKeyPrepared<D, BE>, BE>
|
||||
for BDDKeyPrepared<D, BRA, BE>
|
||||
{
|
||||
|
||||
@@ -16,7 +16,7 @@ pub use key::*;
|
||||
|
||||
pub mod tests;
|
||||
|
||||
pub trait UnsignedInteger: Copy + 'static {
|
||||
pub trait UnsignedInteger: Copy + Sync + Send + 'static {
|
||||
const BITS: u32;
|
||||
const LOG_BITS: u32;
|
||||
const LOG_BYTES: u32;
|
||||
|
||||
@@ -30,8 +30,8 @@ where
|
||||
+ BDDKeyEncryptSk<BRA, BE>
|
||||
+ BDDKeyPreparedFactory<BRA, BE>
|
||||
+ GGSWNoise<BE>
|
||||
+ FheUintPrepare<BRA, u32, BE>
|
||||
+ ExecuteBDDCircuit2WTo1W<u32, BE>
|
||||
+ FheUintPrepare<BRA, BE>
|
||||
+ ExecuteBDDCircuit2WTo1W<BE>
|
||||
+ GLWEEncryptSk<BE>,
|
||||
BlindRotationKey<Vec<u8>, BRA>: BlindRotationKeyFactory<BRA>,
|
||||
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
|
||||
|
||||
@@ -30,8 +30,8 @@ where
|
||||
+ BDDKeyEncryptSk<BRA, BE>
|
||||
+ BDDKeyPreparedFactory<BRA, BE>
|
||||
+ GGSWNoise<BE>
|
||||
+ FheUintPrepare<BRA, u32, BE>
|
||||
+ ExecuteBDDCircuit2WTo1W<u32, BE>
|
||||
+ FheUintPrepare<BRA, BE>
|
||||
+ ExecuteBDDCircuit2WTo1W<BE>
|
||||
+ GLWEEncryptSk<BE>,
|
||||
BlindRotationKey<Vec<u8>, BRA>: BlindRotationKeyFactory<BRA>,
|
||||
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
|
||||
|
||||
@@ -75,6 +75,14 @@ where
|
||||
}
|
||||
|
||||
impl<BRA: BlindRotationAlgo, BE: Backend> TestContext<BRA, BE> {
|
||||
pub fn glwe_infos(&self) -> GLWELayout {
|
||||
TEST_GLWE_INFOS
|
||||
}
|
||||
|
||||
pub fn ggsw_infos(&self) -> GGSWLayout {
|
||||
TEST_GGSW_INFOS
|
||||
}
|
||||
|
||||
pub fn new() -> Self
|
||||
where
|
||||
Module<BE>: ModuleNew<BE>
|
||||
|
||||
@@ -30,8 +30,8 @@ where
|
||||
+ BDDKeyEncryptSk<BRA, BE>
|
||||
+ BDDKeyPreparedFactory<BRA, BE>
|
||||
+ GGSWNoise<BE>
|
||||
+ FheUintPrepare<BRA, u32, BE>
|
||||
+ ExecuteBDDCircuit2WTo1W<u32, BE>
|
||||
+ FheUintPrepare<BRA, BE>
|
||||
+ ExecuteBDDCircuit2WTo1W<BE>
|
||||
+ GLWEEncryptSk<BE>,
|
||||
BlindRotationKey<Vec<u8>, BRA>: BlindRotationKeyFactory<BRA>,
|
||||
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
|
||||
|
||||
@@ -30,8 +30,8 @@ where
|
||||
+ BDDKeyEncryptSk<BRA, BE>
|
||||
+ BDDKeyPreparedFactory<BRA, BE>
|
||||
+ GGSWNoise<BE>
|
||||
+ FheUintPrepare<BRA, u32, BE>
|
||||
+ ExecuteBDDCircuit2WTo1W<u32, BE>
|
||||
+ FheUintPrepare<BRA, BE>
|
||||
+ ExecuteBDDCircuit2WTo1W<BE>
|
||||
+ GLWEEncryptSk<BE>,
|
||||
BlindRotationKey<Vec<u8>, BRA>: BlindRotationKeyFactory<BRA>,
|
||||
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
|
||||
@@ -67,8 +67,10 @@ where
|
||||
let mut c_enc_prep_debug: FheUintPreparedDebug<Vec<u8>, u32> =
|
||||
FheUintPreparedDebug::<Vec<u8>, u32>::alloc_from_infos(module, &ggsw_infos);
|
||||
|
||||
let mut scratch_2 = ScratchOwned::alloc(module.fhe_uint_prepare_tmp_bytes(7, 1, &c_enc_prep_debug, &c_enc, bdd_key_prepared));
|
||||
|
||||
// GGSW(value)
|
||||
c_enc_prep_debug.prepare(module, &c_enc, bdd_key_prepared, scratch.borrow());
|
||||
c_enc_prep_debug.prepare(module, &c_enc, bdd_key_prepared, scratch_2.borrow());
|
||||
|
||||
let max_noise = |col_i: usize| {
|
||||
let mut noise: f64 = -(ggsw_infos.size() as f64 * TEST_BASE2K as f64) + SIGMA.log2() + 2.0;
|
||||
|
||||
@@ -30,8 +30,8 @@ where
|
||||
+ BDDKeyEncryptSk<BRA, BE>
|
||||
+ BDDKeyPreparedFactory<BRA, BE>
|
||||
+ GGSWNoise<BE>
|
||||
+ FheUintPrepare<BRA, u32, BE>
|
||||
+ ExecuteBDDCircuit2WTo1W<u32, BE>
|
||||
+ FheUintPrepare<BRA, BE>
|
||||
+ ExecuteBDDCircuit2WTo1W<BE>
|
||||
+ GLWEEncryptSk<BE>,
|
||||
BlindRotationKey<Vec<u8>, BRA>: BlindRotationKeyFactory<BRA>,
|
||||
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
|
||||
|
||||
@@ -30,8 +30,8 @@ where
|
||||
+ BDDKeyEncryptSk<BRA, BE>
|
||||
+ BDDKeyPreparedFactory<BRA, BE>
|
||||
+ GGSWNoise<BE>
|
||||
+ FheUintPrepare<BRA, u32, BE>
|
||||
+ ExecuteBDDCircuit2WTo1W<u32, BE>
|
||||
+ FheUintPrepare<BRA, BE>
|
||||
+ ExecuteBDDCircuit2WTo1W<BE>
|
||||
+ GLWEEncryptSk<BE>,
|
||||
BlindRotationKey<Vec<u8>, BRA>: BlindRotationKeyFactory<BRA>,
|
||||
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
|
||||
|
||||
@@ -30,8 +30,8 @@ where
|
||||
+ BDDKeyEncryptSk<BRA, BE>
|
||||
+ BDDKeyPreparedFactory<BRA, BE>
|
||||
+ GGSWNoise<BE>
|
||||
+ FheUintPrepare<BRA, u32, BE>
|
||||
+ ExecuteBDDCircuit2WTo1W<u32, BE>
|
||||
+ FheUintPrepare<BRA, BE>
|
||||
+ ExecuteBDDCircuit2WTo1W<BE>
|
||||
+ GLWEEncryptSk<BE>,
|
||||
BlindRotationKey<Vec<u8>, BRA>: BlindRotationKeyFactory<BRA>,
|
||||
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
|
||||
|
||||
@@ -30,8 +30,8 @@ where
|
||||
+ BDDKeyEncryptSk<BRA, BE>
|
||||
+ BDDKeyPreparedFactory<BRA, BE>
|
||||
+ GGSWNoise<BE>
|
||||
+ FheUintPrepare<BRA, u32, BE>
|
||||
+ ExecuteBDDCircuit2WTo1W<u32, BE>
|
||||
+ FheUintPrepare<BRA, BE>
|
||||
+ ExecuteBDDCircuit2WTo1W<BE>
|
||||
+ GLWEEncryptSk<BE>,
|
||||
BlindRotationKey<Vec<u8>, BRA>: BlindRotationKeyFactory<BRA>,
|
||||
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
|
||||
|
||||
@@ -30,8 +30,8 @@ where
|
||||
+ BDDKeyEncryptSk<BRA, BE>
|
||||
+ BDDKeyPreparedFactory<BRA, BE>
|
||||
+ GGSWNoise<BE>
|
||||
+ FheUintPrepare<BRA, u32, BE>
|
||||
+ ExecuteBDDCircuit2WTo1W<u32, BE>
|
||||
+ FheUintPrepare<BRA, BE>
|
||||
+ ExecuteBDDCircuit2WTo1W<BE>
|
||||
+ GLWEEncryptSk<BE>,
|
||||
BlindRotationKey<Vec<u8>, BRA>: BlindRotationKeyFactory<BRA>,
|
||||
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
|
||||
|
||||
@@ -30,8 +30,8 @@ where
|
||||
+ BDDKeyEncryptSk<BRA, BE>
|
||||
+ BDDKeyPreparedFactory<BRA, BE>
|
||||
+ GGSWNoise<BE>
|
||||
+ FheUintPrepare<BRA, u32, BE>
|
||||
+ ExecuteBDDCircuit2WTo1W<u32, BE>
|
||||
+ FheUintPrepare<BRA, BE>
|
||||
+ ExecuteBDDCircuit2WTo1W<BE>
|
||||
+ GLWEEncryptSk<BE>,
|
||||
BlindRotationKey<Vec<u8>, BRA>: BlindRotationKeyFactory<BRA>,
|
||||
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
|
||||
|
||||
@@ -30,8 +30,8 @@ where
|
||||
+ BDDKeyEncryptSk<BRA, BE>
|
||||
+ BDDKeyPreparedFactory<BRA, BE>
|
||||
+ GGSWNoise<BE>
|
||||
+ FheUintPrepare<BRA, u32, BE>
|
||||
+ ExecuteBDDCircuit2WTo1W<u32, BE>
|
||||
+ FheUintPrepare<BRA, BE>
|
||||
+ ExecuteBDDCircuit2WTo1W<BE>
|
||||
+ GLWEEncryptSk<BE>,
|
||||
BlindRotationKey<Vec<u8>, BRA>: BlindRotationKeyFactory<BRA>,
|
||||
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
|
||||
|
||||
@@ -11,7 +11,7 @@ use poulpy_hal::layouts::{Backend, DataMut, DataRef, Scratch, ZnxView};
|
||||
|
||||
use crate::tfhe::blind_rotation::{BlindRotationKeyInfos, BlindRotationKeyPrepared, LookUpTableRotationDirection, LookupTable};
|
||||
|
||||
pub trait BlindRotationAlgo {}
|
||||
pub trait BlindRotationAlgo: Sync {}
|
||||
|
||||
pub trait BlindRotationExecute<BRA: BlindRotationAlgo, BE: Backend> {
|
||||
fn blind_rotation_execute_tmp_bytes<G, B>(
|
||||
|
||||
@@ -188,8 +188,7 @@ impl<D: DataRef, BRT: BlindRotationAlgo> BlindRotationKeyInfos for BlindRotation
|
||||
}
|
||||
|
||||
impl<D: DataRef, BRT: BlindRotationAlgo> BlindRotationKey<D, BRT> {
|
||||
#[allow(dead_code)]
|
||||
fn block_size(&self) -> usize {
|
||||
pub fn block_size(&self) -> usize {
|
||||
match self.dist {
|
||||
Distribution::BinaryBlock(value) => value,
|
||||
_ => 1,
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use poulpy_hal::{
|
||||
api::{ModuleLogN, ModuleN, ScratchOwnedAlloc, ScratchOwnedBorrow},
|
||||
api::{ModuleLogN, ModuleN, ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow},
|
||||
layouts::{Backend, DataRef, Module, Scratch, ScratchOwned, ToOwnedDeep},
|
||||
};
|
||||
|
||||
use poulpy_core::{
|
||||
GGSWFromGGLWE, GLWEDecrypt, GLWEPacking, GLWERotate, GLWETrace, ScratchTakeCore,
|
||||
layouts::{
|
||||
Dsize, GGLWEInfos, GGLWELayout, GGLWEPreparedToRef, GGSWInfos, GGSWToMut, GLWEAutomorphismKeyHelper, GLWEInfos,
|
||||
GLWESecretPreparedFactory, GLWEToMut, GLWEToRef, GetGaloisElement, LWEInfos, LWEToRef,
|
||||
Dsize, GGLWE, GGLWEInfos, GGLWELayout, GGLWEPreparedToRef, GGSWInfos, GGSWToMut, GLWEAutomorphismKeyHelper, GLWEInfos,
|
||||
GLWESecretPreparedFactory, GLWEToMut, GLWEToRef, GetGaloisElement, LWEInfos, LWEToRef, Rank,
|
||||
},
|
||||
};
|
||||
|
||||
@@ -132,6 +132,16 @@ where
|
||||
R: GGSWInfos,
|
||||
A: CircuitBootstrappingKeyInfos,
|
||||
{
|
||||
let gglwe_infos: GGLWELayout = GGLWELayout {
|
||||
n: res_infos.n(),
|
||||
base2k: res_infos.base2k(),
|
||||
k: res_infos.k(),
|
||||
dnum: res_infos.dnum(),
|
||||
dsize: Dsize(1),
|
||||
rank_in: res_infos.rank().max(Rank(1)),
|
||||
rank_out: res_infos.rank(),
|
||||
};
|
||||
|
||||
self.blind_rotation_execute_tmp_bytes(
|
||||
block_size,
|
||||
extension_factor,
|
||||
@@ -140,6 +150,8 @@ where
|
||||
)
|
||||
.max(self.glwe_trace_tmp_bytes(res_infos, res_infos, &cbt_infos.atk_infos()))
|
||||
.max(self.ggsw_from_gglwe_tmp_bytes(res_infos, &cbt_infos.tsk_infos()))
|
||||
+ GLWE::bytes_of_from_infos(res_infos)
|
||||
+ GGLWE::bytes_of_from_infos(&gglwe_infos)
|
||||
}
|
||||
|
||||
fn circuit_bootstrapping_execute_to_constant<R, L, D>(
|
||||
@@ -155,6 +167,10 @@ where
|
||||
L: LWEToRef + LWEInfos,
|
||||
D: DataRef,
|
||||
{
|
||||
assert!(
|
||||
scratch.available() >= self.circuit_bootstrapping_execute_tmp_bytes(key.block_size(), extension_factor, res, key)
|
||||
);
|
||||
|
||||
circuit_bootstrap_core(
|
||||
false,
|
||||
self,
|
||||
@@ -182,6 +198,10 @@ where
|
||||
L: LWEToRef + LWEInfos,
|
||||
D: DataRef,
|
||||
{
|
||||
assert!(
|
||||
scratch.available() >= self.circuit_bootstrapping_execute_tmp_bytes(key.block_size(), extension_factor, res, key)
|
||||
);
|
||||
|
||||
circuit_bootstrap_core(
|
||||
true,
|
||||
self,
|
||||
|
||||
@@ -19,6 +19,7 @@ use crate::tfhe::blind_rotation::{
|
||||
};
|
||||
|
||||
pub trait CircuitBootstrappingKeyInfos {
|
||||
fn block_size(&self) -> usize;
|
||||
fn brk_infos(&self) -> BlindRotationKeyLayout;
|
||||
fn atk_infos(&self) -> GLWEAutomorphismKeyLayout;
|
||||
fn tsk_infos(&self) -> GGLWEToGGSWKeyLayout;
|
||||
@@ -32,6 +33,10 @@ pub struct CircuitBootstrappingKeyLayout {
|
||||
}
|
||||
|
||||
impl CircuitBootstrappingKeyInfos for CircuitBootstrappingKeyLayout {
|
||||
fn block_size(&self) -> usize {
|
||||
unimplemented!("unimplemented for CircuitBootstrappingKeyLayout")
|
||||
}
|
||||
|
||||
fn atk_infos(&self) -> GLWEAutomorphismKeyLayout {
|
||||
self.layout_atk
|
||||
}
|
||||
@@ -164,6 +169,10 @@ where
|
||||
}
|
||||
|
||||
impl<D: DataRef, BRA: BlindRotationAlgo> CircuitBootstrappingKeyInfos for CircuitBootstrappingKey<D, BRA> {
|
||||
fn block_size(&self) -> usize {
|
||||
self.brk.block_size()
|
||||
}
|
||||
|
||||
fn atk_infos(&self) -> GLWEAutomorphismKeyLayout {
|
||||
let (_, atk) = self.atk.iter().next().expect("atk is empty");
|
||||
GLWEAutomorphismKeyLayout {
|
||||
|
||||
@@ -122,6 +122,10 @@ impl<D: DataRef, BRA: BlindRotationAlgo, BE: Backend> GLWEAutomorphismKeyHelper<
|
||||
}
|
||||
|
||||
impl<D: DataRef, BRA: BlindRotationAlgo, B: Backend> CircuitBootstrappingKeyInfos for CircuitBootstrappingKeyPrepared<D, BRA, B> {
|
||||
fn block_size(&self) -> usize {
|
||||
self.brk.block_size()
|
||||
}
|
||||
|
||||
fn atk_infos(&self) -> GLWEAutomorphismKeyLayout {
|
||||
let (_, atk) = self.atk.iter().next().expect("atk is empty");
|
||||
GLWEAutomorphismKeyLayout {
|
||||
|
||||
Reference in New Issue
Block a user