bug fixes

This commit is contained in:
Jean-Philippe Bossuat
2025-05-26 19:16:43 +02:00
parent e5d6a6f828
commit dec3481a6f
3 changed files with 14 additions and 11 deletions

View File

@@ -142,14 +142,13 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
tensor_key_size: usize,
rank: usize,
) -> usize {
GGSWCiphertext::keyswitch_scratch_space(
module,
out_size,
in_size,
auto_key_size,
tensor_key_size,
rank,
)
let cols: usize = rank + 1;
let res: usize = module.bytes_of_vec_znx(cols, out_size);
let res_dft: usize = module.bytes_of_vec_znx_dft(cols, out_size);
let ci_dft: usize = module.bytes_of_vec_znx_dft(cols, out_size);
let ks_internal: usize = GGSWCiphertext::keyswitch_internal_col0_scratch_space(module, out_size, in_size, auto_key_size, rank);
let expand: usize = GGSWCiphertext::expand_row_scratch_space(module, out_size, tensor_key_size, rank);
res + ci_dft + (ks_internal | expand | res_dft)
}
pub fn automorphism_inplace_scratch_space(
@@ -288,6 +287,8 @@ where
{
let cols: usize = self.rank() + 1;
assert!(scratch.available() >= GGSWCiphertext::expand_row_scratch_space(module, self.size(), tsk.size(), tsk.rank()));
// Example for rank 3:
//
// Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is
@@ -383,7 +384,7 @@ where
k: self.k(),
};
let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, self.size());
let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, lhs.size());
// Keyswitch the j-th row of the col 0
(0..lhs.rows()).for_each(|row_i| {
@@ -467,6 +468,7 @@ where
self.rank(),
tensor_key.rank()
);
assert!(scratch.available() >= GGSWCiphertext::automorphism_scratch_space(module, self.size(), lhs.size(), auto_key.size(), tensor_key.size(), self.rank()))
};
let cols: usize = self.rank() + 1;
@@ -646,6 +648,7 @@ where
{
assert_eq!(self.rank(), ksk.rank());
assert_eq!(res.rank(), ksk.rank());
assert!(scratch.available() >= GGSWCiphertext::keyswitch_internal_col0_scratch_space(module, res.size(), self.size(), ksk.size(), ksk.rank()))
}
let (tmp_dft_in_data, scratch2) = scratch.tmp_vec_znx_dft(module, self.rank() + 1, self.size());