Add test for ggsw scalar blind rotation

This commit is contained in:
Pro7ech
2025-10-26 10:28:13 +01:00
parent 98208d5e67
commit 6dd93ceaea
5 changed files with 254 additions and 31 deletions

View File

@@ -162,6 +162,7 @@ where
sk_prepared, sk_prepared,
scratch.borrow(), scratch.borrow(),
); );
self.vec_znx_sub_inplace(&mut pt_have.data, 0, &pt.data, 0); self.vec_znx_sub_inplace(&mut pt_have.data, 0, &pt.data, 0);
let std_pt: f64 = pt_have.data.std(base2k, 0).log2(); let std_pt: f64 = pt_have.data.std(base2k, 0).log2();

View File

@@ -1,24 +1,27 @@
use poulpy_core::{ use poulpy_core::{
GLWECopy, GLWERotate, ScratchTakeCore, GLWECopy, GLWERotate, ScratchTakeCore,
layouts::{GGSW, GGSWInfos, GGSWToMut, GGSWToRef, GLWE, GLWEInfos, GLWEToMut, GLWEToRef}, layouts::{GGSW, GGSWInfos, GGSWToMut, GGSWToRef, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos},
};
use poulpy_hal::{
api::{VecZnxAddScalarInplace, VecZnxNormalizeInplace},
layouts::{Backend, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxZero},
}; };
use poulpy_hal::layouts::{Backend, Module, Scratch};
use crate::tfhe::bdd_arithmetic::{Cmux, GetGGSWBit, UnsignedInteger}; use crate::tfhe::bdd_arithmetic::{Cmux, GetGGSWBit, UnsignedInteger};
impl<T: UnsignedInteger, BE: Backend> GGSWBlindRotation<T, BE> for Module<BE> impl<T: UnsignedInteger, BE: Backend> GGSWBlindRotation<T, BE> for Module<BE>
where where
Self: GLWEBlindRotation<T, BE>, Self: GLWEBlindRotation<T, BE> + VecZnxAddScalarInplace + VecZnxNormalizeInplace<BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
} }
pub trait GGSWBlindRotation<T: UnsignedInteger, BE: Backend> pub trait GGSWBlindRotation<T: UnsignedInteger, BE: Backend>
where where
Self: GLWEBlindRotation<T, BE>, Self: GLWEBlindRotation<T, BE> + VecZnxAddScalarInplace + VecZnxNormalizeInplace<BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
fn ggsw_blind_rotation_tmp_bytes<R, K>(&self, res_infos: &R, k_infos: &K) -> usize fn ggsw_blind_rotate_from_ggsw_tmp_bytes<R, K>(&self, res_infos: &R, k_infos: &K) -> usize
where where
R: GLWEInfos, R: GLWEInfos,
K: GGSWInfos, K: GGSWInfos,
@@ -26,38 +29,98 @@ where
self.glwe_blind_rotation_tmp_bytes(res_infos, k_infos) self.glwe_blind_rotation_tmp_bytes(res_infos, k_infos)
} }
fn ggsw_blind_rotation<R, G, K>( /// res <- a * X^{((k>>bit_rsh) % 2^bit_mask) << bit_lsh}.
fn ggsw_blind_rotate_from_ggsw<R, A, K>(
&self, &self,
res: &mut R, res: &mut R,
test_ggsw: &G, a: &A,
k: &K, k: &K,
bit_start: usize, bit_start: usize,
bit_size: usize, bit_mask: usize,
bit_step: usize, bit_lsh: usize,
scratch: &mut Scratch<BE>, scratch: &mut Scratch<BE>,
) where ) where
R: GGSWToMut, R: GGSWToMut,
G: GGSWToRef, A: GGSWToRef,
K: GetGGSWBit<T, BE>, K: GetGGSWBit<T, BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
let test_ggsw: &GGSW<&[u8]> = &test_ggsw.to_ref(); let a: &GGSW<&[u8]> = &a.to_ref();
assert!(res.dnum() <= a.dnum());
assert_eq!(res.dsize(), a.dsize());
for row in 0..res.dnum().into() {
for col in 0..(res.rank() + 1).into() { for col in 0..(res.rank() + 1).into() {
for row in 0..res.dnum().into() {
self.glwe_blind_rotation( self.glwe_blind_rotation(
&mut res.at_mut(row, col), &mut res.at_mut(row, col),
&test_ggsw.at(row, col), &a.at(row, col),
k, k,
bit_start, bit_start,
bit_size, bit_mask,
bit_step, bit_lsh,
scratch, scratch,
); );
} }
} }
} }
fn ggsw_blind_rotate_from_scalar_tmp_bytes<R, K>(&self, res_infos: &R, k_infos: &K) -> usize
where
R: GLWEInfos,
K: GGSWInfos,
{
self.glwe_blind_rotation_tmp_bytes(res_infos, k_infos) + GLWE::bytes_of_from_infos(res_infos)
}
fn ggsw_blind_rotate_from_scalar<R, S, K>(
&self,
res: &mut R,
test_vector: &S,
k: &K,
bit_start: usize,
bit_mask: usize,
bit_lsh: usize,
scratch: &mut Scratch<BE>,
) where
R: GGSWToMut,
S: ScalarZnxToRef,
K: GetGGSWBit<T, BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
let test_vector: &ScalarZnx<&[u8]> = &test_vector.to_ref();
let base2k: usize = res.base2k().into();
let dsize: usize = res.dsize().into();
let (mut tmp_glwe, scratch_1) = scratch.take_glwe(res);
for col in 0..(res.rank() + 1).into() {
for row in 0..res.dnum().into() {
tmp_glwe.data_mut().zero();
self.vec_znx_add_scalar_inplace(
tmp_glwe.data_mut(),
col,
(dsize - 1) + row * dsize,
test_vector,
0,
);
self.vec_znx_normalize_inplace(base2k, tmp_glwe.data_mut(), col, scratch_1);
self.glwe_blind_rotation(
&mut res.at_mut(row, col),
&tmp_glwe,
k,
bit_start,
bit_mask,
bit_lsh,
scratch_1,
);
}
}
}
} }
impl<T: UnsignedInteger, BE: Backend> GLWEBlindRotation<T, BE> for Module<BE> impl<T: UnsignedInteger, BE: Backend> GLWEBlindRotation<T, BE> for Module<BE>
@@ -80,47 +143,50 @@ where
self.cmux_tmp_bytes(res_infos, res_infos, k_infos) + GLWE::bytes_of_from_infos(res_infos) self.cmux_tmp_bytes(res_infos, res_infos, k_infos) + GLWE::bytes_of_from_infos(res_infos)
} }
/// Homomorphic multiplication of res by X^{k[bit_start..bit_start + bit_size] * bit_step}. /// res <- a * X^{((k>>bit_rsh) % 2^bit_mask) << bit_lsh}.
fn glwe_blind_rotation<R, G, K>( fn glwe_blind_rotation<R, A, K>(
&self, &self,
res: &mut R, res: &mut R,
test_glwe: &G, a: &A,
k: &K, k: &K,
bit_start: usize, bit_rsh: usize,
bit_size: usize, bit_mask: usize,
bit_step: usize, bit_lsh: usize,
scratch: &mut Scratch<BE>, scratch: &mut Scratch<BE>,
) where ) where
R: GLWEToMut, R: GLWEToMut,
G: GLWEToRef, A: GLWEToRef,
K: GetGGSWBit<T, BE>, K: GetGGSWBit<T, BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
assert!(bit_start + bit_size <= T::WORD_SIZE); assert!(bit_rsh + bit_mask <= T::WORD_SIZE);
let mut res: GLWE<&mut [u8]> = res.to_mut(); let mut res: GLWE<&mut [u8]> = res.to_mut();
let (mut tmp_res, scratch_1) = scratch.take_glwe(&res); let (mut tmp_res, scratch_1) = scratch.take_glwe(&res);
// res <- test_glwe // a <- a ; b <- a * X^{-2^{i + bit_lsh}}
self.glwe_copy(&mut res, test_glwe); self.glwe_rotate(-1 << bit_lsh, &mut res, a);
// b <- (b - a) * GGSW(b[i]) + a
self.cmux_inplace(&mut res, a, &k.get_bit(bit_rsh), scratch_1);
// a_is_res = true => (a, b) = (&mut res, &mut tmp_res) // a_is_res = true => (a, b) = (&mut res, &mut tmp_res)
// a_is_res = false => (a, b) = (&mut tmp_res, &mut res) // a_is_res = false => (a, b) = (&mut tmp_res, &mut res)
let mut a_is_res: bool = true; let mut a_is_res: bool = true;
for i in 0..bit_size { for i in 1..bit_mask {
let (a, b) = if a_is_res { let (a, b) = if a_is_res {
(&mut res, &mut tmp_res) (&mut res, &mut tmp_res)
} else { } else {
(&mut tmp_res, &mut res) (&mut tmp_res, &mut res)
}; };
// a <- a ; b <- a * X^{-2^{i + bit_step}} // a <- a ; b <- a * X^{-2^{i + bit_lsh}}
self.glwe_rotate(-1 << (i + bit_step), b, a); self.glwe_rotate(-1 << (i + bit_lsh), b, a);
// b <- (b - a) * GGSW(b[i]) + a // b <- (b - a) * GGSW(b[i]) + a
self.cmux_inplace(b, a, &k.get_bit(i + bit_start), scratch_1); self.cmux_inplace(b, a, &k.get_bit(i + bit_rsh), scratch_1);
// ping-pong roles for next iter // ping-pong roles for next iter
a_is_res = !a_is_res; a_is_res = !a_is_res;

View File

@@ -3,7 +3,7 @@ use poulpy_backend::FFT64Ref;
use crate::tfhe::{ use crate::tfhe::{
bdd_arithmetic::tests::test_suite::{ bdd_arithmetic::tests::test_suite::{
test_bdd_add, test_bdd_and, test_bdd_or, test_bdd_prepare, test_bdd_sll, test_bdd_slt, test_bdd_sltu, test_bdd_sra, test_bdd_add, test_bdd_and, test_bdd_or, test_bdd_prepare, test_bdd_sll, test_bdd_slt, test_bdd_sltu, test_bdd_sra,
test_bdd_srl, test_bdd_sub, test_bdd_xor, test_glwe_blind_rotation, test_bdd_srl, test_bdd_sub, test_bdd_xor, test_ggsw_blind_rotation, test_glwe_blind_rotation,
}, },
blind_rotation::CGGI, blind_rotation::CGGI,
}; };
@@ -13,6 +13,11 @@ fn test_glwe_blind_rotation_fft64_ref() {
test_glwe_blind_rotation::<FFT64Ref>() test_glwe_blind_rotation::<FFT64Ref>()
} }
#[test]
fn test_ggsw_blind_rotation_fft64_ref() {
test_ggsw_blind_rotation::<FFT64Ref>()
}
#[test] #[test]
fn test_bdd_prepare_fft64_ref() { fn test_bdd_prepare_fft64_ref() {
test_bdd_prepare::<CGGI, FFT64Ref>() test_bdd_prepare::<CGGI, FFT64Ref>()

View File

@@ -0,0 +1,149 @@
use poulpy_core::{
GGSWEncryptSk, GGSWNoise, GLWEDecrypt, GLWEEncryptSk, SIGMA, ScratchTakeCore,
layouts::{
Base2K, Degree, Dnum, Dsize, GGSW, GGSWLayout, GGSWPreparedFactory, GLWESecret, GLWESecretPrepared,
GLWESecretPreparedFactory, LWEInfos, Rank, TorusPrecision,
},
};
use poulpy_hal::{
api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxRotateInplace},
layouts::{Backend, Module, ScalarZnx, Scratch, ScratchOwned, ZnxView, ZnxViewMut},
source::Source,
};
use rand::RngCore;
use crate::tfhe::bdd_arithmetic::{FheUintBlocksPrepared, GGSWBlindRotation};
pub fn test_ggsw_blind_rotation<BE: Backend>()
where
Module<BE>: ModuleNew<BE>
+ GLWESecretPreparedFactory<BE>
+ GGSWPreparedFactory<BE>
+ GGSWEncryptSk<BE>
+ GGSWBlindRotation<u32, BE>
+ GGSWNoise<BE>
+ GLWEDecrypt<BE>
+ GLWEEncryptSk<BE>
+ VecZnxRotateInplace<BE>,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
let n: Degree = Degree(1 << 11);
let base2k: Base2K = Base2K(13);
let rank: Rank = Rank(1);
let k_ggsw_res: TorusPrecision = TorusPrecision(39);
let k_ggsw_apply: TorusPrecision = TorusPrecision(52);
let ggsw_res_infos: GGSWLayout = GGSWLayout {
n,
base2k,
k: k_ggsw_res,
rank,
dnum: Dnum(2),
dsize: Dsize(1),
};
let ggsw_k_infos: GGSWLayout = GGSWLayout {
n,
base2k,
k: k_ggsw_apply,
rank,
dnum: Dnum(3),
dsize: Dsize(1),
};
let n_glwe: usize = n.into();
let module: Module<BE> = Module::<BE>::new(n_glwe as u64);
let mut source: Source = Source::new([6u8; 32]);
let mut source_xs: Source = Source::new([1u8; 32]);
let mut source_xa: Source = Source::new([2u8; 32]);
let mut source_xe: Source = Source::new([3u8; 32]);
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(1 << 22);
let mut sk_glwe: GLWESecret<Vec<u8>> = GLWESecret::alloc(n, rank);
sk_glwe.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_glwe_prep: GLWESecretPrepared<Vec<u8>, BE> = GLWESecretPrepared::alloc(&module, rank);
sk_glwe_prep.prepare(&module, &sk_glwe);
let mut res: GGSW<Vec<u8>> = GGSW::alloc_from_infos(&ggsw_res_infos);
let mut scalar: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n_glwe, 1);
scalar
.raw_mut()
.iter_mut()
.enumerate()
.for_each(|(i, x)| *x = i as i64);
let k: u32 = source.next_u32();
// println!("k: {k}");
let mut k_enc_prep: FheUintBlocksPrepared<Vec<u8>, u32, BE> =
FheUintBlocksPrepared::<Vec<u8>, u32, BE>::alloc(&module, &ggsw_k_infos);
k_enc_prep.encrypt_sk(
&module,
k,
&sk_glwe_prep,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
let base: [usize; 2] = [6, 5];
assert_eq!(base.iter().sum::<usize>(), module.log_n());
// Starting bit
let mut bit_start: usize = 0;
let max_noise = |col_i: usize| {
let mut noise: f64 = -(ggsw_res_infos.size() as f64 * base2k.as_usize() as f64) + SIGMA.log2() + 2.0;
noise += 0.5 * ggsw_res_infos.log_n() as f64;
if col_i != 0 {
noise += 0.5 * ggsw_res_infos.log_n() as f64
}
noise
};
for _ in 0..32_usize.div_ceil(module.log_n()) {
// By how many bits to left shift
let mut bit_step: usize = 0;
for digit in base {
let mask: u32 = (1 << digit) - 1;
// How many bits to take
let bit_size: usize = (32 - bit_start).min(digit);
module.ggsw_blind_rotate_from_scalar(
&mut res,
&scalar,
&k_enc_prep,
bit_start,
bit_size,
bit_step,
scratch.borrow(),
);
let rot: i64 = (((k >> bit_start) & mask) << bit_step) as i64;
let mut scalar_want: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(module.n(), 1);
scalar_want.raw_mut().copy_from_slice(scalar.raw());
module.vec_znx_rotate_inplace(-rot, &mut scalar_want.as_vec_znx_mut(), 0, scratch.borrow());
// res.print_noise(&module, &sk_glwe_prep, &scalar_want);
res.assert_noise(&module, &sk_glwe_prep, &scalar_want, &max_noise);
bit_step += digit;
bit_start += digit;
if bit_start >= 32 {
break;
}
}
}
}

View File

@@ -1,5 +1,6 @@
mod add; mod add;
mod and; mod and;
mod ggsw_blind_rotations;
mod glwe_blind_rotation; mod glwe_blind_rotation;
mod or; mod or;
mod prepare; mod prepare;
@@ -13,6 +14,7 @@ mod xor;
pub use add::*; pub use add::*;
pub use and::*; pub use and::*;
pub use ggsw_blind_rotations::*;
pub use glwe_blind_rotation::*; pub use glwe_blind_rotation::*;
pub use or::*; pub use or::*;
pub use prepare::*; pub use prepare::*;