From 6dd93ceaeaa4090f19b7746d0d53fc930fba198a Mon Sep 17 00:00:00 2001 From: Pro7ech Date: Sun, 26 Oct 2025 10:28:13 +0100 Subject: [PATCH] Add test for ggsw scalar blind rotation --- poulpy-core/src/noise/ggsw.rs | 1 + .../src/tfhe/bdd_arithmetic/bdd_rotation.rs | 126 +++++++++++---- .../tfhe/bdd_arithmetic/tests/fft64_ref.rs | 7 +- .../tests/test_suite/ggsw_blind_rotations.rs | 149 ++++++++++++++++++ .../bdd_arithmetic/tests/test_suite/mod.rs | 2 + 5 files changed, 254 insertions(+), 31 deletions(-) create mode 100644 poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/ggsw_blind_rotations.rs diff --git a/poulpy-core/src/noise/ggsw.rs b/poulpy-core/src/noise/ggsw.rs index b6805c6..616d7c2 100644 --- a/poulpy-core/src/noise/ggsw.rs +++ b/poulpy-core/src/noise/ggsw.rs @@ -162,6 +162,7 @@ where sk_prepared, scratch.borrow(), ); + self.vec_znx_sub_inplace(&mut pt_have.data, 0, &pt.data, 0); let std_pt: f64 = pt_have.data.std(base2k, 0).log2(); diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_rotation.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_rotation.rs index a457208..89965ef 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_rotation.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_rotation.rs @@ -1,24 +1,27 @@ use poulpy_core::{ GLWECopy, GLWERotate, ScratchTakeCore, - layouts::{GGSW, GGSWInfos, GGSWToMut, GGSWToRef, GLWE, GLWEInfos, GLWEToMut, GLWEToRef}, + layouts::{GGSW, GGSWInfos, GGSWToMut, GGSWToRef, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos}, +}; +use poulpy_hal::{ + api::{VecZnxAddScalarInplace, VecZnxNormalizeInplace}, + layouts::{Backend, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxZero}, }; -use poulpy_hal::layouts::{Backend, Module, Scratch}; use crate::tfhe::bdd_arithmetic::{Cmux, GetGGSWBit, UnsignedInteger}; impl GGSWBlindRotation for Module where - Self: GLWEBlindRotation, + Self: GLWEBlindRotation + VecZnxAddScalarInplace + VecZnxNormalizeInplace, Scratch: ScratchTakeCore, { } pub trait GGSWBlindRotation where - Self: GLWEBlindRotation, + Self: GLWEBlindRotation + VecZnxAddScalarInplace + VecZnxNormalizeInplace, Scratch: ScratchTakeCore, { - fn ggsw_blind_rotation_tmp_bytes(&self, res_infos: &R, k_infos: &K) -> usize + fn ggsw_blind_rotate_from_ggsw_tmp_bytes(&self, res_infos: &R, k_infos: &K) -> usize where R: GLWEInfos, K: GGSWInfos, @@ -26,38 +29,98 @@ where self.glwe_blind_rotation_tmp_bytes(res_infos, k_infos) } - fn ggsw_blind_rotation( + /// res <- a * X^{((k>>bit_rsh) % 2^bit_mask) << bit_lsh}. + fn ggsw_blind_rotate_from_ggsw( &self, res: &mut R, - test_ggsw: &G, + a: &A, k: &K, bit_start: usize, - bit_size: usize, - bit_step: usize, + bit_mask: usize, + bit_lsh: usize, scratch: &mut Scratch, ) where R: GGSWToMut, - G: GGSWToRef, + A: GGSWToRef, K: GetGGSWBit, Scratch: ScratchTakeCore, { let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); - let test_ggsw: &GGSW<&[u8]> = &test_ggsw.to_ref(); + let a: &GGSW<&[u8]> = &a.to_ref(); - for row in 0..res.dnum().into() { - for col in 0..(res.rank() + 1).into() { + assert!(res.dnum() <= a.dnum()); + assert_eq!(res.dsize(), a.dsize()); + + for col in 0..(res.rank() + 1).into() { + for row in 0..res.dnum().into() { self.glwe_blind_rotation( &mut res.at_mut(row, col), - &test_ggsw.at(row, col), + &a.at(row, col), k, bit_start, - bit_size, - bit_step, + bit_mask, + bit_lsh, scratch, ); } } } + + fn ggsw_blind_rotate_from_scalar_tmp_bytes(&self, res_infos: &R, k_infos: &K) -> usize + where + R: GLWEInfos, + K: GGSWInfos, + { + self.glwe_blind_rotation_tmp_bytes(res_infos, k_infos) + GLWE::bytes_of_from_infos(res_infos) + } + + fn ggsw_blind_rotate_from_scalar( + &self, + res: &mut R, + test_vector: &S, + k: &K, + bit_start: usize, + bit_mask: usize, + bit_lsh: usize, + scratch: &mut Scratch, + ) where + R: GGSWToMut, + S: ScalarZnxToRef, + K: GetGGSWBit, + Scratch: ScratchTakeCore, + { + let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); + let test_vector: &ScalarZnx<&[u8]> = &test_vector.to_ref(); + + let base2k: usize = res.base2k().into(); + let dsize: usize = res.dsize().into(); + + let (mut tmp_glwe, scratch_1) = scratch.take_glwe(res); + + for col in 0..(res.rank() + 1).into() { + for row in 0..res.dnum().into() { + tmp_glwe.data_mut().zero(); + self.vec_znx_add_scalar_inplace( + tmp_glwe.data_mut(), + col, + (dsize - 1) + row * dsize, + test_vector, + 0, + ); + self.vec_znx_normalize_inplace(base2k, tmp_glwe.data_mut(), col, scratch_1); + + self.glwe_blind_rotation( + &mut res.at_mut(row, col), + &tmp_glwe, + k, + bit_start, + bit_mask, + bit_lsh, + scratch_1, + ); + } + } + } } impl GLWEBlindRotation for Module @@ -80,47 +143,50 @@ where self.cmux_tmp_bytes(res_infos, res_infos, k_infos) + GLWE::bytes_of_from_infos(res_infos) } - /// Homomorphic multiplication of res by X^{k[bit_start..bit_start + bit_size] * bit_step}. - fn glwe_blind_rotation( + /// res <- a * X^{((k>>bit_rsh) % 2^bit_mask) << bit_lsh}. + fn glwe_blind_rotation( &self, res: &mut R, - test_glwe: &G, + a: &A, k: &K, - bit_start: usize, - bit_size: usize, - bit_step: usize, + bit_rsh: usize, + bit_mask: usize, + bit_lsh: usize, scratch: &mut Scratch, ) where R: GLWEToMut, - G: GLWEToRef, + A: GLWEToRef, K: GetGGSWBit, Scratch: ScratchTakeCore, { - assert!(bit_start + bit_size <= T::WORD_SIZE); + assert!(bit_rsh + bit_mask <= T::WORD_SIZE); let mut res: GLWE<&mut [u8]> = res.to_mut(); let (mut tmp_res, scratch_1) = scratch.take_glwe(&res); - // res <- test_glwe - self.glwe_copy(&mut res, test_glwe); + // a <- a ; b <- a * X^{-2^{i + bit_lsh}} + self.glwe_rotate(-1 << bit_lsh, &mut res, a); + + // b <- (b - a) * GGSW(b[i]) + a + self.cmux_inplace(&mut res, a, &k.get_bit(bit_rsh), scratch_1); // a_is_res = true => (a, b) = (&mut res, &mut tmp_res) // a_is_res = false => (a, b) = (&mut tmp_res, &mut res) let mut a_is_res: bool = true; - for i in 0..bit_size { + for i in 1..bit_mask { let (a, b) = if a_is_res { (&mut res, &mut tmp_res) } else { (&mut tmp_res, &mut res) }; - // a <- a ; b <- a * X^{-2^{i + bit_step}} - self.glwe_rotate(-1 << (i + bit_step), b, a); + // a <- a ; b <- a * X^{-2^{i + bit_lsh}} + self.glwe_rotate(-1 << (i + bit_lsh), b, a); // b <- (b - a) * GGSW(b[i]) + a - self.cmux_inplace(b, a, &k.get_bit(i + bit_start), scratch_1); + self.cmux_inplace(b, a, &k.get_bit(i + bit_rsh), scratch_1); // ping-pong roles for next iter a_is_res = !a_is_res; diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/fft64_ref.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/fft64_ref.rs index 05c5085..94c7352 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/fft64_ref.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/fft64_ref.rs @@ -3,7 +3,7 @@ use poulpy_backend::FFT64Ref; use crate::tfhe::{ bdd_arithmetic::tests::test_suite::{ test_bdd_add, test_bdd_and, test_bdd_or, test_bdd_prepare, test_bdd_sll, test_bdd_slt, test_bdd_sltu, test_bdd_sra, - test_bdd_srl, test_bdd_sub, test_bdd_xor, test_glwe_blind_rotation, + test_bdd_srl, test_bdd_sub, test_bdd_xor, test_ggsw_blind_rotation, test_glwe_blind_rotation, }, blind_rotation::CGGI, }; @@ -13,6 +13,11 @@ fn test_glwe_blind_rotation_fft64_ref() { test_glwe_blind_rotation::() } +#[test] +fn test_ggsw_blind_rotation_fft64_ref() { + test_ggsw_blind_rotation::() +} + #[test] fn test_bdd_prepare_fft64_ref() { test_bdd_prepare::() diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/ggsw_blind_rotations.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/ggsw_blind_rotations.rs new file mode 100644 index 0000000..1e0146f --- /dev/null +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/ggsw_blind_rotations.rs @@ -0,0 +1,149 @@ +use poulpy_core::{ + GGSWEncryptSk, GGSWNoise, GLWEDecrypt, GLWEEncryptSk, SIGMA, ScratchTakeCore, + layouts::{ + Base2K, Degree, Dnum, Dsize, GGSW, GGSWLayout, GGSWPreparedFactory, GLWESecret, GLWESecretPrepared, + GLWESecretPreparedFactory, LWEInfos, Rank, TorusPrecision, + }, +}; +use poulpy_hal::{ + api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxRotateInplace}, + layouts::{Backend, Module, ScalarZnx, Scratch, ScratchOwned, ZnxView, ZnxViewMut}, + source::Source, +}; +use rand::RngCore; + +use crate::tfhe::bdd_arithmetic::{FheUintBlocksPrepared, GGSWBlindRotation}; + +pub fn test_ggsw_blind_rotation() +where + Module: ModuleNew + + GLWESecretPreparedFactory + + GGSWPreparedFactory + + GGSWEncryptSk + + GGSWBlindRotation + + GGSWNoise + + GLWEDecrypt + + GLWEEncryptSk + + VecZnxRotateInplace, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchTakeCore, +{ + let n: Degree = Degree(1 << 11); + let base2k: Base2K = Base2K(13); + let rank: Rank = Rank(1); + let k_ggsw_res: TorusPrecision = TorusPrecision(39); + let k_ggsw_apply: TorusPrecision = TorusPrecision(52); + + let ggsw_res_infos: GGSWLayout = GGSWLayout { + n, + base2k, + k: k_ggsw_res, + rank, + dnum: Dnum(2), + dsize: Dsize(1), + }; + + let ggsw_k_infos: GGSWLayout = GGSWLayout { + n, + base2k, + k: k_ggsw_apply, + rank, + dnum: Dnum(3), + dsize: Dsize(1), + }; + + let n_glwe: usize = n.into(); + + let module: Module = Module::::new(n_glwe as u64); + let mut source: Source = Source::new([6u8; 32]); + let mut source_xs: Source = Source::new([1u8; 32]); + let mut source_xa: Source = Source::new([2u8; 32]); + let mut source_xe: Source = Source::new([3u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(1 << 22); + + let mut sk_glwe: GLWESecret> = GLWESecret::alloc(n, rank); + sk_glwe.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_glwe_prep: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(&module, rank); + sk_glwe_prep.prepare(&module, &sk_glwe); + + let mut res: GGSW> = GGSW::alloc_from_infos(&ggsw_res_infos); + + let mut scalar: ScalarZnx> = ScalarZnx::alloc(n_glwe, 1); + scalar + .raw_mut() + .iter_mut() + .enumerate() + .for_each(|(i, x)| *x = i as i64); + + let k: u32 = source.next_u32(); + + // println!("k: {k}"); + + let mut k_enc_prep: FheUintBlocksPrepared, u32, BE> = + FheUintBlocksPrepared::, u32, BE>::alloc(&module, &ggsw_k_infos); + k_enc_prep.encrypt_sk( + &module, + k, + &sk_glwe_prep, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + let base: [usize; 2] = [6, 5]; + + assert_eq!(base.iter().sum::(), module.log_n()); + + // Starting bit + let mut bit_start: usize = 0; + + let max_noise = |col_i: usize| { + let mut noise: f64 = -(ggsw_res_infos.size() as f64 * base2k.as_usize() as f64) + SIGMA.log2() + 2.0; + noise += 0.5 * ggsw_res_infos.log_n() as f64; + if col_i != 0 { + noise += 0.5 * ggsw_res_infos.log_n() as f64 + } + noise + }; + + for _ in 0..32_usize.div_ceil(module.log_n()) { + // By how many bits to left shift + let mut bit_step: usize = 0; + + for digit in base { + let mask: u32 = (1 << digit) - 1; + + // How many bits to take + let bit_size: usize = (32 - bit_start).min(digit); + + module.ggsw_blind_rotate_from_scalar( + &mut res, + &scalar, + &k_enc_prep, + bit_start, + bit_size, + bit_step, + scratch.borrow(), + ); + + let rot: i64 = (((k >> bit_start) & mask) << bit_step) as i64; + + let mut scalar_want: ScalarZnx> = ScalarZnx::alloc(module.n(), 1); + scalar_want.raw_mut().copy_from_slice(scalar.raw()); + + module.vec_znx_rotate_inplace(-rot, &mut scalar_want.as_vec_znx_mut(), 0, scratch.borrow()); + + // res.print_noise(&module, &sk_glwe_prep, &scalar_want); + + res.assert_noise(&module, &sk_glwe_prep, &scalar_want, &max_noise); + + bit_step += digit; + bit_start += digit; + + if bit_start >= 32 { + break; + } + } + } +} diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs index 76ef5e9..1b2c645 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs @@ -1,5 +1,6 @@ mod add; mod and; +mod ggsw_blind_rotations; mod glwe_blind_rotation; mod or; mod prepare; @@ -13,6 +14,7 @@ mod xor; pub use add::*; pub use and::*; +pub use ggsw_blind_rotations::*; pub use glwe_blind_rotation::*; pub use or::*; pub use prepare::*;