This commit is contained in:
Pro7ech
2025-10-14 18:46:25 +02:00
parent 0533cdff8a
commit 72dca47cbe
153 changed files with 3099 additions and 1956 deletions

View File

@@ -13,7 +13,7 @@ use poulpy_hal::{
use poulpy_core::{
Distribution, GLWEOperations, TakeGLWECt,
layouts::{GGSWInfos, GLWECiphertext, GLWECiphertextToMut, GLWEInfos, LWECiphertext, LWECiphertextToRef, LWEInfos},
layouts::{GGSWInfos, GLWE, GLWEInfos, GLWEToMut, LWE, LWECiphertextToRef, LWEInfos},
};
use crate::tfhe::blind_rotation::{
@@ -43,14 +43,14 @@ where
if block_size > 1 {
let cols: usize = (brk_infos.rank() + 1).into();
let dnum: usize = brk_infos.dnum().into();
let acc_dft: usize = module.vec_znx_dft_alloc_bytes(cols, dnum) * extension_factor;
let acc_big: usize = module.vec_znx_big_alloc_bytes(1, brk_size);
let vmp_res: usize = module.vec_znx_dft_alloc_bytes(cols, brk_size) * extension_factor;
let vmp_xai: usize = module.vec_znx_dft_alloc_bytes(1, brk_size);
let acc_dft: usize = module.vec_znx_dft_bytes_of(cols, dnum) * extension_factor;
let acc_big: usize = module.vec_znx_big_bytes_of(1, brk_size);
let vmp_res: usize = module.vec_znx_dft_bytes_of(cols, brk_size) * extension_factor;
let vmp_xai: usize = module.vec_znx_dft_bytes_of(1, brk_size);
let acc_dft_add: usize = vmp_res;
let vmp: usize = module.vmp_apply_dft_to_dft_tmp_bytes(brk_size, dnum, dnum, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2)
let acc: usize = if extension_factor > 1 {
VecZnx::alloc_bytes(module.n(), cols, glwe_infos.size()) * extension_factor
VecZnx::bytes_of(module.n(), cols, glwe_infos.size()) * extension_factor
} else {
0
};
@@ -61,8 +61,7 @@ where
+ vmp_xai
+ (vmp | (acc_big + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_apply_tmp_bytes())))
} else {
GLWECiphertext::alloc_bytes(glwe_infos)
+ GLWECiphertext::external_product_inplace_scratch_space(module, glwe_infos, brk_infos)
GLWE::bytes_of(glwe_infos) + GLWE::external_product_inplace_scratch_space(module, glwe_infos, brk_infos)
}
}
@@ -99,8 +98,8 @@ where
fn execute<DR: DataMut, DI: DataRef>(
&self,
module: &Module<B>,
res: &mut GLWECiphertext<DR>,
lwe: &LWECiphertext<DI>,
res: &mut GLWE<DR>,
lwe: &LWE<DI>,
lut: &LookUpTable,
scratch: &mut Scratch<B>,
) {
@@ -121,8 +120,8 @@ where
fn execute_block_binary_extended<DataRes, DataIn, DataBrk, B: Backend>(
module: &Module<B>,
res: &mut GLWECiphertext<DataRes>,
lwe: &LWECiphertext<DataIn>,
res: &mut GLWE<DataRes>,
lwe: &LWE<DataIn>,
lut: &LookUpTable,
brk: &BlindRotationKeyPrepared<DataBrk, CGGI, B>,
scratch: &mut Scratch<B>,
@@ -179,7 +178,7 @@ fn execute_block_binary_extended<DataRes, DataIn, DataBrk, B: Backend>(
}
let mut lwe_2n: Vec<i64> = vec![0i64; (lwe.n() + 1).as_usize()]; // TODO: from scratch space
let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref();
let lwe_ref: LWE<&[u8]> = lwe.to_ref();
let two_n: usize = 2 * n_glwe;
let two_n_ext: usize = 2 * lut.domain_size();
@@ -288,8 +287,8 @@ fn execute_block_binary_extended<DataRes, DataIn, DataBrk, B: Backend>(
fn execute_block_binary<DataRes, DataIn, DataBrk, B: Backend>(
module: &Module<B>,
res: &mut GLWECiphertext<DataRes>,
lwe: &LWECiphertext<DataIn>,
res: &mut GLWE<DataRes>,
lwe: &LWE<DataIn>,
lut: &LookUpTable,
brk: &BlindRotationKeyPrepared<DataBrk, CGGI, B>,
scratch: &mut Scratch<B>,
@@ -324,8 +323,8 @@ fn execute_block_binary<DataRes, DataIn, DataBrk, B: Backend>(
{
let n_glwe: usize = brk.n_glwe().into();
let mut lwe_2n: Vec<i64> = vec![0i64; (lwe.n() + 1).into()]; // TODO: from scratch space
let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut();
let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref();
let mut out_mut: GLWE<&mut [u8]> = res.to_mut();
let lwe_ref: LWE<&[u8]> = lwe.to_ref();
let two_n: usize = n_glwe << 1;
let base2k: usize = brk.base2k().into();
let dnum: usize = brk.dnum().into();
@@ -410,8 +409,8 @@ fn execute_block_binary<DataRes, DataIn, DataBrk, B: Backend>(
fn execute_standard<DataRes, DataIn, DataBrk, B: Backend>(
module: &Module<B>,
res: &mut GLWECiphertext<DataRes>,
lwe: &LWECiphertext<DataIn>,
res: &mut GLWE<DataRes>,
lwe: &LWE<DataIn>,
lut: &LookUpTable,
brk: &BlindRotationKeyPrepared<DataBrk, CGGI, B>,
scratch: &mut Scratch<B>,
@@ -480,8 +479,8 @@ fn execute_standard<DataRes, DataIn, DataBrk, B: Backend>(
}
let mut lwe_2n: Vec<i64> = vec![0i64; (lwe.n() + 1).into()]; // TODO: from scratch space
let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut();
let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref();
let mut out_mut: GLWE<&mut [u8]> = res.to_mut();
let lwe_ref: LWE<&[u8]> = lwe.to_ref();
mod_switch_2n(
2 * lut.domain_size(),
@@ -519,7 +518,7 @@ fn execute_standard<DataRes, DataIn, DataBrk, B: Backend>(
out_mut.normalize_inplace(module, scratch_1);
}
pub fn mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphertext<&[u8]>, rot_dir: LookUpTableRotationDirection) {
pub fn mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWE<&[u8]>, rot_dir: LookUpTableRotationDirection) {
let base2k: usize = lwe.base2k().into();
let log2n: usize = usize::BITS as usize - (n - 1).leading_zeros() as usize + 1;