Add Zn type

This commit is contained in:
Pro7ech
2025-08-21 12:16:53 +02:00
parent ccd94e36cc
commit bf513dc555
129 changed files with 1400 additions and 686 deletions

View File

@@ -23,7 +23,6 @@ use crate::tfhe::blind_rotation::{
#[allow(clippy::too_many_arguments)]
pub fn cggi_blind_rotate_scratch_space<B: Backend>(
module: &Module<B>,
n: usize,
block_size: usize,
extension_factor: usize,
basek: usize,
@@ -44,14 +43,14 @@ where
if block_size > 1 {
let cols: usize = rank + 1;
let acc_dft: usize = module.vec_znx_dft_alloc_bytes(n, cols, rows) * extension_factor;
let acc_big: usize = module.vec_znx_big_alloc_bytes(n, 1, brk_size);
let vmp_res: usize = module.vec_znx_dft_alloc_bytes(n, cols, brk_size) * extension_factor;
let vmp_xai: usize = module.vec_znx_dft_alloc_bytes(n, 1, brk_size);
let acc_dft: usize = module.vec_znx_dft_alloc_bytes(cols, rows) * extension_factor;
let acc_big: usize = module.vec_znx_big_alloc_bytes(1, brk_size);
let vmp_res: usize = module.vec_znx_dft_alloc_bytes(cols, brk_size) * extension_factor;
let vmp_xai: usize = module.vec_znx_dft_alloc_bytes(1, brk_size);
let acc_dft_add: usize = vmp_res;
let vmp: usize = module.vmp_apply_tmp_bytes(n, brk_size, rows, rows, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2)
let vmp: usize = module.vmp_apply_tmp_bytes(brk_size, rows, rows, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2)
let acc: usize = if extension_factor > 1 {
VecZnx::alloc_bytes(n, cols, k_res.div_ceil(basek)) * extension_factor
VecZnx::alloc_bytes(module.n(), cols, k_res.div_ceil(basek)) * extension_factor
} else {
0
};
@@ -60,10 +59,10 @@ where
+ acc_dft_add
+ vmp_res
+ vmp_xai
+ (vmp | (acc_big + (module.vec_znx_big_normalize_tmp_bytes(n) | module.vec_znx_dft_to_vec_znx_big_tmp_bytes(n))))
+ (vmp | (acc_big + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_dft_to_vec_znx_big_tmp_bytes())))
} else {
GLWECiphertext::bytes_of(n, basek, k_res, rank)
+ GLWECiphertext::external_product_scratch_space(module, n, basek, k_res, k_res, k_brk, 1, rank)
GLWECiphertext::bytes_of(module.n(), basek, k_res, rank)
+ GLWECiphertext::external_product_scratch_space(module, basek, k_res, k_res, k_brk, 1, rank)
}
}