Added missing tests for CGGI & added standard blind rotation

This commit is contained in:
Jean-Philippe Bossuat
2025-07-08 13:23:38 +02:00
parent 5234c3fc63
commit 992cb3fa37
3 changed files with 130 additions and 36 deletions

View File

@@ -6,7 +6,7 @@ use backend::{
use itertools::izip; use itertools::izip;
use crate::{ use crate::{
GLWECiphertext, GLWECiphertextToMut, Infos, LWECiphertext, GLWECiphertext, GLWECiphertextToMut, GLWEOps, Infos, LWECiphertext, ScratchCore,
blind_rotation::{key::BlindRotationKeyCGGI, lut::LookUpTable}, blind_rotation::{key::BlindRotationKeyCGGI, lut::LookUpTable},
lwe::ciphertext::LWECiphertextToRef, lwe::ciphertext::LWECiphertextToRef,
}; };
@@ -63,7 +63,7 @@ pub fn cggi_blind_rotate<DataRes, DataIn>(
} else if brk.block_size() > 1 { } else if brk.block_size() > 1 {
cggi_blind_rotate_block_binary(module, res, lwe, lut, brk, scratch); cggi_blind_rotate_block_binary(module, res, lwe, lut, brk, scratch);
} else { } else {
todo!("implement this case") cggi_blind_rotate_standard(module, res, lwe, lut, brk, scratch);
} }
} }
@@ -121,8 +121,7 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended<DataRes, DataIn>(
a.chunks_exact(block_size), a.chunks_exact(block_size),
brk.data.chunks_exact(block_size) brk.data.chunks_exact(block_size)
) )
.enumerate() .for_each(|(ai, ski)| {
.for_each(|(i, (ai, ski))| {
(0..extension_factor).for_each(|i| { (0..extension_factor).for_each(|i| {
(0..cols).for_each(|j| { (0..cols).for_each(|j| {
module.vec_znx_dft(1, 0, &mut acc_dft[i], j, &acc[i], j); module.vec_znx_dft(1, 0, &mut acc_dft[i], j, &acc[i], j);
@@ -323,6 +322,96 @@ pub(crate) fn cggi_blind_rotate_block_binary<DataRes, DataIn>(
}); });
} }
pub(crate) fn cggi_blind_rotate_standard<DataRes, DataIn>(
module: &Module<FFT64>,
res: &mut GLWECiphertext<DataRes>,
lwe: &LWECiphertext<DataIn>,
lut: &LookUpTable,
brk: &BlindRotationKeyCGGI<FFT64>,
scratch: &mut Scratch,
) where
DataRes: AsRef<[u8]> + AsMut<[u8]>,
DataIn: AsRef<[u8]>,
{
#[cfg(debug_assertions)]
{
assert_eq!(
res.n(),
module.n(),
"res.n(): {} != brk.n(): {}",
res.n(),
module.n()
);
assert_eq!(
lut.domain_size(),
module.n(),
"lut.n(): {} != brk.n(): {}",
lut.domain_size(),
module.n()
);
assert_eq!(
brk.n(),
module.n(),
"brk.n(): {} != brk.n(): {}",
brk.n(),
module.n()
);
assert_eq!(
res.rank(),
brk.rank(),
"res.rank(): {} != brk.rank(): {}",
res.rank(),
brk.rank()
);
assert_eq!(
lwe.n(),
brk.data.len(),
"lwe.n(): {} != brk.data.len(): {}",
lwe.n(),
brk.data.len()
);
}
let mut lwe_2n: Vec<i64> = vec![0i64; lwe.n() + 1]; // TODO: from scratch space
let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut();
let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref();
let basek: usize = brk.basek();
negate_and_mod_switch_2n(2 * lut.domain_size(), &mut lwe_2n, &lwe_ref);
let a: &[i64] = &lwe_2n[1..];
let b: i64 = lwe_2n[0];
out_mut.data.zero();
// Initialize out to X^{b} * LUT(X)
module.vec_znx_rotate(b, &mut out_mut.data, 0, &lut.data[0], 0);
// ACC + [sum DFT(X^ai -1) * (DFT(ACC) x BRKi)]
let (mut acc_tmp, scratch1) = scratch.tmp_glwe_ct(module, basek, out_mut.k(), out_mut.rank());
let (mut acc_tmp_rot, scratch2) = scratch1.tmp_glwe_ct(module, basek, out_mut.k(), out_mut.rank());
// TODO: see if faster by skipping normalization in external product and keeping acc in big coeffs
// TODO: first iteration can be optimized to be a gglwe product
izip!(a.iter(), brk.data.iter()).for_each(|(ai, ski)| {
// acc_tmp = sk[i] * acc
acc_tmp.external_product(module, &out_mut, ski, scratch2);
// acc_tmp = (sk[i] * acc) * X^{ai}
acc_tmp_rot.rotate(module, *ai, &acc_tmp);
// acc = acc + (sk[i] * acc) * X^{ai}
out_mut.add_inplace(module, &acc_tmp_rot);
// acc = acc + (sk[i] * acc) * X^{ai} - (sk[i] * acc) = acc + (sk[i] * acc) * (X^{ai} - 1)
out_mut.sub_inplace_ab(module, &acc_tmp);
});
// We can normalize only at the end because we add normalized values in [-2^{basek-1}, 2^{basek-1}]
// on top of each others, thus ~ 2^{63-basek} additions are supported before overflow.
out_mut.normalize_inplace(module, scratch2);
}
pub(crate) fn negate_and_mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphertext<&[u8]>) { pub(crate) fn negate_and_mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphertext<&[u8]>) {
let basek: usize = lwe.basek(); let basek: usize = lwe.basek();

View File

@@ -74,6 +74,11 @@ impl BlindRotationKeyCGGI<FFT64> {
} }
} }
#[allow(dead_code)]
pub(crate) fn n(&self) -> usize {
self.data[0].n()
}
#[allow(dead_code)] #[allow(dead_code)]
pub(crate) fn rows(&self) -> usize { pub(crate) fn rows(&self) -> usize {
self.data[0].rows() self.data[0].rows()

View File

@@ -1,6 +1,4 @@
use std::time::Instant; use backend::{Encoding, FFT64, Module, ScratchOwned, ZnxView};
use backend::{Encoding, FFT64, Module, ScratchOwned, Stats, VecZnxOps, ZnxView};
use sampling::source::Source; use sampling::source::Source;
use crate::{ use crate::{
@@ -14,22 +12,31 @@ use crate::{
}; };
#[test] #[test]
fn blind_rotation() { fn standard() {
let module: Module<FFT64> = Module::<FFT64>::new(2048); blind_rotatio_test(224, 1, 1);
let basek: usize = 19; }
let n_lwe: usize = 1071; #[test]
fn block_binary() {
blind_rotatio_test(224, 7, 1);
}
#[test]
fn block_binary_extended() {
blind_rotatio_test(224, 7, 2);
}
fn blind_rotatio_test(n_lwe: usize, block_size: usize, extension_factor: usize) {
let module: Module<FFT64> = Module::<FFT64>::new(512);
let basek: usize = 19;
let k_lwe: usize = 24; let k_lwe: usize = 24;
let k_brk: usize = 3 * basek; let k_brk: usize = 3 * basek;
let rows_brk: usize = 1; let rows_brk: usize = 2; // Ensures first limb is noise-free.
let k_lut: usize = 2 * basek; let k_lut: usize = 2 * basek;
let rank: usize = 1; let rank: usize = 1;
let block_size: usize = 7;
let extension_factor: usize = 2; let message_modulus: usize = 1 << 4;
let message_modulus: usize = 1 << 6;
let mut source_xs: Source = Source::new([1u8; 32]); let mut source_xs: Source = Source::new([1u8; 32]);
let mut source_xe: Source = Source::new([1u8; 32]); let mut source_xe: Source = Source::new([1u8; 32]);
@@ -56,7 +63,6 @@ fn blind_rotation() {
rank, rank,
)); ));
let start: Instant = Instant::now();
let mut brk: BlindRotationKeyCGGI<FFT64> = BlindRotationKeyCGGI::allocate(&module, n_lwe, basek, k_brk, rows_brk, rank); let mut brk: BlindRotationKeyCGGI<FFT64> = BlindRotationKeyCGGI::allocate(&module, n_lwe, basek, k_brk, rows_brk, rank);
brk.generate_from_sk( brk.generate_from_sk(
@@ -69,9 +75,6 @@ fn blind_rotation() {
scratch.borrow(), scratch.borrow(),
); );
let duration: std::time::Duration = start.elapsed();
println!("brk-gen: {} ms", duration.as_millis());
let mut lwe: LWECiphertext<Vec<u8>> = LWECiphertext::alloc(n_lwe, basek, k_lwe); let mut lwe: LWECiphertext<Vec<u8>> = LWECiphertext::alloc(n_lwe, basek, k_lwe);
let mut pt_lwe: LWEPlaintext<Vec<u8>> = LWEPlaintext::alloc(basek, k_lwe); let mut pt_lwe: LWEPlaintext<Vec<u8>> = LWEPlaintext::alloc(basek, k_lwe);
@@ -81,13 +84,13 @@ fn blind_rotation() {
pt_lwe.data.encode_coeff_i64(0, basek, bits, 0, x, bits); pt_lwe.data.encode_coeff_i64(0, basek, bits, 0, x, bits);
println!("{}", pt_lwe.data); // println!("{}", pt_lwe.data);
lwe.encrypt_sk(&pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe, 3.2); lwe.encrypt_sk(&pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe, 3.2);
lwe.decrypt(&mut pt_lwe, &sk_lwe); lwe.decrypt(&mut pt_lwe, &sk_lwe);
println!("{}", pt_lwe.data); // println!("{}", pt_lwe.data);
let mut f: Vec<i64> = vec![0i64; message_modulus]; let mut f: Vec<i64> = vec![0i64; message_modulus];
f.iter_mut() f.iter_mut()
@@ -99,13 +102,9 @@ fn blind_rotation() {
let mut res: GLWECiphertext<Vec<u8>> = GLWECiphertext::alloc(&module, basek, k_lut, rank); let mut res: GLWECiphertext<Vec<u8>> = GLWECiphertext::alloc(&module, basek, k_lut, rank);
let start: Instant = Instant::now(); cggi_blind_rotate(&module, &mut res, &lwe, &lut, &brk, scratch_br.borrow());
(0..32).for_each(|_| {
cggi_blind_rotate(&module, &mut res, &lwe, &lut, &brk, scratch_br.borrow());
});
let duration: std::time::Duration = start.elapsed(); println!("out_mut.data: {}", res.data);
println!("blind-rotate: {} ms", duration.as_millis());
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(&module, basek, k_lut); let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(&module, basek, k_lut);
@@ -125,20 +124,21 @@ fn blind_rotation() {
.sum::<i64>()) .sum::<i64>())
% (2 * lut.domain_size()) as i64; % (2 * lut.domain_size()) as i64;
println!("pt_want: {}", pt_want); // println!("pt_want: {}", pt_want);
lut.rotate(pt_want); lut.rotate(pt_want);
lut.data.iter().for_each(|d| { // lut.data.iter().for_each(|d| {
println!("{}", d); // println!("{}", d);
}); // });
// First limb should be exactly equal (test are parameterized such that the noise does not reach // First limb should be exactly equal (test are parameterized such that the noise does not reach
// the first limb) // the first limb)
// assert_eq!(pt_have.data.at_mut(0, 0), lut.data[0].at_mut(0, 0)); assert_eq!(pt_have.data.at(0, 0), lut.data[0].at(0, 0));
// Then checks the noise // Then checks the noise
module.vec_znx_sub_ab_inplace(&mut lut.data[0], 0, &pt_have.data, 0); // module.vec_znx_sub_ab_inplace(&mut lut.data[0], 0, &pt_have.data, 0);
let noise: f64 = lut.data[0].std(0, basek); // let noise: f64 = lut.data[0].std(0, basek);
println!("noise: {}", noise); // println!("noise: {}", noise);
// assert!(noise < 1e-3);
} }