From 3634ab774643ac8f48fed88b0cd277e3baa95294 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Mon, 24 Feb 2025 10:17:08 +0100 Subject: [PATCH] finalized raw gadget product test with noise equations --- rlwe/src/gadget_product.rs | 149 +++++++++++++++++++------------------ 1 file changed, 78 insertions(+), 71 deletions(-) diff --git a/rlwe/src/gadget_product.rs b/rlwe/src/gadget_product.rs index 4027575..b1551fd 100644 --- a/rlwe/src/gadget_product.rs +++ b/rlwe/src/gadget_product.rs @@ -67,7 +67,7 @@ pub fn gadget_product_core( Elem: ElemVecZnx, { assert!(b_cols <= b.cols()); - module.vec_znx_dft(res_dft_1, a, a_cols); + module.vec_znx_dft(res_dft_1, a, min(a_cols, b_cols)); module.vmp_apply_dft_to_dft(res_dft_0, res_dft_1, b.at(0), tmp_bytes); module.vmp_apply_dft_to_dft_inplace(res_dft_1, b.at(1), tmp_bytes); } @@ -108,6 +108,7 @@ mod test { VecZnxDftOps, VecZnxOps, VmpPMat, }; use sampling::source::{Source, new_seed}; + use std::cmp::min; #[test] fn test_gadget_product_core() { @@ -220,86 +221,93 @@ mod test { // Iterates over all possible cols values for input/output polynomials and gadget ciphertext. - pt.elem_mut().zero(); - elem_res.zero(); + (1..a.cols() + 1).for_each(|a_cols| { + (1..gadget_ct.cols() + 1).for_each(|b_cols| { + pt.elem_mut().zero(); + elem_res.zero(); - let a_cols: usize = a.cols() - 1; - let b_cols: usize = gadget_ct.cols(); + //let b_cols: usize = min(a_cols+1, gadget_ct.cols()); - println!("a_cols: {} b_cols: {}", a_cols, b_cols); + println!("a_cols: {} b_cols: {}", a_cols, b_cols); - // res_dft_0 = DFT(gct_[0] * ct[1] = a * (-bs' + s + e) = -cs' + as + e') - // res_dft_1 = DFT(gct_[1] * ct[1] = a * b = c) - gadget_product_core::( - params.module(), - &mut res_dft_0, - &mut res_dft_1, - &a, - a_cols, - &gadget_ct, - b_cols, - &mut tmp_bytes, - ); + // res_dft_0 = DFT(gct_[0] * ct[1] = a * (-bs' + s + e) = -cs' + as + e') + // res_dft_1 = DFT(gct_[1] * ct[1] = a * b = c) + gadget_product_core::( + params.module(), + &mut res_dft_0, + &mut res_dft_1, + &a, + a_cols, + &gadget_ct, + b_cols, + &mut tmp_bytes, + ); - // res_big_0 = IDFT(res_dft_0) - params - .module() - .vec_znx_idft_tmp_a(&mut res_big_0, &mut res_dft_0, b_cols); - // res_big_1 = IDFT(res_dft_1); - params - .module() - .vec_znx_idft_tmp_a(&mut res_big_1, &mut res_dft_1, b_cols); + // res_big_0 = IDFT(res_dft_0) + params + .module() + .vec_znx_idft_tmp_a(&mut res_big_0, &mut res_dft_0, b_cols); + // res_big_1 = IDFT(res_dft_1); + params + .module() + .vec_znx_idft_tmp_a(&mut res_big_1, &mut res_dft_1, b_cols); - // res_big_0 = normalize(res_big_0) - params.module().vec_znx_big_normalize( - log_base2k, - elem_res.at_mut(0), - &res_big_0, - &mut tmp_bytes, - ); + // res_big_0 = normalize(res_big_0) + params.module().vec_znx_big_normalize( + log_base2k, + elem_res.at_mut(0), + &res_big_0, + &mut tmp_bytes, + ); - // res_big_1 = normalize(res_big_1) - params.module().vec_znx_big_normalize( - log_base2k, - elem_res.at_mut(1), - &res_big_1, - &mut tmp_bytes, - ); + // res_big_1 = normalize(res_big_1) + params.module().vec_znx_big_normalize( + log_base2k, + elem_res.at_mut(1), + &res_big_1, + &mut tmp_bytes, + ); - // <(-c*sk1 + a*sk0 + e, a), (1, sk1)> = a*sk0 + e - decrypt_rlwe( - params.module(), - pt.elem_mut(), - &elem_res, - &sk1_svp_ppol, - &mut tmp_bytes, - ); + // <(-c*sk1 + a*sk0 + e, a), (1, sk1)> = a*sk0 + e + decrypt_rlwe( + params.module(), + pt.elem_mut(), + &elem_res, + &sk1_svp_ppol, + &mut tmp_bytes, + ); - // a * sk0 + e - a*sk0 = e - params - .module() - .vec_znx_sub_inplace(pt.at_mut(0), &mut a_times_s); - pt.at_mut(0).normalize(log_base2k, &mut tmp_bytes); + // a * sk0 + e - a*sk0 = e + params + .module() + .vec_znx_sub_inplace(pt.at_mut(0), &mut a_times_s); + pt.at_mut(0).normalize(log_base2k, &mut tmp_bytes); - pt.at(0).print(pt.elem().cols(), 16); + //pt.at(0).print(pt.elem().cols(), 16); - println!("noise_have: {}", pt.at(0).std(log_base2k).log2()); + let noise_have: f64 = pt.at(0).std(log_base2k).log2(); - let var_a_err: f64; + let var_a_err: f64; - if a_cols < a.cols() { - var_a_err = 1f64 / 12f64; - } else { - var_a_err = 0f64; - } + if a_cols < a.cols() { + var_a_err = 1f64 / 12f64; + } else { + var_a_err = 0f64; + } - let a_logq: usize = a_cols * log_base2k; - let b_logq: usize = b_cols * log_base2k; - let var_msg: f64 = params.xs() as f64; - println!( - "noise_pred: {}", - params.noise_grlwe_product(var_msg, var_a_err, a_logq, b_logq) - ); + let a_logq: usize = a_cols * log_base2k; + let b_logq: usize = b_cols * log_base2k; + let var_msg: f64 = params.xs() as f64; + + let noise_pred: f64 = + params.noise_grlwe_product(var_msg, var_a_err, a_logq, b_logq); + + assert!(noise_have <= noise_pred + 1.0); + + println!("noise_pred: {}", noise_have); + println!("noise_have: {}", noise_pred); + }); + }); } } @@ -350,15 +358,14 @@ pub fn noise_grlwe_product( a_logq: usize, b_logq: usize, ) -> f64 { + let a_logq: usize = min(a_logq, b_logq); let a_cols: usize = (a_logq + log_base2k - 1) / log_base2k; - let b_cols: usize = (b_logq + log_base2k - 1) / log_base2k; let b_scale = 2.0f64.powi(b_logq as i32); let a_scale: f64 = 2.0f64.powi((b_logq - a_logq) as i32); let base: f64 = (1 << (log_base2k)) as f64; let var_base: f64 = base * base / 12f64; - let var_round: f64 = 1f64 / 12f64; // lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2) // rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs @@ -367,5 +374,5 @@ pub fn noise_grlwe_product( noise += var_msg * var_a_err * a_scale * a_scale; noise = noise.sqrt(); noise /= b_scale; - noise.log2() + noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}] }