From 992cb3fa37825fbb2ee8441c2338e6133305d566 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 8 Jul 2025 13:23:38 +0200 Subject: [PATCH] Added missing tests for CGGI & added standard blind rotation --- core/src/blind_rotation/ccgi.rs | 97 +++++++++++++++++++++- core/src/blind_rotation/key.rs | 5 ++ core/src/blind_rotation/test_fft64/cggi.rs | 64 +++++++------- 3 files changed, 130 insertions(+), 36 deletions(-) diff --git a/core/src/blind_rotation/ccgi.rs b/core/src/blind_rotation/ccgi.rs index 684ef72..2b5e877 100644 --- a/core/src/blind_rotation/ccgi.rs +++ b/core/src/blind_rotation/ccgi.rs @@ -6,7 +6,7 @@ use backend::{ use itertools::izip; use crate::{ - GLWECiphertext, GLWECiphertextToMut, Infos, LWECiphertext, + GLWECiphertext, GLWECiphertextToMut, GLWEOps, Infos, LWECiphertext, ScratchCore, blind_rotation::{key::BlindRotationKeyCGGI, lut::LookUpTable}, lwe::ciphertext::LWECiphertextToRef, }; @@ -63,7 +63,7 @@ pub fn cggi_blind_rotate( } else if brk.block_size() > 1 { cggi_blind_rotate_block_binary(module, res, lwe, lut, brk, scratch); } else { - todo!("implement this case") + cggi_blind_rotate_standard(module, res, lwe, lut, brk, scratch); } } @@ -121,8 +121,7 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended( a.chunks_exact(block_size), brk.data.chunks_exact(block_size) ) - .enumerate() - .for_each(|(i, (ai, ski))| { + .for_each(|(ai, ski)| { (0..extension_factor).for_each(|i| { (0..cols).for_each(|j| { module.vec_znx_dft(1, 0, &mut acc_dft[i], j, &acc[i], j); @@ -323,6 +322,96 @@ pub(crate) fn cggi_blind_rotate_block_binary( }); } +pub(crate) fn cggi_blind_rotate_standard( + module: &Module, + res: &mut GLWECiphertext, + lwe: &LWECiphertext, + lut: &LookUpTable, + brk: &BlindRotationKeyCGGI, + scratch: &mut Scratch, +) where + DataRes: AsRef<[u8]> + AsMut<[u8]>, + DataIn: AsRef<[u8]>, +{ + #[cfg(debug_assertions)] + { + assert_eq!( + res.n(), + module.n(), + "res.n(): {} != brk.n(): {}", + res.n(), + module.n() + ); + assert_eq!( + lut.domain_size(), + module.n(), + "lut.n(): {} != brk.n(): {}", + lut.domain_size(), + module.n() + ); + assert_eq!( + brk.n(), + module.n(), + "brk.n(): {} != brk.n(): {}", + brk.n(), + module.n() + ); + assert_eq!( + res.rank(), + brk.rank(), + "res.rank(): {} != brk.rank(): {}", + res.rank(), + brk.rank() + ); + assert_eq!( + lwe.n(), + brk.data.len(), + "lwe.n(): {} != brk.data.len(): {}", + lwe.n(), + brk.data.len() + ); + } + + let mut lwe_2n: Vec = vec![0i64; lwe.n() + 1]; // TODO: from scratch space + let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut(); + let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref(); + let basek: usize = brk.basek(); + + negate_and_mod_switch_2n(2 * lut.domain_size(), &mut lwe_2n, &lwe_ref); + + let a: &[i64] = &lwe_2n[1..]; + let b: i64 = lwe_2n[0]; + + out_mut.data.zero(); + + // Initialize out to X^{b} * LUT(X) + module.vec_znx_rotate(b, &mut out_mut.data, 0, &lut.data[0], 0); + + // ACC + [sum DFT(X^ai -1) * (DFT(ACC) x BRKi)] + let (mut acc_tmp, scratch1) = scratch.tmp_glwe_ct(module, basek, out_mut.k(), out_mut.rank()); + let (mut acc_tmp_rot, scratch2) = scratch1.tmp_glwe_ct(module, basek, out_mut.k(), out_mut.rank()); + + // TODO: see if faster by skipping normalization in external product and keeping acc in big coeffs + // TODO: first iteration can be optimized to be a gglwe product + izip!(a.iter(), brk.data.iter()).for_each(|(ai, ski)| { + // acc_tmp = sk[i] * acc + acc_tmp.external_product(module, &out_mut, ski, scratch2); + + // acc_tmp = (sk[i] * acc) * X^{ai} + acc_tmp_rot.rotate(module, *ai, &acc_tmp); + + // acc = acc + (sk[i] * acc) * X^{ai} + out_mut.add_inplace(module, &acc_tmp_rot); + + // acc = acc + (sk[i] * acc) * X^{ai} - (sk[i] * acc) = acc + (sk[i] * acc) * (X^{ai} - 1) + out_mut.sub_inplace_ab(module, &acc_tmp); + }); + + // We can normalize only at the end because we add normalized values in [-2^{basek-1}, 2^{basek-1}] + // on top of each others, thus ~ 2^{63-basek} additions are supported before overflow. + out_mut.normalize_inplace(module, scratch2); +} + pub(crate) fn negate_and_mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphertext<&[u8]>) { let basek: usize = lwe.basek(); diff --git a/core/src/blind_rotation/key.rs b/core/src/blind_rotation/key.rs index b7f9c3f..b83d60c 100644 --- a/core/src/blind_rotation/key.rs +++ b/core/src/blind_rotation/key.rs @@ -74,6 +74,11 @@ impl BlindRotationKeyCGGI { } } + #[allow(dead_code)] + pub(crate) fn n(&self) -> usize { + self.data[0].n() + } + #[allow(dead_code)] pub(crate) fn rows(&self) -> usize { self.data[0].rows() diff --git a/core/src/blind_rotation/test_fft64/cggi.rs b/core/src/blind_rotation/test_fft64/cggi.rs index 4a5c319..785246e 100644 --- a/core/src/blind_rotation/test_fft64/cggi.rs +++ b/core/src/blind_rotation/test_fft64/cggi.rs @@ -1,6 +1,4 @@ -use std::time::Instant; - -use backend::{Encoding, FFT64, Module, ScratchOwned, Stats, VecZnxOps, ZnxView}; +use backend::{Encoding, FFT64, Module, ScratchOwned, ZnxView}; use sampling::source::Source; use crate::{ @@ -14,22 +12,31 @@ use crate::{ }; #[test] -fn blind_rotation() { - let module: Module = Module::::new(2048); - let basek: usize = 19; +fn standard() { + blind_rotatio_test(224, 1, 1); +} - let n_lwe: usize = 1071; +#[test] +fn block_binary() { + blind_rotatio_test(224, 7, 1); +} + +#[test] +fn block_binary_extended() { + blind_rotatio_test(224, 7, 2); +} + +fn blind_rotatio_test(n_lwe: usize, block_size: usize, extension_factor: usize) { + let module: Module = Module::::new(512); + let basek: usize = 19; let k_lwe: usize = 24; let k_brk: usize = 3 * basek; - let rows_brk: usize = 1; + let rows_brk: usize = 2; // Ensures first limb is noise-free. let k_lut: usize = 2 * basek; let rank: usize = 1; - let block_size: usize = 7; - let extension_factor: usize = 2; - - let message_modulus: usize = 1 << 6; + let message_modulus: usize = 1 << 4; let mut source_xs: Source = Source::new([1u8; 32]); let mut source_xe: Source = Source::new([1u8; 32]); @@ -56,7 +63,6 @@ fn blind_rotation() { rank, )); - let start: Instant = Instant::now(); let mut brk: BlindRotationKeyCGGI = BlindRotationKeyCGGI::allocate(&module, n_lwe, basek, k_brk, rows_brk, rank); brk.generate_from_sk( @@ -69,9 +75,6 @@ fn blind_rotation() { scratch.borrow(), ); - let duration: std::time::Duration = start.elapsed(); - println!("brk-gen: {} ms", duration.as_millis()); - let mut lwe: LWECiphertext> = LWECiphertext::alloc(n_lwe, basek, k_lwe); let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc(basek, k_lwe); @@ -81,13 +84,13 @@ fn blind_rotation() { pt_lwe.data.encode_coeff_i64(0, basek, bits, 0, x, bits); - println!("{}", pt_lwe.data); + // println!("{}", pt_lwe.data); lwe.encrypt_sk(&pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe, 3.2); lwe.decrypt(&mut pt_lwe, &sk_lwe); - println!("{}", pt_lwe.data); + // println!("{}", pt_lwe.data); let mut f: Vec = vec![0i64; message_modulus]; f.iter_mut() @@ -99,13 +102,9 @@ fn blind_rotation() { let mut res: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_lut, rank); - let start: Instant = Instant::now(); - (0..32).for_each(|_| { - cggi_blind_rotate(&module, &mut res, &lwe, &lut, &brk, scratch_br.borrow()); - }); + cggi_blind_rotate(&module, &mut res, &lwe, &lut, &brk, scratch_br.borrow()); - let duration: std::time::Duration = start.elapsed(); - println!("blind-rotate: {} ms", duration.as_millis()); + println!("out_mut.data: {}", res.data); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_lut); @@ -125,20 +124,21 @@ fn blind_rotation() { .sum::()) % (2 * lut.domain_size()) as i64; - println!("pt_want: {}", pt_want); + // println!("pt_want: {}", pt_want); lut.rotate(pt_want); - lut.data.iter().for_each(|d| { - println!("{}", d); - }); + // lut.data.iter().for_each(|d| { + // println!("{}", d); + // }); // First limb should be exactly equal (test are parameterized such that the noise does not reach // the first limb) - // assert_eq!(pt_have.data.at_mut(0, 0), lut.data[0].at_mut(0, 0)); + assert_eq!(pt_have.data.at(0, 0), lut.data[0].at(0, 0)); // Then checks the noise - module.vec_znx_sub_ab_inplace(&mut lut.data[0], 0, &pt_have.data, 0); - let noise: f64 = lut.data[0].std(0, basek); - println!("noise: {}", noise); + // module.vec_znx_sub_ab_inplace(&mut lut.data[0], 0, &pt_have.data, 0); + // let noise: f64 = lut.data[0].std(0, basek); + // println!("noise: {}", noise); + // assert!(noise < 1e-3); }