diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/blind_rotation.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/blind_rotation.rs index 77c2155..b50329c 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/blind_rotation.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/blind_rotation.rs @@ -25,16 +25,49 @@ where R: GLWEInfos, K: GGSWInfos, { - self.glwe_to_glwe_blind_rotation_tmp_bytes(res_infos, k_infos) + self.glwe_blind_rotation_tmp_bytes(res_infos, k_infos) + } + + #[allow(clippy::too_many_arguments)] + /// res <- res * X^{((k>>bit_rsh) % 2^bit_mask) << bit_lsh}. + fn ggsw_blind_rotation_inplace( + &self, + res: &mut R, + fhe_uint: &K, + sign: bool, + bit_rsh: usize, + bit_mask: usize, + bit_lsh: usize, + scratch: &mut Scratch, + ) where + R: GGSWToMut, + K: GetGGSWBit, + Scratch: ScratchTakeCore, + { + let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); + + for col in 0..(res.rank() + 1).into() { + for row in 0..res.dnum().into() { + self.glwe_blind_rotation_inplace( + &mut res.at_mut(row, col), + fhe_uint, + sign, + bit_rsh, + bit_mask, + bit_lsh, + scratch, + ); + } + } } #[allow(clippy::too_many_arguments)] /// res <- a * X^{((k>>bit_rsh) % 2^bit_mask) << bit_lsh}. - fn ggsw_to_ggsw_blind_rotation( + fn ggsw_blind_rotation( &self, res: &mut R, a: &A, - k: &K, + fhe_uint: &K, sign: bool, bit_rsh: usize, bit_mask: usize, @@ -54,10 +87,10 @@ where for col in 0..(res.rank() + 1).into() { for row in 0..res.dnum().into() { - self.glwe_to_glwe_blind_rotation( + self.glwe_blind_rotation( &mut res.at_mut(row, col), &a.at(row, col), - k, + fhe_uint, sign, bit_rsh, bit_mask, @@ -73,7 +106,7 @@ where R: GLWEInfos, K: GGSWInfos, { - self.glwe_to_glwe_blind_rotation_tmp_bytes(res_infos, k_infos) + GLWE::bytes_of_from_infos(res_infos) + self.glwe_blind_rotation_tmp_bytes(res_infos, k_infos) + GLWE::bytes_of_from_infos(res_infos) } #[allow(clippy::too_many_arguments)] @@ -81,7 +114,7 @@ where &self, res: &mut R, test_vector: &S, - k: &K, + fhe_uint: &K, sign: bool, bit_rsh: usize, bit_mask: usize, @@ -113,10 +146,10 @@ where ); self.vec_znx_normalize_inplace(base2k, tmp_glwe.data_mut(), col, scratch_1); - self.glwe_to_glwe_blind_rotation( + self.glwe_blind_rotation( &mut res.at_mut(row, col), &tmp_glwe, - k, + fhe_uint, sign, bit_rsh, bit_mask, @@ -139,7 +172,7 @@ pub trait GLWEBlindRotation where Self: GLWECopy + GLWERotate + Cmux, { - fn glwe_to_glwe_blind_rotation_tmp_bytes(&self, res_infos: &R, k_infos: &K) -> usize + fn glwe_blind_rotation_tmp_bytes(&self, res_infos: &R, k_infos: &K) -> usize where R: GLWEInfos, K: GGSWInfos, @@ -148,12 +181,10 @@ where } #[allow(clippy::too_many_arguments)] - /// res <- a * X^{sign * ((k>>bit_rsh) % 2^bit_mask) << bit_lsh}. - fn glwe_to_glwe_blind_rotation( + fn glwe_blind_rotation_inplace( &self, res: &mut R, - a: &A, - k: &K, + fhe_uint: &K, sign: bool, bit_rsh: usize, bit_mask: usize, @@ -161,31 +192,20 @@ where scratch: &mut Scratch, ) where R: GLWEToMut, - A: GLWEToRef, K: GetGGSWBit, Scratch: ScratchTakeCore, { assert!(bit_rsh + bit_mask <= T::WORD_SIZE); let mut res: GLWE<&mut [u8]> = res.to_mut(); - let a: &GLWE<&[u8]> = &a.to_ref(); let (mut tmp_res, scratch_1) = scratch.take_glwe(&res); - // a <- a ; b <- a * X^{-2^{i + bit_lsh}} - match sign { - true => self.glwe_rotate(1 << bit_lsh, &mut res, a), - false => self.glwe_rotate(-1 << bit_lsh, &mut res, a), - } - - // b <- (b - a) * GGSW(b[i]) + a - self.cmux_inplace(&mut res, a, &k.get_bit(bit_rsh), scratch_1); - // a_is_res = true => (a, b) = (&mut res, &mut tmp_res) // a_is_res = false => (a, b) = (&mut tmp_res, &mut res) let mut a_is_res: bool = true; - for i in 1..bit_mask { + for i in 0..bit_mask { let (a, b) = if a_is_res { (&mut res, &mut tmp_res) } else { @@ -199,7 +219,7 @@ where } // b <- (b - a) * GGSW(b[i]) + a - self.cmux_inplace(b, a, &k.get_bit(i + bit_rsh), scratch_1); + self.cmux_inplace(b, a, &fhe_uint.get_bit(i + bit_rsh), scratch_1); // ping-pong roles for next iter a_is_res = !a_is_res; @@ -210,4 +230,26 @@ where self.glwe_copy(&mut res, &tmp_res); } } + + #[allow(clippy::too_many_arguments)] + /// res <- a * X^{sign * ((k>>bit_rsh) % 2^bit_mask) << bit_lsh}. + fn glwe_blind_rotation( + &self, + res: &mut R, + a: &A, + fhe_uint: &K, + sign: bool, + bit_rsh: usize, + bit_mask: usize, + bit_lsh: usize, + scratch: &mut Scratch, + ) where + R: GLWEToMut, + A: GLWEToRef, + K: GetGGSWBit, + Scratch: ScratchTakeCore, + { + self.glwe_copy(res, a); + self.glwe_blind_rotation_inplace(res, fhe_uint, sign, bit_rsh, bit_mask, bit_lsh, scratch); + } } diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/glwe_blind_rotation.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/glwe_blind_rotation.rs index 6448e4a..78a03de 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/glwe_blind_rotation.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/glwe_blind_rotation.rs @@ -101,7 +101,7 @@ where // How many bits to take let bit_size: usize = (32 - bit_start).min(digit); - module.glwe_to_glwe_blind_rotation( + module.glwe_blind_rotation( &mut res, &test_glwe, &k_enc_prep,