diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs index 31a7f3f..cc3176f 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs @@ -210,16 +210,17 @@ pub trait FheUintBlocksPrepare( + fn fhe_uint_prepare( &self, res: &mut FheUintPrepared, - bits: &FheUint, - key: &BDDKeyPrepared, + bits: &FheUint, + key: &K, scratch: &mut Scratch, ) where DM: DataMut, - DR0: DataRef, - DR1: DataRef; + DB: DataRef, + DK: DataRef, + K: BDDKeyHelper; } impl FheUintBlocksPrepare for Module @@ -240,39 +241,46 @@ where ) } - fn fhe_uint_prepare( + fn fhe_uint_prepare( &self, res: &mut FheUintPrepared, - bits: &FheUint, - key: &BDDKeyPrepared, + bits: &FheUint, + key: &K, scratch: &mut Scratch, ) where DM: DataMut, - DR0: DataRef, - DR1: DataRef, + DB: DataRef, + DK: DataRef, + K: BDDKeyHelper, { + let (cbt, ks) = key.get_cbt_key(); + let mut lwe: LWE> = LWE::alloc_from_infos(bits); //TODO: add TakeLWE let (mut tmp_ggsw, scratch_1) = scratch.take_ggsw(res); for (bit, dst) in res.bits.iter_mut().enumerate() { - bits.get_bit(self, bit, &mut lwe, &key.ks, scratch_1); - key.cbt - .execute_to_constant(self, &mut tmp_ggsw, &lwe, 1, 1, scratch_1); + bits.get_bit(self, bit, &mut lwe, ks, scratch_1); + cbt.execute_to_constant(self, &mut tmp_ggsw, &lwe, 1, 1, scratch_1); dst.prepare(self, &tmp_ggsw, scratch_1); } } } +pub trait BDDKeyHelper { + fn get_cbt_key( + &self, + ) -> ( + &CircuitBootstrappingKeyPrepared, + &GLWEToLWEKeyPrepared, + ); +} + impl FheUintPrepared { - pub fn prepare( - &mut self, - module: &M, - other: &FheUint, - key: &BDDKeyPrepared, - scratch: &mut Scratch, - ) where + pub fn prepare(&mut self, module: &M, other: &FheUint, key: &K, scratch: &mut Scratch) + where BRA: BlindRotationAlgo, O: DataRef, - K: DataRef, + DK: DataRef, + K: BDDKeyHelper, M: FheUintBlocksPrepare, Scratch: ScratchTakeCore, {