mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
updated key-switch for rank switching & updated glwe key-switching test
This commit is contained in:
@@ -82,27 +82,34 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
|
||||
+ module.bytes_of_vec_znx_dft(rank + 1, size)
|
||||
}
|
||||
|
||||
pub fn keyswitch_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize) -> usize {
|
||||
pub fn keyswitch_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
res_size: usize,
|
||||
lhs: usize,
|
||||
rhs: usize,
|
||||
rank_in: usize,
|
||||
rank_out: usize,
|
||||
) -> usize {
|
||||
<GGLWECiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space(
|
||||
module, res_size, lhs, rhs,
|
||||
module, res_size, lhs, rhs, rank_in, rank_out,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize) -> usize {
|
||||
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
|
||||
<GGLWECiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_inplace_scratch_space(
|
||||
module, res_size, rhs,
|
||||
module, res_size, rhs, rank,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn external_product_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize) -> usize {
|
||||
pub fn external_product_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize, rank: usize) -> usize {
|
||||
<GGSWCiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space(
|
||||
module, res_size, lhs, rhs,
|
||||
module, res_size, lhs, rhs, rank, rank,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn external_product_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize) -> usize {
|
||||
pub fn external_product_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
|
||||
<GGSWCiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space(
|
||||
module, res_size, rhs,
|
||||
module, res_size, rhs, rank,
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -265,9 +272,24 @@ where
|
||||
}
|
||||
|
||||
impl VecGLWEProductScratchSpace for GGSWCiphertext<Vec<u8>, FFT64> {
|
||||
fn prod_with_glwe_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, rgsw_size: usize) -> usize {
|
||||
module.bytes_of_vec_znx_dft(2, rgsw_size)
|
||||
+ ((module.bytes_of_vec_znx_dft(2, a_size) + module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 2, 2, rgsw_size))
|
||||
fn prod_with_glwe_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
rgsw_size: usize,
|
||||
rank_in: usize,
|
||||
rank_out: usize,
|
||||
) -> usize {
|
||||
module.bytes_of_vec_znx_dft(rank_out + 1, rgsw_size)
|
||||
+ ((module.bytes_of_vec_znx_dft(rank_in + 1, a_size)
|
||||
+ module.vmp_apply_tmp_bytes(
|
||||
res_size,
|
||||
a_size,
|
||||
a_size,
|
||||
rank_in + 1,
|
||||
rank_out + 1,
|
||||
rgsw_size,
|
||||
))
|
||||
| module.vec_znx_big_normalize_tmp_bytes())
|
||||
}
|
||||
}
|
||||
@@ -290,6 +312,8 @@ where
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(self.rank(), a.rank());
|
||||
assert_eq!(self.rank(), res.rank());
|
||||
assert_eq!(res.basek(), log_base2k);
|
||||
assert_eq!(a.basek(), log_base2k);
|
||||
assert_eq!(self.n(), module.n());
|
||||
@@ -297,18 +321,22 @@ where
|
||||
assert_eq!(a.n(), module.n());
|
||||
}
|
||||
|
||||
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, 2, self.size()); // Todo optimise
|
||||
let cols: usize = self.rank() + 1;
|
||||
|
||||
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, self.size()); // Todo optimise
|
||||
|
||||
{
|
||||
let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, 2, a.size());
|
||||
module.vec_znx_dft(&mut a_dft, 0, a, 0);
|
||||
module.vec_znx_dft(&mut a_dft, 1, a, 1);
|
||||
let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, a.size());
|
||||
(0..cols).for_each(|col_i| {
|
||||
module.vec_znx_dft(&mut a_dft, col_i, a, col_i);
|
||||
});
|
||||
module.vmp_apply(&mut res_dft, &a_dft, self, scratch2);
|
||||
}
|
||||
|
||||
let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft);
|
||||
|
||||
module.vec_znx_big_normalize(log_base2k, res, 0, &res_big, 0, scratch1);
|
||||
module.vec_znx_big_normalize(log_base2k, res, 1, &res_big, 1, scratch1);
|
||||
(0..cols).for_each(|i| {
|
||||
module.vec_znx_big_normalize(log_base2k, res, i, &res_big, i, scratch1);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user