fixed standard binary cggi blind rotation & fixed GLWECiphertext::external_product_scratch_space returning too small values

This commit is contained in:
Jean-Philippe Bossuat
2025-07-08 13:37:35 +02:00
parent 992cb3fa37
commit f7c94cd84a
3 changed files with 50 additions and 33 deletions

View File

@@ -8,21 +8,24 @@ use itertools::izip;
use crate::{ use crate::{
GLWECiphertext, GLWECiphertextToMut, GLWEOps, Infos, LWECiphertext, ScratchCore, GLWECiphertext, GLWECiphertextToMut, GLWEOps, Infos, LWECiphertext, ScratchCore,
blind_rotation::{key::BlindRotationKeyCGGI, lut::LookUpTable}, blind_rotation::{key::BlindRotationKeyCGGI, lut::LookUpTable},
dist::Distribution,
lwe::ciphertext::LWECiphertextToRef, lwe::ciphertext::LWECiphertextToRef,
}; };
pub fn cggi_blind_rotate_scratch_space( pub fn cggi_blind_rotate_scratch_space(
module: &Module<FFT64>, module: &Module<FFT64>,
block_size: usize,
extension_factor: usize, extension_factor: usize,
basek: usize, basek: usize,
k_lut: usize, k_res: usize,
k_brk: usize, k_brk: usize,
rows: usize, rows: usize,
rank: usize, rank: usize,
) -> usize { ) -> usize {
let cols: usize = rank + 1;
let brk_size: usize = k_brk.div_ceil(basek); let brk_size: usize = k_brk.div_ceil(basek);
if block_size > 1 {
let cols: usize = rank + 1;
let acc_dft: usize = module.bytes_of_vec_znx_dft(cols, rows) * extension_factor; let acc_dft: usize = module.bytes_of_vec_znx_dft(cols, rows) * extension_factor;
let acc_big: usize = module.bytes_of_vec_znx_big(1, brk_size); let acc_big: usize = module.bytes_of_vec_znx_big(1, brk_size);
let vmp_res: usize = module.bytes_of_vec_znx_dft(cols, brk_size) * extension_factor; let vmp_res: usize = module.bytes_of_vec_znx_dft(cols, brk_size) * extension_factor;
@@ -33,7 +36,7 @@ pub fn cggi_blind_rotate_scratch_space(
let acc: usize; let acc: usize;
if extension_factor > 1 { if extension_factor > 1 {
acc = module.bytes_of_vec_znx(cols, k_lut.div_ceil(basek)) * extension_factor; acc = module.bytes_of_vec_znx(cols, k_res.div_ceil(basek)) * extension_factor;
} else { } else {
acc = 0; acc = 0;
} }
@@ -45,6 +48,10 @@ pub fn cggi_blind_rotate_scratch_space(
+ xai_plus_y + xai_plus_y
+ xai_plus_y_dft + xai_plus_y_dft
+ (vmp | (acc_big + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes()))); + (vmp | (acc_big + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes())));
} else {
2 * GLWECiphertext::bytes_of(module, basek, k_res, rank)
+ GLWECiphertext::external_product_scratch_space(module, basek, k_res, k_res, k_brk, 1, rank)
}
} }
pub fn cggi_blind_rotate<DataRes, DataIn>( pub fn cggi_blind_rotate<DataRes, DataIn>(
@@ -58,12 +65,20 @@ pub fn cggi_blind_rotate<DataRes, DataIn>(
DataRes: AsRef<[u8]> + AsMut<[u8]>, DataRes: AsRef<[u8]> + AsMut<[u8]>,
DataIn: AsRef<[u8]>, DataIn: AsRef<[u8]>,
{ {
match brk.dist {
Distribution::BinaryBlock(_) | Distribution::BinaryFixed(_) | Distribution::BinaryProb(_) | Distribution::ZERO => {
if lut.extension_factor() > 1 { if lut.extension_factor() > 1 {
cggi_blind_rotate_block_binary_extended(module, res, lwe, lut, brk, scratch); cggi_blind_rotate_block_binary_extended(module, res, lwe, lut, brk, scratch);
} else if brk.block_size() > 1 { } else if brk.block_size() > 1 {
cggi_blind_rotate_block_binary(module, res, lwe, lut, brk, scratch); cggi_blind_rotate_block_binary(module, res, lwe, lut, brk, scratch);
} else { } else {
cggi_blind_rotate_standard(module, res, lwe, lut, brk, scratch); cggi_blind_rotate_binary_standard(module, res, lwe, lut, brk, scratch);
}
}
// TODO: ternary distribution ?
_ => panic!(
"invalid BlindRotationKeyCGGI distribution: must be BinaryBlock, BinaryFixed or BinaryProb (or ZERO for debugging)"
),
} }
} }
@@ -322,7 +337,7 @@ pub(crate) fn cggi_blind_rotate_block_binary<DataRes, DataIn>(
}); });
} }
pub(crate) fn cggi_blind_rotate_standard<DataRes, DataIn>( pub(crate) fn cggi_blind_rotate_binary_standard<DataRes, DataIn>(
module: &Module<FFT64>, module: &Module<FFT64>,
res: &mut GLWECiphertext<DataRes>, res: &mut GLWECiphertext<DataRes>,
lwe: &LWECiphertext<DataIn>, lwe: &LWECiphertext<DataIn>,

View File

@@ -33,7 +33,8 @@ fn blind_rotatio_test(n_lwe: usize, block_size: usize, extension_factor: usize)
let k_lwe: usize = 24; let k_lwe: usize = 24;
let k_brk: usize = 3 * basek; let k_brk: usize = 3 * basek;
let rows_brk: usize = 2; // Ensures first limb is noise-free. let rows_brk: usize = 2; // Ensures first limb is noise-free.
let k_lut: usize = 2 * basek; let k_lut: usize = 1 * basek;
let k_res: usize = 2 * basek;
let rank: usize = 1; let rank: usize = 1;
let message_modulus: usize = 1 << 4; let message_modulus: usize = 1 << 4;
@@ -55,9 +56,10 @@ fn blind_rotatio_test(n_lwe: usize, block_size: usize, extension_factor: usize)
let mut scratch_br: ScratchOwned = ScratchOwned::new(cggi_blind_rotate_scratch_space( let mut scratch_br: ScratchOwned = ScratchOwned::new(cggi_blind_rotate_scratch_space(
&module, &module,
block_size,
extension_factor, extension_factor,
basek, basek,
k_lut, k_res,
k_brk, k_brk,
rows_brk, rows_brk,
rank, rank,
@@ -100,13 +102,13 @@ fn blind_rotatio_test(n_lwe: usize, block_size: usize, extension_factor: usize)
let mut lut: LookUpTable = LookUpTable::alloc(&module, basek, k_lut, extension_factor); let mut lut: LookUpTable = LookUpTable::alloc(&module, basek, k_lut, extension_factor);
lut.set(&module, &f, message_modulus); lut.set(&module, &f, message_modulus);
let mut res: GLWECiphertext<Vec<u8>> = GLWECiphertext::alloc(&module, basek, k_lut, rank); let mut res: GLWECiphertext<Vec<u8>> = GLWECiphertext::alloc(&module, basek, k_res, rank);
cggi_blind_rotate(&module, &mut res, &lwe, &lut, &brk, scratch_br.borrow()); cggi_blind_rotate(&module, &mut res, &lwe, &lut, &brk, scratch_br.borrow());
println!("out_mut.data: {}", res.data); println!("out_mut.data: {}", res.data);
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(&module, basek, k_lut); let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(&module, basek, k_res);
res.decrypt(&module, &mut pt_have, &sk_glwe_dft, scratch.borrow()); res.decrypt(&module, &mut pt_have, &sk_glwe_dft, scratch.borrow());

View File

@@ -14,7 +14,7 @@ impl GLWECiphertext<Vec<u8>> {
digits: usize, digits: usize,
rank: usize, rank: usize,
) -> usize { ) -> usize {
let res_dft: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank); let res_dft: usize = FourierGLWECiphertext::bytes_of(module, basek, k_ggsw, rank);
let in_size: usize = k_in.div_ceil(basek).div_ceil(digits); let in_size: usize = k_in.div_ceil(basek).div_ceil(digits);
let out_size: usize = k_out.div_ceil(basek); let out_size: usize = k_out.div_ceil(basek);
let ggsw_size: usize = k_ggsw.div_ceil(basek); let ggsw_size: usize = k_ggsw.div_ceil(basek);