mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
bug fixes
This commit is contained in:
@@ -142,14 +142,13 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
|
||||
tensor_key_size: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
GGSWCiphertext::keyswitch_scratch_space(
|
||||
module,
|
||||
out_size,
|
||||
in_size,
|
||||
auto_key_size,
|
||||
tensor_key_size,
|
||||
rank,
|
||||
)
|
||||
let cols: usize = rank + 1;
|
||||
let res: usize = module.bytes_of_vec_znx(cols, out_size);
|
||||
let res_dft: usize = module.bytes_of_vec_znx_dft(cols, out_size);
|
||||
let ci_dft: usize = module.bytes_of_vec_znx_dft(cols, out_size);
|
||||
let ks_internal: usize = GGSWCiphertext::keyswitch_internal_col0_scratch_space(module, out_size, in_size, auto_key_size, rank);
|
||||
let expand: usize = GGSWCiphertext::expand_row_scratch_space(module, out_size, tensor_key_size, rank);
|
||||
res + ci_dft + (ks_internal | expand | res_dft)
|
||||
}
|
||||
|
||||
pub fn automorphism_inplace_scratch_space(
|
||||
@@ -288,6 +287,8 @@ where
|
||||
{
|
||||
let cols: usize = self.rank() + 1;
|
||||
|
||||
assert!(scratch.available() >= GGSWCiphertext::expand_row_scratch_space(module, self.size(), tsk.size(), tsk.rank()));
|
||||
|
||||
// Example for rank 3:
|
||||
//
|
||||
// Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is
|
||||
@@ -383,7 +384,7 @@ where
|
||||
k: self.k(),
|
||||
};
|
||||
|
||||
let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, self.size());
|
||||
let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, lhs.size());
|
||||
|
||||
// Keyswitch the j-th row of the col 0
|
||||
(0..lhs.rows()).for_each(|row_i| {
|
||||
@@ -467,6 +468,7 @@ where
|
||||
self.rank(),
|
||||
tensor_key.rank()
|
||||
);
|
||||
assert!(scratch.available() >= GGSWCiphertext::automorphism_scratch_space(module, self.size(), lhs.size(), auto_key.size(), tensor_key.size(), self.rank()))
|
||||
};
|
||||
|
||||
let cols: usize = self.rank() + 1;
|
||||
@@ -646,6 +648,7 @@ where
|
||||
{
|
||||
assert_eq!(self.rank(), ksk.rank());
|
||||
assert_eq!(res.rank(), ksk.rank());
|
||||
assert!(scratch.available() >= GGSWCiphertext::keyswitch_internal_col0_scratch_space(module, res.size(), self.size(), ksk.size(), ksk.rank()))
|
||||
}
|
||||
|
||||
let (tmp_dft_in_data, scratch2) = scratch.tmp_vec_znx_dft(module, self.rank() + 1, self.size());
|
||||
|
||||
@@ -17,7 +17,7 @@ pub struct TensorKey<C, B: Backend> {
|
||||
impl TensorKey<Vec<u8>, FFT64> {
|
||||
pub fn alloc(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, rank: usize) -> Self {
|
||||
let mut keys: Vec<GLWESwitchingKey<Vec<u8>, FFT64>> = Vec::new();
|
||||
let pairs: usize = ((rank + 1) * rank) >> 1;
|
||||
let pairs: usize = (((rank + 1) * rank) >> 1).max(1);
|
||||
(0..pairs).for_each(|_| {
|
||||
keys.push(GLWESwitchingKey::alloc(module, basek, k, rows, 1, rank));
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user