updated key-switch for rank switching & updated glwe key-switching test

This commit is contained in:
Jean-Philippe Bossuat
2025-05-14 16:34:52 +02:00
parent cb1928802a
commit f517a730a3
10 changed files with 1806 additions and 1676 deletions

View File

@@ -82,27 +82,34 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
+ module.bytes_of_vec_znx_dft(rank + 1, size)
}
pub fn keyswitch_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize) -> usize {
pub fn keyswitch_scratch_space(
module: &Module<FFT64>,
res_size: usize,
lhs: usize,
rhs: usize,
rank_in: usize,
rank_out: usize,
) -> usize {
<GGLWECiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space(
module, res_size, lhs, rhs,
module, res_size, lhs, rhs, rank_in, rank_out,
)
}
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize) -> usize {
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
<GGLWECiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_inplace_scratch_space(
module, res_size, rhs,
module, res_size, rhs, rank,
)
}
pub fn external_product_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize) -> usize {
pub fn external_product_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize, rank: usize) -> usize {
<GGSWCiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space(
module, res_size, lhs, rhs,
module, res_size, lhs, rhs, rank, rank,
)
}
pub fn external_product_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize) -> usize {
pub fn external_product_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
<GGSWCiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space(
module, res_size, rhs,
module, res_size, rhs, rank,
)
}
}
@@ -265,9 +272,24 @@ where
}
impl VecGLWEProductScratchSpace for GGSWCiphertext<Vec<u8>, FFT64> {
fn prod_with_glwe_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, rgsw_size: usize) -> usize {
module.bytes_of_vec_znx_dft(2, rgsw_size)
+ ((module.bytes_of_vec_znx_dft(2, a_size) + module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 2, 2, rgsw_size))
fn prod_with_glwe_scratch_space(
module: &Module<FFT64>,
res_size: usize,
a_size: usize,
rgsw_size: usize,
rank_in: usize,
rank_out: usize,
) -> usize {
module.bytes_of_vec_znx_dft(rank_out + 1, rgsw_size)
+ ((module.bytes_of_vec_znx_dft(rank_in + 1, a_size)
+ module.vmp_apply_tmp_bytes(
res_size,
a_size,
a_size,
rank_in + 1,
rank_out + 1,
rgsw_size,
))
| module.vec_znx_big_normalize_tmp_bytes())
}
}
@@ -290,6 +312,8 @@ where
#[cfg(debug_assertions)]
{
assert_eq!(self.rank(), a.rank());
assert_eq!(self.rank(), res.rank());
assert_eq!(res.basek(), log_base2k);
assert_eq!(a.basek(), log_base2k);
assert_eq!(self.n(), module.n());
@@ -297,18 +321,22 @@ where
assert_eq!(a.n(), module.n());
}
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, 2, self.size()); // Todo optimise
let cols: usize = self.rank() + 1;
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, self.size()); // Todo optimise
{
let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, 2, a.size());
module.vec_znx_dft(&mut a_dft, 0, a, 0);
module.vec_znx_dft(&mut a_dft, 1, a, 1);
let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, a.size());
(0..cols).for_each(|col_i| {
module.vec_znx_dft(&mut a_dft, col_i, a, col_i);
});
module.vmp_apply(&mut res_dft, &a_dft, self, scratch2);
}
let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft);
module.vec_znx_big_normalize(log_base2k, res, 0, &res_big, 0, scratch1);
module.vec_znx_big_normalize(log_base2k, res, 1, &res_big, 1, scratch1);
(0..cols).for_each(|i| {
module.vec_znx_big_normalize(log_base2k, res, i, &res_big, i, scratch1);
});
}
}