diff --git a/backend/src/encoding.rs b/backend/src/encoding.rs index 48501b6..73b86a3 100644 --- a/backend/src/encoding.rs +++ b/backend/src/encoding.rs @@ -267,7 +267,7 @@ fn decode_coeff_i64>(a: &VecZnx, col_i: usize, basek: usize, k let data: &[i64] = a.raw(); let mut res: i64 = data[i]; let rem: usize = basek - (k % basek); - let slice_size: usize = a.n() * a.size(); + let slice_size: usize = a.n() * a.cols(); (1..size).for_each(|i| { let x: i64 = data[i * slice_size]; if i == size - 1 && rem != basek { diff --git a/core/src/ggsw_ciphertext.rs b/core/src/ggsw_ciphertext.rs index ca22faf..7215adf 100644 --- a/core/src/ggsw_ciphertext.rs +++ b/core/src/ggsw_ciphertext.rs @@ -142,14 +142,13 @@ impl GGSWCiphertext, FFT64> { tensor_key_size: usize, rank: usize, ) -> usize { - GGSWCiphertext::keyswitch_scratch_space( - module, - out_size, - in_size, - auto_key_size, - tensor_key_size, - rank, - ) + let cols: usize = rank + 1; + let res: usize = module.bytes_of_vec_znx(cols, out_size); + let res_dft: usize = module.bytes_of_vec_znx_dft(cols, out_size); + let ci_dft: usize = module.bytes_of_vec_znx_dft(cols, out_size); + let ks_internal: usize = GGSWCiphertext::keyswitch_internal_col0_scratch_space(module, out_size, in_size, auto_key_size, rank); + let expand: usize = GGSWCiphertext::expand_row_scratch_space(module, out_size, tensor_key_size, rank); + res + ci_dft + (ks_internal | expand | res_dft) } pub fn automorphism_inplace_scratch_space( @@ -288,6 +287,8 @@ where { let cols: usize = self.rank() + 1; + assert!(scratch.available() >= GGSWCiphertext::expand_row_scratch_space(module, self.size(), tsk.size(), tsk.rank())); + // Example for rank 3: // // Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is @@ -383,7 +384,7 @@ where k: self.k(), }; - let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, self.size()); + let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, lhs.size()); // Keyswitch the j-th row of the col 0 (0..lhs.rows()).for_each(|row_i| { @@ -467,6 +468,7 @@ where self.rank(), tensor_key.rank() ); + assert!(scratch.available() >= GGSWCiphertext::automorphism_scratch_space(module, self.size(), lhs.size(), auto_key.size(), tensor_key.size(), self.rank())) }; let cols: usize = self.rank() + 1; @@ -646,6 +648,7 @@ where { assert_eq!(self.rank(), ksk.rank()); assert_eq!(res.rank(), ksk.rank()); + assert!(scratch.available() >= GGSWCiphertext::keyswitch_internal_col0_scratch_space(module, res.size(), self.size(), ksk.size(), ksk.rank())) } let (tmp_dft_in_data, scratch2) = scratch.tmp_vec_znx_dft(module, self.rank() + 1, self.size()); diff --git a/core/src/tensor_key.rs b/core/src/tensor_key.rs index 105dace..89ce299 100644 --- a/core/src/tensor_key.rs +++ b/core/src/tensor_key.rs @@ -17,7 +17,7 @@ pub struct TensorKey { impl TensorKey, FFT64> { pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, rank: usize) -> Self { let mut keys: Vec, FFT64>> = Vec::new(); - let pairs: usize = ((rank + 1) * rank) >> 1; + let pairs: usize = (((rank + 1) * rank) >> 1).max(1); (0..pairs).for_each(|_| { keys.push(GLWESwitchingKey::alloc(module, basek, k, rows, 1, rank)); });