diff --git a/core/src/blind_rotation/ccgi.rs b/core/src/blind_rotation/ccgi.rs index 2b5e877..0d4f6fd 100644 --- a/core/src/blind_rotation/ccgi.rs +++ b/core/src/blind_rotation/ccgi.rs @@ -8,43 +8,50 @@ use itertools::izip; use crate::{ GLWECiphertext, GLWECiphertextToMut, GLWEOps, Infos, LWECiphertext, ScratchCore, blind_rotation::{key::BlindRotationKeyCGGI, lut::LookUpTable}, + dist::Distribution, lwe::ciphertext::LWECiphertextToRef, }; pub fn cggi_blind_rotate_scratch_space( module: &Module, + block_size: usize, extension_factor: usize, basek: usize, - k_lut: usize, + k_res: usize, k_brk: usize, rows: usize, rank: usize, ) -> usize { - let cols: usize = rank + 1; let brk_size: usize = k_brk.div_ceil(basek); - let acc_dft: usize = module.bytes_of_vec_znx_dft(cols, rows) * extension_factor; - let acc_big: usize = module.bytes_of_vec_znx_big(1, brk_size); - let vmp_res: usize = module.bytes_of_vec_znx_dft(cols, brk_size) * extension_factor; - let acc_dft_add: usize = vmp_res; - let xai_plus_y: usize = module.bytes_of_scalar_znx(1); - let xai_plus_y_dft: usize = module.bytes_of_scalar_znx_dft(1); - let vmp: usize = module.vmp_apply_tmp_bytes(brk_size, rows, rows, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2) + if block_size > 1 { + let cols: usize = rank + 1; + let acc_dft: usize = module.bytes_of_vec_znx_dft(cols, rows) * extension_factor; + let acc_big: usize = module.bytes_of_vec_znx_big(1, brk_size); + let vmp_res: usize = module.bytes_of_vec_znx_dft(cols, brk_size) * extension_factor; + let acc_dft_add: usize = vmp_res; + let xai_plus_y: usize = module.bytes_of_scalar_znx(1); + let xai_plus_y_dft: usize = module.bytes_of_scalar_znx_dft(1); + let vmp: usize = module.vmp_apply_tmp_bytes(brk_size, rows, rows, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2) - let acc: usize; - if extension_factor > 1 { - acc = module.bytes_of_vec_znx(cols, k_lut.div_ceil(basek)) * extension_factor; + let acc: usize; + if extension_factor > 1 { + acc = module.bytes_of_vec_znx(cols, k_res.div_ceil(basek)) * extension_factor; + } else { + acc = 0; + } + + return acc + + acc_dft + + acc_dft_add + + vmp_res + + xai_plus_y + + xai_plus_y_dft + + (vmp | (acc_big + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes()))); } else { - acc = 0; + 2 * GLWECiphertext::bytes_of(module, basek, k_res, rank) + + GLWECiphertext::external_product_scratch_space(module, basek, k_res, k_res, k_brk, 1, rank) } - - return acc - + acc_dft - + acc_dft_add - + vmp_res - + xai_plus_y - + xai_plus_y_dft - + (vmp | (acc_big + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes()))); } pub fn cggi_blind_rotate( @@ -58,12 +65,20 @@ pub fn cggi_blind_rotate( DataRes: AsRef<[u8]> + AsMut<[u8]>, DataIn: AsRef<[u8]>, { - if lut.extension_factor() > 1 { - cggi_blind_rotate_block_binary_extended(module, res, lwe, lut, brk, scratch); - } else if brk.block_size() > 1 { - cggi_blind_rotate_block_binary(module, res, lwe, lut, brk, scratch); - } else { - cggi_blind_rotate_standard(module, res, lwe, lut, brk, scratch); + match brk.dist { + Distribution::BinaryBlock(_) | Distribution::BinaryFixed(_) | Distribution::BinaryProb(_) | Distribution::ZERO => { + if lut.extension_factor() > 1 { + cggi_blind_rotate_block_binary_extended(module, res, lwe, lut, brk, scratch); + } else if brk.block_size() > 1 { + cggi_blind_rotate_block_binary(module, res, lwe, lut, brk, scratch); + } else { + cggi_blind_rotate_binary_standard(module, res, lwe, lut, brk, scratch); + } + } + // TODO: ternary distribution ? + _ => panic!( + "invalid BlindRotationKeyCGGI distribution: must be BinaryBlock, BinaryFixed or BinaryProb (or ZERO for debugging)" + ), } } @@ -322,7 +337,7 @@ pub(crate) fn cggi_blind_rotate_block_binary( }); } -pub(crate) fn cggi_blind_rotate_standard( +pub(crate) fn cggi_blind_rotate_binary_standard( module: &Module, res: &mut GLWECiphertext, lwe: &LWECiphertext, diff --git a/core/src/blind_rotation/test_fft64/cggi.rs b/core/src/blind_rotation/test_fft64/cggi.rs index 785246e..ea8291a 100644 --- a/core/src/blind_rotation/test_fft64/cggi.rs +++ b/core/src/blind_rotation/test_fft64/cggi.rs @@ -33,7 +33,8 @@ fn blind_rotatio_test(n_lwe: usize, block_size: usize, extension_factor: usize) let k_lwe: usize = 24; let k_brk: usize = 3 * basek; let rows_brk: usize = 2; // Ensures first limb is noise-free. - let k_lut: usize = 2 * basek; + let k_lut: usize = 1 * basek; + let k_res: usize = 2 * basek; let rank: usize = 1; let message_modulus: usize = 1 << 4; @@ -55,9 +56,10 @@ fn blind_rotatio_test(n_lwe: usize, block_size: usize, extension_factor: usize) let mut scratch_br: ScratchOwned = ScratchOwned::new(cggi_blind_rotate_scratch_space( &module, + block_size, extension_factor, basek, - k_lut, + k_res, k_brk, rows_brk, rank, @@ -100,13 +102,13 @@ fn blind_rotatio_test(n_lwe: usize, block_size: usize, extension_factor: usize) let mut lut: LookUpTable = LookUpTable::alloc(&module, basek, k_lut, extension_factor); lut.set(&module, &f, message_modulus); - let mut res: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_lut, rank); + let mut res: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_res, rank); cggi_blind_rotate(&module, &mut res, &lwe, &lut, &brk, scratch_br.borrow()); println!("out_mut.data: {}", res.data); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_lut); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_res); res.decrypt(&module, &mut pt_have, &sk_glwe_dft, scratch.borrow()); diff --git a/core/src/glwe/external_product.rs b/core/src/glwe/external_product.rs index 3ebf339..e7ee778 100644 --- a/core/src/glwe/external_product.rs +++ b/core/src/glwe/external_product.rs @@ -14,7 +14,7 @@ impl GLWECiphertext> { digits: usize, rank: usize, ) -> usize { - let res_dft: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank); + let res_dft: usize = FourierGLWECiphertext::bytes_of(module, basek, k_ggsw, rank); let in_size: usize = k_in.div_ceil(basek).div_ceil(digits); let out_size: usize = k_out.div_ceil(basek); let ggsw_size: usize = k_ggsw.div_ceil(basek);