diff --git a/core/src/ggsw_ciphertext.rs b/core/src/ggsw_ciphertext.rs index 67f4774..577bd6e 100644 --- a/core/src/ggsw_ciphertext.rs +++ b/core/src/ggsw_ciphertext.rs @@ -149,6 +149,8 @@ where }; (0..self.rows()).for_each(|row_j| { + vec_znx_pt.data.zero(); + // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_j, pt, 0); module.vec_znx_normalize_inplace(basek, &mut vec_znx_pt, 0, scrach_2); @@ -177,8 +179,6 @@ where module.vmp_prepare_row(self, row_j, col_i, &vec_znx_dft_ct); } }); - - vec_znx_pt.data.zero(); // zeroes for next iteration }); } diff --git a/core/src/glwe_ciphertext.rs b/core/src/glwe_ciphertext.rs index 245ee26..0fd6242 100644 --- a/core/src/glwe_ciphertext.rs +++ b/core/src/glwe_ciphertext.rs @@ -121,19 +121,33 @@ impl GLWECiphertext> { ksk_size: usize, ) -> usize { let res_dft: usize = module.bytes_of_vec_znx_dft(out_rank + 1, ksk_size); - let vmp: usize = module.vmp_apply_tmp_bytes( - out_size, - in_size, - in_size, - in_rank + 1, - out_rank + 1, - ksk_size, - ) + module.bytes_of_vec_znx_dft(in_rank, in_size); + let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, in_rank, out_rank + 1, ksk_size) + + module.bytes_of_vec_znx_dft(in_rank, in_size); let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(); return res_dft + (vmp | normalize); } + pub fn keyswitch_from_fourier_scratch_space( + module: &Module, + out_size: usize, + out_rank: usize, + in_size: usize, + in_rank: usize, + ksk_size: usize, + ) -> usize { + let res_dft = module.bytes_of_vec_znx_dft(out_rank + 1, ksk_size); + + let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, in_rank, out_rank + 1, ksk_size) + + module.bytes_of_vec_znx_dft(in_rank, in_size); + + let a0_big: usize = module.bytes_of_vec_znx_big(1, in_size) + module.vec_znx_idft_tmp_bytes(); + + let norm: usize = module.vec_znx_big_normalize_tmp_bytes(); + + res_dft + (vmp | a0_big | norm) + } + pub fn keyswitch_inplace_scratch_space(module: &Module, out_size: usize, out_rank: usize, ksk_size: usize) -> usize { GLWECiphertext::keyswitch_scratch_space(module, out_size, out_rank, out_size, out_rank, ksk_size) } @@ -322,7 +336,7 @@ where assert_eq!(lhs.n(), module.n()); assert!( scratch.available() - >= GLWECiphertextFourier::keyswitch_scratch_space( + >= GLWECiphertext::keyswitch_from_fourier_scratch_space( module, self.size(), self.rank(), diff --git a/core/src/glwe_ciphertext_fourier.rs b/core/src/glwe_ciphertext_fourier.rs index 4c22507..135a2dd 100644 --- a/core/src/glwe_ciphertext_fourier.rs +++ b/core/src/glwe_ciphertext_fourier.rs @@ -91,22 +91,8 @@ impl GLWECiphertextFourier, FFT64> { in_rank: usize, ksk_size: usize, ) -> usize { - let res_dft: usize = module.bytes_of_vec_znx_dft(out_rank + 1, out_size); - - let vmp = module.bytes_of_vec_znx_dft(in_rank, in_size) - + module.vmp_apply_tmp_bytes( - out_size, - in_size, - in_size, - in_rank + 1, - out_rank + 1, - ksk_size, - ); - let res_small: usize = module.bytes_of_vec_znx(out_rank + 1, out_size); - let add_a0: usize = module.bytes_of_vec_znx_big(1, in_size) + module.vec_znx_idft_tmp_bytes(); - let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(); - - res_dft + (vmp | add_a0 | (res_small + normalize)) + module.bytes_of_vec_znx(out_rank + 1, out_size) + + GLWECiphertext::keyswitch_from_fourier_scratch_space(module, out_size, out_rank, in_size, in_rank, ksk_size) } pub fn keyswitch_inplace_scratch_space(module: &Module, out_size: usize, out_rank: usize, ksk_size: usize) -> usize { @@ -181,11 +167,11 @@ where let mut res_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { data: res_idft_data, - basek: self.basek, - k: self.k, + basek: lhs.basek, + k: lhs.k, }; - res_idft.keyswitch_from_fourier(module, self, rhs, scratch1); + res_idft.keyswitch_from_fourier(module, lhs, rhs, scratch1); (0..cols_out).for_each(|i| { module.vec_znx_dft(self, i, &res_idft, i);