From 98208d5e67d307b03e41ce78f367ed2ac41f0ea3 Mon Sep 17 00:00:00 2001 From: Pro7ech Date: Sat, 25 Oct 2025 17:58:34 +0200 Subject: [PATCH] add test for GLWEBlindRotation --- poulpy-backend/src/cpu_fft64_avx/vec_znx.rs | 13 +- poulpy-backend/src/cpu_fft64_ref/vec_znx.rs | 13 +- .../src/cpu_spqlios/fft64/vec_znx.rs | 13 +- poulpy-core/src/operations/glwe.rs | 15 +- poulpy-hal/src/api/vec_znx.rs | 6 + poulpy-hal/src/delegates/vec_znx.rs | 16 ++- poulpy-hal/src/oep/vec_znx.rs | 10 ++ poulpy-hal/src/reference/vec_znx/mod.rs | 2 + poulpy-hal/src/reference/vec_znx/zero.rs | 16 +++ .../src/tfhe/bdd_arithmetic/bdd_rotation.rs | 65 +++++++-- .../tfhe/bdd_arithmetic/tests/fft64_ref.rs | 7 +- .../tests/test_suite/glwe_blind_rotation.rs | 134 ++++++++++++++++++ .../bdd_arithmetic/tests/test_suite/mod.rs | 2 + 13 files changed, 286 insertions(+), 26 deletions(-) create mode 100644 poulpy-hal/src/reference/vec_znx/zero.rs create mode 100644 poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/glwe_blind_rotation.rs diff --git a/poulpy-backend/src/cpu_fft64_avx/vec_znx.rs b/poulpy-backend/src/cpu_fft64_avx/vec_znx.rs index 33325a7..9e286b7 100644 --- a/poulpy-backend/src/cpu_fft64_avx/vec_znx.rs +++ b/poulpy-backend/src/cpu_fft64_avx/vec_znx.rs @@ -14,7 +14,7 @@ use poulpy_hal::{ VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl, VecZnxSplitRingTmpBytesImpl, VecZnxSubImpl, VecZnxSubInplaceImpl, VecZnxSubNegateInplaceImpl, VecZnxSubScalarImpl, - VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl, + VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl, VecZnxZeroImpl, }, reference::vec_znx::{ vec_znx_add, vec_znx_add_inplace, vec_znx_add_normal_ref, vec_znx_add_scalar, vec_znx_add_scalar_inplace, @@ -25,13 +25,22 @@ use poulpy_hal::{ vec_znx_normalize_inplace, vec_znx_normalize_tmp_bytes, vec_znx_rotate, vec_znx_rotate_inplace, vec_znx_rotate_inplace_tmp_bytes, vec_znx_rsh, vec_znx_rsh_inplace, vec_znx_rsh_tmp_bytes, vec_znx_split_ring, vec_znx_split_ring_tmp_bytes, vec_znx_sub, vec_znx_sub_inplace, vec_znx_sub_negate_inplace, vec_znx_sub_scalar, - vec_znx_sub_scalar_inplace, vec_znx_switch_ring, + vec_znx_sub_scalar_inplace, vec_znx_switch_ring, vec_znx_zero, }, source::Source, }; use crate::cpu_fft64_avx::FFT64Avx; +unsafe impl VecZnxZeroImpl for FFT64Avx { + fn vec_znx_zero_impl(_module: &Module, res: &mut R, res_col: usize) + where + R: VecZnxToMut, + { + vec_znx_zero::<_, FFT64Avx>(res, res_col); + } +} + unsafe impl VecZnxNormalizeTmpBytesImpl for FFT64Avx { fn vec_znx_normalize_tmp_bytes_impl(module: &Module) -> usize { vec_znx_normalize_tmp_bytes(module.n()) diff --git a/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs b/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs index fa88aaa..a2a2086 100644 --- a/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs +++ b/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs @@ -14,7 +14,7 @@ use poulpy_hal::{ VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl, VecZnxSplitRingTmpBytesImpl, VecZnxSubImpl, VecZnxSubInplaceImpl, VecZnxSubNegateInplaceImpl, VecZnxSubScalarImpl, - VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl, + VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl, VecZnxZeroImpl, }, reference::vec_znx::{ vec_znx_add, vec_znx_add_inplace, vec_znx_add_normal_ref, vec_znx_add_scalar, vec_znx_add_scalar_inplace, @@ -25,13 +25,22 @@ use poulpy_hal::{ vec_znx_normalize_inplace, vec_znx_normalize_tmp_bytes, vec_znx_rotate, vec_znx_rotate_inplace, vec_znx_rotate_inplace_tmp_bytes, vec_znx_rsh, vec_znx_rsh_inplace, vec_znx_rsh_tmp_bytes, vec_znx_split_ring, vec_znx_split_ring_tmp_bytes, vec_znx_sub, vec_znx_sub_inplace, vec_znx_sub_negate_inplace, vec_znx_sub_scalar, - vec_znx_sub_scalar_inplace, vec_znx_switch_ring, + vec_znx_sub_scalar_inplace, vec_znx_switch_ring, vec_znx_zero, }, source::Source, }; use crate::cpu_fft64_ref::FFT64Ref; +unsafe impl VecZnxZeroImpl for FFT64Ref { + fn vec_znx_zero_impl(_module: &Module, res: &mut R, res_col: usize) + where + R: VecZnxToMut, + { + vec_znx_zero::<_, FFT64Ref>(res, res_col); + } +} + unsafe impl VecZnxNormalizeTmpBytesImpl for FFT64Ref { fn vec_znx_normalize_tmp_bytes_impl(module: &Module) -> usize { vec_znx_normalize_tmp_bytes(module.n()) diff --git a/poulpy-backend/src/cpu_spqlios/fft64/vec_znx.rs b/poulpy-backend/src/cpu_spqlios/fft64/vec_znx.rs index c3a110a..2c35145 100644 --- a/poulpy-backend/src/cpu_spqlios/fft64/vec_znx.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/vec_znx.rs @@ -15,7 +15,7 @@ use poulpy_hal::{ VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl, VecZnxSplitRingTmpBytesImpl, VecZnxSubImpl, VecZnxSubInplaceImpl, VecZnxSubNegateInplaceImpl, VecZnxSubScalarImpl, - VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl, + VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl, VecZnxZeroImpl, }, reference::{ vec_znx::{ @@ -23,7 +23,7 @@ use poulpy_hal::{ vec_znx_fill_uniform_ref, vec_znx_lsh, vec_znx_lsh_inplace, vec_znx_lsh_tmp_bytes, vec_znx_merge_rings, vec_znx_merge_rings_tmp_bytes, vec_znx_mul_xp_minus_one_inplace_tmp_bytes, vec_znx_normalize_tmp_bytes, vec_znx_rotate_inplace_tmp_bytes, vec_znx_rsh, vec_znx_rsh_inplace, vec_znx_rsh_tmp_bytes, vec_znx_split_ring, - vec_znx_split_ring_tmp_bytes, vec_znx_switch_ring, + vec_znx_split_ring_tmp_bytes, vec_znx_switch_ring, vec_znx_zero, }, znx::{znx_copy_ref, znx_zero_ref}, }, @@ -35,6 +35,15 @@ use crate::cpu_spqlios::{ ffi::{module::module_info_t, vec_znx, znx}, }; +unsafe impl VecZnxZeroImpl for FFT64Spqlios { + fn vec_znx_zero_impl(_module: &Module, res: &mut R, res_col: usize) + where + R: VecZnxToMut, + { + vec_znx_zero::<_, FFT64Spqlios>(res, res_col); + } +} + unsafe impl VecZnxNormalizeTmpBytesImpl for FFT64Spqlios { fn vec_znx_normalize_tmp_bytes_impl(module: &Module) -> usize { vec_znx_normalize_tmp_bytes(module.n()) diff --git a/poulpy-core/src/operations/glwe.rs b/poulpy-core/src/operations/glwe.rs index 6dcb52c..95df49a 100644 --- a/poulpy-core/src/operations/glwe.rs +++ b/poulpy-core/src/operations/glwe.rs @@ -2,7 +2,7 @@ use poulpy_hal::{ api::{ ModuleN, VecZnxAdd, VecZnxAddInplace, VecZnxCopy, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, - VecZnxSubInplace, VecZnxSubNegateInplace, + VecZnxSubInplace, VecZnxSubNegateInplace, VecZnxZero, }, layouts::{Backend, Module, Scratch, VecZnx, ZnxZero}, reference::vec_znx::vec_znx_rotate_inplace_tmp_bytes, @@ -262,11 +262,11 @@ where } } -impl GLWECopy for Module where Self: ModuleN + VecZnxCopy {} +impl GLWECopy for Module where Self: ModuleN + VecZnxCopy + VecZnxZero {} pub trait GLWECopy where - Self: ModuleN + VecZnxCopy, + Self: ModuleN + VecZnxCopy + VecZnxZero, { fn glwe_copy(&self, res: &mut R, a: &A) where @@ -278,12 +278,17 @@ where assert_eq!(res.n(), self.n() as u32); assert_eq!(a.n(), self.n() as u32); - assert_eq!(res.rank(), a.rank()); - for i in 0..res.rank().as_usize() + 1 { + let min_rank: usize = res.rank().min(a.rank()).as_usize() + 1; + + for i in 0..min_rank { self.vec_znx_copy(res.data_mut(), i, a.data(), i); } + for i in min_rank..(res.rank() + 1).into() { + self.vec_znx_zero(res.data_mut(), i); + } + res.set_k(a.k().min(res.max_k())); res.set_base2k(a.base2k()); } diff --git a/poulpy-hal/src/api/vec_znx.rs b/poulpy-hal/src/api/vec_znx.rs index afa30b8..8bf0e65 100644 --- a/poulpy-hal/src/api/vec_znx.rs +++ b/poulpy-hal/src/api/vec_znx.rs @@ -8,6 +8,12 @@ pub trait VecZnxNormalizeTmpBytes { fn vec_znx_normalize_tmp_bytes(&self) -> usize; } +pub trait VecZnxZero { + fn vec_znx_zero(&self, res: &mut R, res_col: usize) + where + R: VecZnxToMut; +} + pub trait VecZnxNormalize { #[allow(clippy::too_many_arguments)] /// Normalizes the selected column of `a` and stores the result into the selected column of `res`. diff --git a/poulpy-hal/src/delegates/vec_znx.rs b/poulpy-hal/src/delegates/vec_znx.rs index 60a961e..02f512a 100644 --- a/poulpy-hal/src/delegates/vec_znx.rs +++ b/poulpy-hal/src/delegates/vec_znx.rs @@ -6,7 +6,7 @@ use crate::{ VecZnxMulXpMinusOneInplace, VecZnxMulXpMinusOneInplaceTmpBytes, VecZnxNegate, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxRsh, VecZnxRshInplace, VecZnxRshTmpBytes, VecZnxSplitRing, VecZnxSplitRingTmpBytes, VecZnxSub, VecZnxSubInplace, - VecZnxSubNegateInplace, VecZnxSubScalar, VecZnxSubScalarInplace, VecZnxSwitchRing, + VecZnxSubNegateInplace, VecZnxSubScalar, VecZnxSubScalarInplace, VecZnxSwitchRing, VecZnxZero, }, layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef}, oep::{ @@ -18,11 +18,23 @@ use crate::{ VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl, VecZnxSplitRingTmpBytesImpl, VecZnxSubImpl, VecZnxSubInplaceImpl, VecZnxSubNegateInplaceImpl, VecZnxSubScalarImpl, - VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl, + VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl, VecZnxZeroImpl, }, source::Source, }; +impl VecZnxZero for Module +where + B: Backend + VecZnxZeroImpl, +{ + fn vec_znx_zero(&self, res: &mut R, res_col: usize) + where + R: VecZnxToMut, + { + B::vec_znx_zero_impl(self, res, res_col); + } +} + impl VecZnxNormalizeTmpBytes for Module where B: Backend + VecZnxNormalizeTmpBytesImpl, diff --git a/poulpy-hal/src/oep/vec_znx.rs b/poulpy-hal/src/oep/vec_znx.rs index 380253e..47bc94a 100644 --- a/poulpy-hal/src/oep/vec_znx.rs +++ b/poulpy-hal/src/oep/vec_znx.rs @@ -3,6 +3,16 @@ use crate::{ source::Source, }; +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See TODO for reference implementation. +/// * See [crate::api::VecZnxZero] for corresponding public API. +/// # Safety [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxZeroImpl { + fn vec_znx_zero_impl(module: &Module, res: &mut R, res_col: usize) + where + R: VecZnxToMut; +} + /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. /// * See [crate::api::VecZnxNormalizeTmpBytes] for corresponding public API. diff --git a/poulpy-hal/src/reference/vec_znx/mod.rs b/poulpy-hal/src/reference/vec_znx/mod.rs index 4edb574..cb945ea 100644 --- a/poulpy-hal/src/reference/vec_znx/mod.rs +++ b/poulpy-hal/src/reference/vec_znx/mod.rs @@ -13,6 +13,7 @@ mod split_ring; mod sub; mod sub_scalar; mod switch_ring; +mod zero; pub use add::*; pub use add_scalar::*; @@ -29,3 +30,4 @@ pub use split_ring::*; pub use sub::*; pub use sub_scalar::*; pub use switch_ring::*; +pub use zero::*; diff --git a/poulpy-hal/src/reference/vec_znx/zero.rs b/poulpy-hal/src/reference/vec_znx/zero.rs new file mode 100644 index 0000000..7febbf5 --- /dev/null +++ b/poulpy-hal/src/reference/vec_znx/zero.rs @@ -0,0 +1,16 @@ +use crate::{ + layouts::{VecZnx, VecZnxToMut, ZnxInfos, ZnxViewMut}, + reference::znx::ZnxZero, +}; + +pub fn vec_znx_zero(res: &mut R, res_col: usize) +where + R: VecZnxToMut, + ZNXARI: ZnxZero, +{ + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + let res_size = res.size(); + for j in 0..res_size { + ZNXARI::znx_zero(res.at_mut(res_col, j)); + } +} diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_rotation.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_rotation.rs index bf2116e..a457208 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_rotation.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_rotation.rs @@ -1,11 +1,18 @@ use poulpy_core::{ GLWECopy, GLWERotate, ScratchTakeCore, - layouts::{GGSW, GGSWInfos, GGSWToMut, GLWE, GLWEInfos, GLWEToMut}, + layouts::{GGSW, GGSWInfos, GGSWToMut, GGSWToRef, GLWE, GLWEInfos, GLWEToMut, GLWEToRef}, }; -use poulpy_hal::layouts::{Backend, Scratch}; +use poulpy_hal::layouts::{Backend, Module, Scratch}; use crate::tfhe::bdd_arithmetic::{Cmux, GetGGSWBit, UnsignedInteger}; +impl GGSWBlindRotation for Module +where + Self: GLWEBlindRotation, + Scratch: ScratchTakeCore, +{ +} + pub trait GGSWBlindRotation where Self: GLWEBlindRotation, @@ -19,9 +26,10 @@ where self.glwe_blind_rotation_tmp_bytes(res_infos, k_infos) } - fn ggsw_blind_rotation( + fn ggsw_blind_rotation( &self, res: &mut R, + test_ggsw: &G, k: &K, bit_start: usize, bit_size: usize, @@ -29,15 +37,18 @@ where scratch: &mut Scratch, ) where R: GGSWToMut, + G: GGSWToRef, K: GetGGSWBit, Scratch: ScratchTakeCore, { let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); + let test_ggsw: &GGSW<&[u8]> = &test_ggsw.to_ref(); for row in 0..res.dnum().into() { for col in 0..(res.rank() + 1).into() { self.glwe_blind_rotation( &mut res.at_mut(row, col), + &test_ggsw.at(row, col), k, bit_start, bit_size, @@ -49,6 +60,13 @@ where } } +impl GLWEBlindRotation for Module +where + Self: GLWECopy + GLWERotate + Cmux, + Scratch: ScratchTakeCore, +{ +} + pub trait GLWEBlindRotation where Self: GLWECopy + GLWERotate + Cmux, @@ -63,9 +81,10 @@ where } /// Homomorphic multiplication of res by X^{k[bit_start..bit_start + bit_size] * bit_step}. - fn glwe_blind_rotation( + fn glwe_blind_rotation( &self, res: &mut R, + test_glwe: &G, k: &K, bit_start: usize, bit_size: usize, @@ -73,21 +92,43 @@ where scratch: &mut Scratch, ) where R: GLWEToMut, + G: GLWEToRef, K: GetGGSWBit, Scratch: ScratchTakeCore, { - let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + assert!(bit_start + bit_size <= T::WORD_SIZE); - let (mut tmp_res, scratch_1) = scratch.take_glwe(res); + let mut res: GLWE<&mut [u8]> = res.to_mut(); - self.glwe_copy(&mut tmp_res, res); + let (mut tmp_res, scratch_1) = scratch.take_glwe(&res); - for i in 1..bit_size { - // res' = res * X^2^(i * bit_step) - self.glwe_rotate(1 << (i + bit_step), &mut tmp_res, res); + // res <- test_glwe + self.glwe_copy(&mut res, test_glwe); - // res = (res - res') * GGSW(b[i]) + res' - self.cmux_inplace(res, &tmp_res, &k.get_bit(i + bit_start), scratch_1); + // a_is_res = true => (a, b) = (&mut res, &mut tmp_res) + // a_is_res = false => (a, b) = (&mut tmp_res, &mut res) + let mut a_is_res: bool = true; + + for i in 0..bit_size { + let (a, b) = if a_is_res { + (&mut res, &mut tmp_res) + } else { + (&mut tmp_res, &mut res) + }; + + // a <- a ; b <- a * X^{-2^{i + bit_step}} + self.glwe_rotate(-1 << (i + bit_step), b, a); + + // b <- (b - a) * GGSW(b[i]) + a + self.cmux_inplace(b, a, &k.get_bit(i + bit_start), scratch_1); + + // ping-pong roles for next iter + a_is_res = !a_is_res; + } + + // Ensure the final value ends up in `res` + if !a_is_res { + self.glwe_copy(&mut res, &tmp_res); } } } diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/fft64_ref.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/fft64_ref.rs index fbae3c8..05c5085 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/fft64_ref.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/fft64_ref.rs @@ -3,11 +3,16 @@ use poulpy_backend::FFT64Ref; use crate::tfhe::{ bdd_arithmetic::tests::test_suite::{ test_bdd_add, test_bdd_and, test_bdd_or, test_bdd_prepare, test_bdd_sll, test_bdd_slt, test_bdd_sltu, test_bdd_sra, - test_bdd_srl, test_bdd_sub, test_bdd_xor, + test_bdd_srl, test_bdd_sub, test_bdd_xor, test_glwe_blind_rotation, }, blind_rotation::CGGI, }; +#[test] +fn test_glwe_blind_rotation_fft64_ref() { + test_glwe_blind_rotation::() +} + #[test] fn test_bdd_prepare_fft64_ref() { test_bdd_prepare::() diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/glwe_blind_rotation.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/glwe_blind_rotation.rs new file mode 100644 index 0000000..ed1fdd0 --- /dev/null +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/glwe_blind_rotation.rs @@ -0,0 +1,134 @@ +use poulpy_core::{ + GGSWEncryptSk, GLWEDecrypt, GLWEEncryptSk, ScratchTakeCore, + layouts::{ + Base2K, Degree, Dnum, Dsize, GGSWLayout, GGSWPreparedFactory, GLWE, GLWELayout, GLWEPlaintext, GLWESecret, + GLWESecretPrepared, GLWESecretPreparedFactory, LWEInfos, Rank, TorusPrecision, + }, +}; +use poulpy_hal::{ + api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow}, + layouts::{Backend, Module, Scratch, ScratchOwned}, + source::Source, +}; +use rand::RngCore; + +use crate::tfhe::bdd_arithmetic::{FheUintBlocksPrepared, GLWEBlindRotation}; + +pub fn test_glwe_blind_rotation() +where + Module: ModuleNew + + GLWESecretPreparedFactory + + GGSWPreparedFactory + + GGSWEncryptSk + + GLWEBlindRotation + + GLWEDecrypt + + GLWEEncryptSk, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchTakeCore, +{ + let n: Degree = Degree(1 << 11); + let base2k: Base2K = Base2K(13); + let rank: Rank = Rank(1); + let k_glwe: TorusPrecision = TorusPrecision(26); + let k_ggsw: TorusPrecision = TorusPrecision(39); + let dnum: Dnum = Dnum(3); + + let glwe_infos: GLWELayout = GLWELayout { + n, + base2k, + k: k_glwe, + rank, + }; + let ggsw_infos: GGSWLayout = GGSWLayout { + n, + base2k, + k: k_ggsw, + rank, + dnum, + dsize: Dsize(1), + }; + + let n_glwe: usize = glwe_infos.n().into(); + + let module: Module = Module::::new(n_glwe as u64); + let mut source: Source = Source::new([6u8; 32]); + let mut source_xs: Source = Source::new([1u8; 32]); + let mut source_xa: Source = Source::new([2u8; 32]); + let mut source_xe: Source = Source::new([3u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(1 << 22); + + let mut sk_glwe: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_infos); + sk_glwe.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_glwe_prep: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc_from_infos(&module, &glwe_infos); + sk_glwe_prep.prepare(&module, &sk_glwe); + + let mut res: GLWE> = GLWE::alloc_from_infos(&glwe_infos); + + let mut test_glwe: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_infos); + let mut data: Vec = vec![0i64; module.n()]; + data.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); + test_glwe.encode_vec_i64(&data, base2k.as_usize().into()); + + println!("pt: {}", test_glwe); + + let k: u32 = source.next_u32(); + + println!("k: {k}"); + + let mut k_enc_prep: FheUintBlocksPrepared, u32, BE> = + FheUintBlocksPrepared::, u32, BE>::alloc(&module, &ggsw_infos); + k_enc_prep.encrypt_sk( + &module, + k, + &sk_glwe_prep, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + let base: [usize; 2] = [6, 5]; + + assert_eq!(base.iter().sum::(), module.log_n()); + + // Starting bit + let mut bit_start: usize = 0; + + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_infos); + + for _ in 0..32_usize.div_ceil(module.log_n()) { + // By how many bits to left shift + let mut bit_step: usize = 0; + + for digit in base { + let mask: u32 = (1 << digit) - 1; + + // How many bits to take + let bit_size: usize = (32 - bit_start).min(digit); + + module.glwe_blind_rotation( + &mut res, + &test_glwe, + &k_enc_prep, + bit_start, + bit_size, + bit_step, + scratch.borrow(), + ); + + res.decrypt(&module, &mut pt, &sk_glwe_prep, scratch.borrow()); + + assert_eq!( + (((k >> bit_start) & mask) << bit_step) as i64, + pt.decode_coeff_i64(base2k.as_usize().into(), 0) + ); + + bit_step += digit; + bit_start += digit; + + if bit_start >= 32 { + break; + } + } + } +} diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs index 73c96d3..76ef5e9 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs @@ -1,5 +1,6 @@ mod add; mod and; +mod glwe_blind_rotation; mod or; mod prepare; mod sll; @@ -12,6 +13,7 @@ mod xor; pub use add::*; pub use and::*; +pub use glwe_blind_rotation::*; pub use or::*; pub use prepare::*; pub use sll::*;