From 9695761ff11d087fdd8a76d2b88a89a323cfe5cb Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 22 Apr 2025 23:13:06 +0200 Subject: [PATCH] added automorphism & fixed gadget product noise estimation --- rlwe/src/automorphism.rs | 153 ++++++++++++++++++++++++++++++++----- rlwe/src/gadget_product.rs | 31 +++++--- 2 files changed, 153 insertions(+), 31 deletions(-) diff --git a/rlwe/src/automorphism.rs b/rlwe/src/automorphism.rs index e57cc3a..b4d84a3 100644 --- a/rlwe/src/automorphism.rs +++ b/rlwe/src/automorphism.rs @@ -10,6 +10,7 @@ use base2k::{ VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, assert_alignement, }; use sampling::source::Source; +use std::cmp::min; /// Stores DFT([-A*AUTO(s, -p) + 2^{-K*i}*s + E, A]) where AUTO(X, p): X^{i} -> X^{i*p} pub struct AutomorphismKey { @@ -55,6 +56,7 @@ impl AutomorphismKey { let mut value: Ciphertext = new_gadget_ciphertext(module, log_base2k, rows, log_q); let p_inv: i64 = module.galois_element_inv(p); + module.vec_znx_automorphism(p_inv, &mut sk_auto.as_vec_znx(), &sk.0.as_vec_znx()); module.svp_prepare(&mut sk_out, &sk_auto); encrypt_grlwe_sk( @@ -83,7 +85,7 @@ pub fn automorphism( b: &AutomorphismKey, tmp_bytes: &mut [u8], ) { - let cols = std::cmp::min(c.cols(), a.cols()); + let cols: usize = min(min(c.cols(), a.cols()), b.value.rows()); #[cfg(debug_assertions)] { @@ -134,6 +136,74 @@ pub fn automorphism( module.vec_znx_automorphism_inplace(b.p, c.at_mut(1)); } +pub fn automorphism_inplace_tmp_bytes( + module: &Module, + c_cols: usize, + a_cols: usize, + b_rows: usize, + b_cols: usize, +) -> usize { + return module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, b_rows, b_cols) + + 2 * module.bytes_of_vec_znx_dft(std::cmp::min(c_cols, a_cols)); +} + +pub fn automorphism_inplace( + module: &Module, + a: &mut Ciphertext, + b: &AutomorphismKey, + tmp_bytes: &mut [u8], +) { + let cols: usize = min(a.cols(), b.value.rows()); + + #[cfg(debug_assertions)] + { + assert!( + tmp_bytes.len() + >= automorphism_inplace_tmp_bytes( + module, + a.cols(), + a.cols(), + b.value.rows(), + b.value.cols() + ) + ); + assert_alignement(tmp_bytes.as_ptr()); + } + + let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols)); + let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols)); + + let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_b1_dft); + let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_res_dft); + let mut res_big: VecZnxBig = res_dft.as_vec_znx_big(); + + // a1_dft = DFT(a[1]) + module.vec_znx_dft(&mut a1_dft, a.at(1)); + + // res_dft = IDFT() = [-b*AUTO(s, -p) + a * s + e] + module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.value.at(0), tmp_bytes); + module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft); + + // res_dft = [-b*AUTO(s, -p) + a * s + e] + [-a * s + m + e] = [-b*AUTO(s, -p) + m + e] + module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0)); + + // a[0] = NORMALIZE([-b*AUTO(s, -p) + m + e]) + module.vec_znx_big_normalize(a.log_base2k(), a.at_mut(0), &mut res_big, tmp_bytes); + + // a[0] = AUTO([-b*AUTO(s, -p) + m + e], p) = [-AUTO(b, p)*s + AUTO(m, p) + AUTO(b, e)] + module.vec_znx_automorphism_inplace(b.p, a.at_mut(0)); + + // res_dft = IDFT() = [b] + module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.value.at(1), tmp_bytes); + module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft); + + // a[1] = b + module.vec_znx_big_normalize(a.log_base2k(), a.at_mut(1), &mut res_big, tmp_bytes); + + // a[1] = AUTO(b, p) + module.vec_znx_automorphism_inplace(b.p, a.at_mut(1)); +} + pub fn automorphism_big( module: &Module, c: &mut Ciphertext, @@ -195,7 +265,7 @@ mod test { }; use sampling::source::{Source, new_seed}; - use super::{AutomorphismKey, automorphis_key_new_tmp_bytes}; + use super::{automorphis_key_new_tmp_bytes, automorphism, AutomorphismKey}; #[test] fn test_automorphism() { @@ -217,20 +287,23 @@ mod test { let params: Parameters = Parameters::new(¶ms_lit); + let module: &base2k::Module = params.module(); + let log_q: usize = params.log_q(); + let log_qp: usize = params.log_qp(); let rows: usize = params.cols_q(); // scratch space let mut tmp_bytes: Vec = alloc_aligned( - params.decrypt_rlwe_tmp_byte(params.log_q()) - | params.encrypt_rlwe_sk_tmp_bytes(params.log_q()) + params.decrypt_rlwe_tmp_byte(log_q) + | params.encrypt_rlwe_sk_tmp_bytes(log_q) | params.gadget_product_tmp_bytes( - params.log_qp(), - params.log_qp(), - params.cols_qp(), - params.log_qp(), + log_qp, + log_qp, + rows, + log_qp, ) - | params.encrypt_grlwe_sk_tmp_bytes(rows, params.log_qp()) - | params.automorphism_key_new_tmp_bytes(rows, params.log_qp()), + | params.encrypt_grlwe_sk_tmp_bytes(rows, log_qp) + | params.automorphism_key_new_tmp_bytes(rows, log_qp), ); // Samplers for public and private randomness @@ -239,34 +312,72 @@ mod test { let mut source_xs: Source = Source::new(new_seed()); // Two secret keys - let mut sk: SecretKey = SecretKey::new(params.module()); + let mut sk: SecretKey = SecretKey::new(module); sk.fill_ternary_hw(params.xs(), &mut source_xs); - let mut sk_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol(); - params.module().svp_prepare(&mut sk_svp_ppol, &sk.0); + let mut sk_svp_ppol: base2k::SvpPPol = module.new_svp_ppol(); + module.svp_prepare(&mut sk_svp_ppol, &sk.0); let p: i64 = -5; let auto_key: AutomorphismKey = AutomorphismKey::new( - params.module(), + module, p, &sk, - params.log_base2k(), + log_base2k, rows, - params.log_qp(), + log_qp, &mut source_xa, &mut source_xe, params.xe(), &mut tmp_bytes, ); - let data: Vec = vec![0i64; params.n()]; + let mut data: Vec = vec![0i64; params.n()]; - let mut ct: Ciphertext = Ciphertext::new(params.module(), params.log_base2k(), params.log_q(), 2); - let mut pt: Plaintext = Plaintext::new(params.module(), params.log_base2k(), params.log_q()); + data.iter_mut().enumerate().for_each(|(i, x)|{ + *x = i as i64 + }); - pt.at_mut(0).encode_vec_i64(params.log_base2k(), 2*params.log_base2k(), &data, 32); + let log_k: usize = 2*log_base2k; - encrypt_rlwe_sk(params.module(), &mut ct.elem_mut(), Some(&pt.elem()), &sk_svp_ppol, &mut source_xa, &mut source_xe, params.xe(), &mut tmp_bytes); + let mut ct: Ciphertext = Ciphertext::new(module, log_base2k, log_q, 2); + let mut pt: Plaintext = Plaintext::new(module, log_base2k, log_q); + + pt.at_mut(0).encode_vec_i64(log_base2k, log_k, &data, 32); + + encrypt_rlwe_sk(module, &mut ct.elem_mut(), Some(pt.elem()), &sk_svp_ppol, &mut source_xa, &mut source_xe, params.xe(), &mut tmp_bytes); + + module.vec_znx_automorphism_inplace(p, pt.at_mut(0)); + + let mut ct_auto: Ciphertext = Ciphertext::new(module, log_base2k, log_q, 2); + + automorphism(module, &mut ct_auto, &ct, &auto_key, &mut tmp_bytes); + + module.vec_znx_sub_inplace(ct_auto.at_mut(0), pt.at(0)); + ct_auto.at_mut(0).normalize(log_base2k, &mut tmp_bytes); + + decrypt_rlwe(module, pt.elem_mut(), ct_auto.elem(), &sk_svp_ppol, &mut tmp_bytes); + + let noise_have: f64 = pt.at(0).std(log_base2k).log2(); + + let var_a_err: f64; + if ct_auto.cols() < ct.cols() { + var_a_err = 1f64 / 12f64; + } else { + var_a_err = 0f64; + } + + let var_msg: f64 = (params.xs() as f64) / params.n() as f64; + + let noise_pred: f64 = + params.noise_grlwe_product(var_msg, var_a_err, ct_auto.log_q(), auto_key.value.log_q()); + + println!("noise_pred: {}", noise_have); + println!("noise_have: {}", noise_pred); + + assert!(noise_have <= noise_pred + 1.0); + + } } diff --git a/rlwe/src/gadget_product.rs b/rlwe/src/gadget_product.rs index ad4bf69..4095383 100644 --- a/rlwe/src/gadget_product.rs +++ b/rlwe/src/gadget_product.rs @@ -218,10 +218,6 @@ mod test { ); // Intermediate buffers - let mut res_dft_0: VecZnxDft = params.module().new_vec_znx_dft(gadget_ct.cols()); - let mut res_dft_1: VecZnxDft = params.module().new_vec_znx_dft(gadget_ct.cols()); - let mut res_big_0: VecZnxBig = res_dft_0.as_vec_znx_big(); - let mut res_big_1: VecZnxBig = res_dft_1.as_vec_znx_big(); // Input polynopmial, uniformly distributed let mut a: VecZnx = params.module().new_vec_znx(params.cols_q()); @@ -255,7 +251,18 @@ mod test { // Iterates over all possible cols values for input/output polynomials and gadget ciphertext. (1..a.cols() + 1).for_each(|a_cols| { + + let mut a_trunc: VecZnx = params.module().new_vec_znx(a_cols); + a_trunc.copy_from(&a); + (1..gadget_ct.cols() + 1).for_each(|b_cols| { + + + let mut res_dft_0: VecZnxDft = params.module().new_vec_znx_dft(b_cols); + let mut res_dft_1: VecZnxDft = params.module().new_vec_znx_dft(b_cols); + let mut res_big_0: VecZnxBig = res_dft_0.as_vec_znx_big(); + let mut res_big_1: VecZnxBig = res_dft_1.as_vec_znx_big(); + pt.elem_mut().zero(); elem_res.zero(); @@ -269,7 +276,7 @@ mod test { params.module(), &mut res_dft_0, &mut res_dft_1, - &a, + &a_trunc, &gadget_ct, b_cols, &mut tmp_bytes, @@ -329,15 +336,19 @@ mod test { let a_logq: usize = a_cols * log_base2k; let b_logq: usize = b_cols * log_base2k; - let var_msg: f64 = params.xs() as f64; + let var_msg: f64 = (params.xs() as f64) / params.n() as f64; + + println!("{} {} {} {}", var_msg, var_a_err, a_logq, b_logq); let noise_pred: f64 = params.noise_grlwe_product(var_msg, var_a_err, a_logq, b_logq); - assert!(noise_have <= noise_pred + 1.0); + println!("noise_pred: {}", noise_pred); + println!("noise_have: {}", noise_have); - println!("noise_pred: {}", noise_have); - println!("noise_have: {}", noise_pred); + //assert!(noise_have <= noise_pred + 1.0); + + }); }); } @@ -403,7 +414,7 @@ pub fn noise_grlwe_product( // rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs let mut noise: f64 = (a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs); - noise += var_msg * var_a_err * a_scale * a_scale; + noise += var_msg * var_a_err * a_scale * a_scale * n; noise = noise.sqrt(); noise /= b_scale; noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}]