diff --git a/rlwe/src/encryption.rs b/rlwe/src/encryption.rs index ca4b837..f8b11c9 100644 --- a/rlwe/src/encryption.rs +++ b/rlwe/src/encryption.rs @@ -48,24 +48,29 @@ where let mut ct_mut: VecZnx<&mut [u8]> = ct.data_mut().to_mut(); let size: usize = ct_mut.size(); + // c1 = a ct_mut.fill_uniform(log_base2k, 1, size, source_xa); - // c1_dft = DFT(a) * DFT(s) - let (mut c1_dft, scratch_1) = scratch.tmp_vec_znx_dft(module, 1, size); - module.svp_apply(&mut c1_dft, 0, &sk.data().to_ref(), 0, &ct_mut, 1); + let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size); - // c1_big = IDFT(c1_dft) - let (mut c1_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, size); - module.vec_znx_idft_tmp_a(&mut c1_big, 0, &mut c1_dft, 0); + { + let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size); - // c1_big = m - c1_big - if let Some(pt) = pt { - module.vec_znx_big_sub_small_b_inplace(&mut c1_big, 0, &pt.data().to_ref(), 0); + // c0_dft = DFT(a) * DFT(s) + module.svp_apply(&mut c0_dft, 0, &sk.data().to_ref(), 0, &ct_mut, 1); + + // c0_big = IDFT(c0_dft) + module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0); } - // c1_big += e - c1_big.add_normal(log_base2k, 0, log_q, source_xe, sigma, bound); - // c0 = norm(c1_big) - module.vec_znx_big_normalize(log_base2k, &mut ct_mut, 0, &c1_big, 0, scratch_2); + // c0_big = m - c0_big + if let Some(pt) = pt { + module.vec_znx_big_sub_small_b_inplace(&mut c0_big, 0, &pt.data().to_ref(), 0); + } + // c0_big += e + c0_big.add_normal(log_base2k, 0, log_q, source_xe, sigma, bound); + + // c0 = norm(c0_big = -as + m + e) + module.vec_znx_big_normalize(log_base2k, &mut ct_mut, 0, &c0_big, 0, scratch_1); } }