diff --git a/base2k/src/vec_znx_dft_ops.rs b/base2k/src/vec_znx_dft_ops.rs index 3e5965b..e4d6c33 100644 --- a/base2k/src/vec_znx_dft_ops.rs +++ b/base2k/src/vec_znx_dft_ops.rs @@ -148,9 +148,6 @@ impl VecZnxDftOps for Module { ); }); } - (min_size..res_mut.size()).for_each(|j| { - res_mut.zero_at(res_col, j); - }) } fn vec_znx_dft_copy(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) diff --git a/core/src/automorphism.rs b/core/src/automorphism.rs index 8b4fe3a..8dca7ec 100644 --- a/core/src/automorphism.rs +++ b/core/src/automorphism.rs @@ -352,12 +352,12 @@ where pub fn keyswitch_inplace( &mut self, module: &Module, - rhs: &AutomorphismKey, + rhs: &GLWESwitchingKey, scratch: &mut base2k::Scratch, ) where MatZnxDft: MatZnxDftToRef, { - self.key.keyswitch_inplace(module, &rhs.key, scratch); + self.key.keyswitch_inplace(module, &rhs, scratch); } pub fn external_product( diff --git a/core/src/glwe_ciphertext.rs b/core/src/glwe_ciphertext.rs index ca94db1..e319d21 100644 --- a/core/src/glwe_ciphertext.rs +++ b/core/src/glwe_ciphertext.rs @@ -141,11 +141,9 @@ impl GLWECiphertext> { let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, in_rank, out_rank + 1, ksk_size) + module.bytes_of_vec_znx_dft(in_rank, in_size); - let a0_big: usize = module.bytes_of_vec_znx_big(1, in_size) + module.vec_znx_idft_tmp_bytes(); - let norm: usize = module.vec_znx_big_normalize_tmp_bytes(); - res_dft + (vmp | a0_big | norm) + res_dft + (vmp | norm) } pub fn keyswitch_inplace_scratch_space(module: &Module, out_size: usize, out_rank: usize, ksk_size: usize) -> usize { @@ -362,15 +360,10 @@ where module.vmp_apply(&mut res_dft, &ai_dft, rhs, scratch2); } - // Switches result of VMP outside of DFT - let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft); + module.vec_znx_dft_add_inplace(&mut res_dft, 0, lhs, 0); - { - // Switches lhs 0-th outside of DFT domain and adds on - let (mut a0_big, scratch2) = scratch1.tmp_vec_znx_big(module, 1, lhs.size()); - module.vec_znx_idft(&mut a0_big, 0, lhs, 0, scratch2); - module.vec_znx_big_add_inplace(&mut res_big, 0, &a0_big, 0); - } + // Switches result of VMP outside of DFT + let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft); (0..cols_out).for_each(|i| { module.vec_znx_big_normalize(basek, self, i, &res_big, i, scratch1);