fixed standard binary cggi blind rotation & fixed GLWECiphertext::external_product_scratch_space returning too small values

This commit is contained in:
Jean-Philippe Bossuat
2025-07-08 13:37:35 +02:00
parent 992cb3fa37
commit f7c94cd84a
3 changed files with 50 additions and 33 deletions

View File

@@ -8,43 +8,50 @@ use itertools::izip;
use crate::{
GLWECiphertext, GLWECiphertextToMut, GLWEOps, Infos, LWECiphertext, ScratchCore,
blind_rotation::{key::BlindRotationKeyCGGI, lut::LookUpTable},
dist::Distribution,
lwe::ciphertext::LWECiphertextToRef,
};
pub fn cggi_blind_rotate_scratch_space(
module: &Module<FFT64>,
block_size: usize,
extension_factor: usize,
basek: usize,
k_lut: usize,
k_res: usize,
k_brk: usize,
rows: usize,
rank: usize,
) -> usize {
let cols: usize = rank + 1;
let brk_size: usize = k_brk.div_ceil(basek);
let acc_dft: usize = module.bytes_of_vec_znx_dft(cols, rows) * extension_factor;
let acc_big: usize = module.bytes_of_vec_znx_big(1, brk_size);
let vmp_res: usize = module.bytes_of_vec_znx_dft(cols, brk_size) * extension_factor;
let acc_dft_add: usize = vmp_res;
let xai_plus_y: usize = module.bytes_of_scalar_znx(1);
let xai_plus_y_dft: usize = module.bytes_of_scalar_znx_dft(1);
let vmp: usize = module.vmp_apply_tmp_bytes(brk_size, rows, rows, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2)
if block_size > 1 {
let cols: usize = rank + 1;
let acc_dft: usize = module.bytes_of_vec_znx_dft(cols, rows) * extension_factor;
let acc_big: usize = module.bytes_of_vec_znx_big(1, brk_size);
let vmp_res: usize = module.bytes_of_vec_znx_dft(cols, brk_size) * extension_factor;
let acc_dft_add: usize = vmp_res;
let xai_plus_y: usize = module.bytes_of_scalar_znx(1);
let xai_plus_y_dft: usize = module.bytes_of_scalar_znx_dft(1);
let vmp: usize = module.vmp_apply_tmp_bytes(brk_size, rows, rows, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2)
let acc: usize;
if extension_factor > 1 {
acc = module.bytes_of_vec_znx(cols, k_lut.div_ceil(basek)) * extension_factor;
let acc: usize;
if extension_factor > 1 {
acc = module.bytes_of_vec_znx(cols, k_res.div_ceil(basek)) * extension_factor;
} else {
acc = 0;
}
return acc
+ acc_dft
+ acc_dft_add
+ vmp_res
+ xai_plus_y
+ xai_plus_y_dft
+ (vmp | (acc_big + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes())));
} else {
acc = 0;
2 * GLWECiphertext::bytes_of(module, basek, k_res, rank)
+ GLWECiphertext::external_product_scratch_space(module, basek, k_res, k_res, k_brk, 1, rank)
}
return acc
+ acc_dft
+ acc_dft_add
+ vmp_res
+ xai_plus_y
+ xai_plus_y_dft
+ (vmp | (acc_big + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes())));
}
pub fn cggi_blind_rotate<DataRes, DataIn>(
@@ -58,12 +65,20 @@ pub fn cggi_blind_rotate<DataRes, DataIn>(
DataRes: AsRef<[u8]> + AsMut<[u8]>,
DataIn: AsRef<[u8]>,
{
if lut.extension_factor() > 1 {
cggi_blind_rotate_block_binary_extended(module, res, lwe, lut, brk, scratch);
} else if brk.block_size() > 1 {
cggi_blind_rotate_block_binary(module, res, lwe, lut, brk, scratch);
} else {
cggi_blind_rotate_standard(module, res, lwe, lut, brk, scratch);
match brk.dist {
Distribution::BinaryBlock(_) | Distribution::BinaryFixed(_) | Distribution::BinaryProb(_) | Distribution::ZERO => {
if lut.extension_factor() > 1 {
cggi_blind_rotate_block_binary_extended(module, res, lwe, lut, brk, scratch);
} else if brk.block_size() > 1 {
cggi_blind_rotate_block_binary(module, res, lwe, lut, brk, scratch);
} else {
cggi_blind_rotate_binary_standard(module, res, lwe, lut, brk, scratch);
}
}
// TODO: ternary distribution ?
_ => panic!(
"invalid BlindRotationKeyCGGI distribution: must be BinaryBlock, BinaryFixed or BinaryProb (or ZERO for debugging)"
),
}
}
@@ -322,7 +337,7 @@ pub(crate) fn cggi_blind_rotate_block_binary<DataRes, DataIn>(
});
}
pub(crate) fn cggi_blind_rotate_standard<DataRes, DataIn>(
pub(crate) fn cggi_blind_rotate_binary_standard<DataRes, DataIn>(
module: &Module<FFT64>,
res: &mut GLWECiphertext<DataRes>,
lwe: &LWECiphertext<DataIn>,

View File

@@ -33,7 +33,8 @@ fn blind_rotatio_test(n_lwe: usize, block_size: usize, extension_factor: usize)
let k_lwe: usize = 24;
let k_brk: usize = 3 * basek;
let rows_brk: usize = 2; // Ensures first limb is noise-free.
let k_lut: usize = 2 * basek;
let k_lut: usize = 1 * basek;
let k_res: usize = 2 * basek;
let rank: usize = 1;
let message_modulus: usize = 1 << 4;
@@ -55,9 +56,10 @@ fn blind_rotatio_test(n_lwe: usize, block_size: usize, extension_factor: usize)
let mut scratch_br: ScratchOwned = ScratchOwned::new(cggi_blind_rotate_scratch_space(
&module,
block_size,
extension_factor,
basek,
k_lut,
k_res,
k_brk,
rows_brk,
rank,
@@ -100,13 +102,13 @@ fn blind_rotatio_test(n_lwe: usize, block_size: usize, extension_factor: usize)
let mut lut: LookUpTable = LookUpTable::alloc(&module, basek, k_lut, extension_factor);
lut.set(&module, &f, message_modulus);
let mut res: GLWECiphertext<Vec<u8>> = GLWECiphertext::alloc(&module, basek, k_lut, rank);
let mut res: GLWECiphertext<Vec<u8>> = GLWECiphertext::alloc(&module, basek, k_res, rank);
cggi_blind_rotate(&module, &mut res, &lwe, &lut, &brk, scratch_br.borrow());
println!("out_mut.data: {}", res.data);
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(&module, basek, k_lut);
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(&module, basek, k_res);
res.decrypt(&module, &mut pt_have, &sk_glwe_dft, scratch.borrow());

View File

@@ -14,7 +14,7 @@ impl GLWECiphertext<Vec<u8>> {
digits: usize,
rank: usize,
) -> usize {
let res_dft: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank);
let res_dft: usize = FourierGLWECiphertext::bytes_of(module, basek, k_ggsw, rank);
let in_size: usize = k_in.div_ceil(basek).div_ceil(digits);
let out_size: usize = k_out.div_ceil(basek);
let ggsw_size: usize = k_ggsw.div_ceil(basek);