From 09981b78b5166e2c27f4df042a403449c5ed9511 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Wed, 23 Apr 2025 11:32:52 +0200 Subject: [PATCH] trace working --- base2k/src/sampling.rs | 2 +- base2k/src/vec_znx.rs | 27 ++++- base2k/src/vmp.rs | 1 + rlwe/benches/gadget_product.rs | 5 +- rlwe/src/automorphism.rs | 210 ++++++++++++++++++++++----------- rlwe/src/decryptor.rs | 1 + rlwe/src/elem.rs | 2 - rlwe/src/gadget_product.rs | 9 +- rlwe/src/parameters.rs | 2 + rlwe/src/rgsw_product.rs | 4 +- rlwe/src/trace.rs | 143 +++++++++++++++++++--- 11 files changed, 301 insertions(+), 105 deletions(-) diff --git a/base2k/src/sampling.rs b/base2k/src/sampling.rs index 416d3a6..9a52359 100644 --- a/base2k/src/sampling.rs +++ b/base2k/src/sampling.rs @@ -63,7 +63,7 @@ impl Sampling for Module { while dist_f64.abs() > bound { dist_f64 = dist.sample(source) } - *a += (dist_f64.round() as i64) << log_base2k_rem + *a += (dist_f64.round() as i64) << log_base2k_rem; }); } else { a.at_mut(a.cols() - 1).iter_mut().for_each(|a| { diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index d5235b5..ca26eb5 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -347,8 +347,11 @@ pub trait VecZnxOps { /// c <- a - b. fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx); + /// b <- a - b. + fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx); + /// b <- b - a. - fn vec_znx_sub_inplace(&self, b: &mut VecZnx, a: &VecZnx); + fn vec_znx_sub_ba_inplace(&self, b: &mut VecZnx, a: &VecZnx); /// b <- -a. fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx); @@ -452,8 +455,8 @@ impl VecZnxOps for Module { } } - // b <- a + b - fn vec_znx_sub_inplace(&self, b: &mut VecZnx, a: &VecZnx) { + // b <- a - b + fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx) { unsafe { vec_znx::vec_znx_sub( self.ptr, @@ -470,6 +473,24 @@ impl VecZnxOps for Module { } } + // b <- b - a + fn vec_znx_sub_ba_inplace(&self, b: &mut VecZnx, a: &VecZnx) { + unsafe { + vec_znx::vec_znx_sub( + self.ptr, + b.as_mut_ptr(), + b.cols() as u64, + b.n() as u64, + b.as_ptr(), + b.cols() as u64, + b.n() as u64, + a.as_ptr(), + a.cols() as u64, + a.n() as u64, + ) + } + } + fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx) { unsafe { vec_znx::vec_znx_negate( diff --git a/base2k/src/vmp.rs b/base2k/src/vmp.rs index 9484db8..fb5a9d4 100644 --- a/base2k/src/vmp.rs +++ b/base2k/src/vmp.rs @@ -26,6 +26,7 @@ pub struct VmpPMat { /// The ring degree of each [VecZnxDft]. n: usize, + #[warn(dead_code)] backend: BACKEND, } diff --git a/rlwe/benches/gadget_product.rs b/rlwe/benches/gadget_product.rs index 74a7ad4..64f2aad 100644 --- a/rlwe/benches/gadget_product.rs +++ b/rlwe/benches/gadget_product.rs @@ -19,14 +19,13 @@ fn bench_gadget_product_inplace(c: &mut Criterion) { res_dft_0: &'a mut VecZnxDft, res_dft_1: &'a mut VecZnxDft, a: &'a VecZnx, - a_cols: usize, b: &'a Ciphertext, b_cols: usize, tmp_bytes: &'a mut [u8], ) -> Box { Box::new(move || { gadget_product_core( - module, res_dft_0, res_dft_1, a, a_cols, b, b_cols, tmp_bytes, + module, res_dft_0, res_dft_1, a, b, b_cols, tmp_bytes, ); }) } @@ -119,7 +118,6 @@ fn bench_gadget_product_inplace(c: &mut Criterion) { .module() .fill_uniform(params.log_base2k(), &mut a, params.cols_q(), &mut source_xa); - let a_cols: usize = a.cols(); let b_cols: usize = gadget_ct.cols(); let runners: [(String, Box); 1] = [(format!("gadget_product"), { @@ -128,7 +126,6 @@ fn bench_gadget_product_inplace(c: &mut Criterion) { &mut res_dft_0, &mut res_dft_1, &mut a, - a_cols, &gadget_ct, b_cols, &mut tmp_bytes, diff --git a/rlwe/src/automorphism.rs b/rlwe/src/automorphism.rs index b4d84a3..7083716 100644 --- a/rlwe/src/automorphism.rs +++ b/rlwe/src/automorphism.rs @@ -10,7 +10,7 @@ use base2k::{ VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, assert_alignement, }; use sampling::source::Source; -use std::cmp::min; +use std::{cmp::min, collections::HashMap}; /// Stores DFT([-A*AUTO(s, -p) + 2^{-K*i}*s + E, A]) where AUTO(X, p): X^{i} -> X^{i*p} pub struct AutomorphismKey { @@ -33,6 +33,21 @@ impl Parameters { pub fn automorphism_key_new_tmp_bytes(&self, rows: usize, log_q: usize) -> usize { automorphis_key_new_tmp_bytes(self.module(), self.log_base2k(), rows, log_q) } + + pub fn automorphism_tmp_bytes( + &self, + res_logq: usize, + in_logq: usize, + gct_logq: usize, + ) -> usize { + automorphism_tmp_bytes( + self.module(), + self.log_base2k(), + res_logq, + in_logq, + gct_logq, + ) + } } impl AutomorphismKey { @@ -48,34 +63,68 @@ impl AutomorphismKey { sigma: f64, tmp_bytes: &mut [u8], ) -> Self { + Self::new_many_core(module, &vec![p], sk, log_base2k, rows, log_q, source_xa, source_xe, sigma, tmp_bytes).into_iter().next().unwrap() + } + + pub fn new_many(module: &Module, p: &Vec, sk: &SecretKey, log_base2k: usize, rows: usize, log_q: usize, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, tmp_bytes: &mut [u8]) -> HashMap{ + Self::new_many_core( + module, + p, + sk, + log_base2k, + rows, + log_q, + source_xa, + source_xe, + sigma, + tmp_bytes, + ) + .into_iter() + .zip(p.iter().cloned()) + .map(|(key, pi)| (pi, key)) + .collect() + } + + fn new_many_core(module: &Module, p: &Vec, sk: &SecretKey, log_base2k: usize, rows: usize, log_q: usize, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, tmp_bytes: &mut [u8]) -> Vec{ let (sk_auto_bytes, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_scalar()); let (sk_out_bytes, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_svp_ppol()); let sk_auto: Scalar = module.new_scalar_from_bytes_borrow(sk_auto_bytes); let mut sk_out: SvpPPol = module.new_svp_ppol_from_bytes_borrow(sk_out_bytes); - let mut value: Ciphertext = new_gadget_ciphertext(module, log_base2k, rows, log_q); - let p_inv: i64 = module.galois_element_inv(p); + let mut keys: Vec = Vec::new(); + + p.iter().for_each(|pi|{ + let mut value: Ciphertext = new_gadget_ciphertext(module, log_base2k, rows, log_q); - module.vec_znx_automorphism(p_inv, &mut sk_auto.as_vec_znx(), &sk.0.as_vec_znx()); - module.svp_prepare(&mut sk_out, &sk_auto); - encrypt_grlwe_sk( - module, &mut value, &sk.0, &sk_out, source_xa, source_xe, sigma, tmp_bytes, - ); + let p_inv: i64 = module.galois_element_inv(*pi); + + module.vec_znx_automorphism(p_inv, &mut sk_auto.as_vec_znx(), &sk.0.as_vec_znx()); + module.svp_prepare(&mut sk_out, &sk_auto); + encrypt_grlwe_sk( + module, &mut value, &sk.0, &sk_out, source_xa, source_xe, sigma, tmp_bytes, + ); - Self { value: value, p: p } + keys.push(Self { value: value, p: *pi }) + }); + + keys } } pub fn automorphism_tmp_bytes( module: &Module, - c_cols: usize, - a_cols: usize, - b_rows: usize, - b_cols: usize, + log_base2k: usize, + res_logq: usize, + in_logq: usize, + gct_logq: usize, ) -> usize { - return module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, b_rows, b_cols) - + 2 * module.bytes_of_vec_znx_dft(std::cmp::min(c_cols, a_cols)); + let gct_cols: usize = (gct_logq + log_base2k - 1) / log_base2k; + let in_cols: usize = (in_logq + log_base2k - 1) / log_base2k; + let res_cols: usize = (res_logq + log_base2k - 1) / log_base2k; + return module.vmp_apply_dft_to_dft_tmp_bytes(res_cols, in_cols, in_cols, gct_cols) + + module.bytes_of_vec_znx_dft(std::cmp::min(res_cols, in_cols)) + + module.bytes_of_vec_znx_dft(gct_cols); } pub fn automorphism( @@ -83,12 +132,14 @@ pub fn automorphism( c: &mut Ciphertext, a: &Ciphertext, b: &AutomorphismKey, + b_cols: usize, tmp_bytes: &mut [u8], ) { let cols: usize = min(min(c.cols(), a.cols()), b.value.rows()); #[cfg(debug_assertions)] { + assert!(b_cols <= b.value.cols()); assert!( tmp_bytes.len() >= automorphism_tmp_bytes( @@ -102,11 +153,13 @@ pub fn automorphism( assert_alignement(tmp_bytes.as_ptr()); } - let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols)); - let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols)); + let (tmp_bytes_a1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols)); + let (tmp_bytes_res_dft, tmp_bytes) = + tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(b_cols)); - let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_b1_dft); - let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_res_dft); + let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_a1_dft); + let mut res_dft: VecZnxDft = + module.new_vec_znx_dft_from_bytes_borrow(b_cols, tmp_bytes_res_dft); let mut res_big: VecZnxBig = res_dft.as_vec_znx_big(); // a1_dft = DFT(a[1]) @@ -151,12 +204,14 @@ pub fn automorphism_inplace( module: &Module, a: &mut Ciphertext, b: &AutomorphismKey, + b_cols: usize, tmp_bytes: &mut [u8], ) { let cols: usize = min(a.cols(), b.value.rows()); #[cfg(debug_assertions)] { + assert!(b_cols <= b.value.cols()); assert!( tmp_bytes.len() >= automorphism_inplace_tmp_bytes( @@ -174,7 +229,8 @@ pub fn automorphism_inplace( let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols)); let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_b1_dft); - let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_res_dft); + let mut res_dft: VecZnxDft = + module.new_vec_znx_dft_from_bytes_borrow(b_cols, tmp_bytes_res_dft); let mut res_big: VecZnxBig = res_dft.as_vec_znx_big(); // a1_dft = DFT(a[1]) @@ -197,6 +253,11 @@ pub fn automorphism_inplace( module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.value.at(1), tmp_bytes); module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft); + (0..b_cols).for_each(|col_i| { + let raw: &[i64] = res_big.raw::(module); + println!("{:?}", &raw[col_i * module.n()..(col_i + 1) * module.n()]) + }); + // a[1] = b module.vec_znx_big_normalize(a.log_base2k(), a.at_mut(1), &mut res_big, tmp_bytes); @@ -257,28 +318,33 @@ pub fn automorphism_big( #[cfg(test)] mod test { + use super::{AutomorphismKey, automorphism}; use crate::{ - ciphertext::{new_gadget_ciphertext, Ciphertext}, decryptor::decrypt_rlwe, elem::{Elem, ElemCommon, ElemVecZnx}, encryptor::encrypt_rlwe_sk, keys::SecretKey, parameters::{Parameters, ParametersLiteral}, plaintext::Plaintext + ciphertext::Ciphertext, + decryptor::decrypt_rlwe, + elem::ElemCommon, + encryptor::encrypt_rlwe_sk, + keys::SecretKey, + parameters::{Parameters, ParametersLiteral}, + plaintext::Plaintext, }; use base2k::{ - alloc_aligned, Encoding, Infos, Sampling, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, BACKEND + BACKEND, Encoding, Module, SvpPPol, SvpPPolOps, VecZnx, VecZnxOps, alloc_aligned, }; use sampling::source::{Source, new_seed}; - use super::{automorphis_key_new_tmp_bytes, automorphism, AutomorphismKey}; - #[test] fn test_automorphism() { let log_base2k: usize = 10; - let q_cols: usize = 4; - let p_cols: usize = 1; + let log_q: usize = 50; + let log_p: usize = 15; // Basic parameters with enough limbs to test edge cases let params_lit: ParametersLiteral = ParametersLiteral { backend: BACKEND::FFT64, log_n: 12, - log_q: q_cols * log_base2k, - log_p: p_cols * log_base2k, + log_q: log_q, + log_p: log_p, log_base2k: log_base2k, log_scale: 20, xe: 3.2, @@ -287,23 +353,18 @@ mod test { let params: Parameters = Parameters::new(¶ms_lit); - let module: &base2k::Module = params.module(); + let module: &Module = params.module(); let log_q: usize = params.log_q(); let log_qp: usize = params.log_qp(); - let rows: usize = params.cols_q(); + let gct_rows: usize = params.cols_q(); + let gct_cols: usize = params.cols_qp(); // scratch space let mut tmp_bytes: Vec = alloc_aligned( params.decrypt_rlwe_tmp_byte(log_q) | params.encrypt_rlwe_sk_tmp_bytes(log_q) - | params.gadget_product_tmp_bytes( - log_qp, - log_qp, - rows, - log_qp, - ) - | params.encrypt_grlwe_sk_tmp_bytes(rows, log_qp) - | params.automorphism_key_new_tmp_bytes(rows, log_qp), + | params.automorphism_key_new_tmp_bytes(gct_rows, log_qp) + | params.automorphism_tmp_bytes(log_q, log_q, log_qp), ); // Samplers for public and private randomness @@ -311,10 +372,9 @@ mod test { let mut source_xa: Source = Source::new(new_seed()); let mut source_xs: Source = Source::new(new_seed()); - // Two secret keys let mut sk: SecretKey = SecretKey::new(module); sk.fill_ternary_hw(params.xs(), &mut source_xs); - let mut sk_svp_ppol: base2k::SvpPPol = module.new_svp_ppol(); + let mut sk_svp_ppol: SvpPPol = module.new_svp_ppol(); module.svp_prepare(&mut sk_svp_ppol, &sk.0); let p: i64 = -5; @@ -324,7 +384,7 @@ mod test { p, &sk, log_base2k, - rows, + gct_rows, log_qp, &mut source_xa, &mut source_xe, @@ -334,50 +394,64 @@ mod test { let mut data: Vec = vec![0i64; params.n()]; - data.iter_mut().enumerate().for_each(|(i, x)|{ - *x = i as i64 - }); + data.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); - let log_k: usize = 2*log_base2k; + let log_k: usize = 2 * log_base2k; - let mut ct: Ciphertext = Ciphertext::new(module, log_base2k, log_q, 2); - let mut pt: Plaintext = Plaintext::new(module, log_base2k, log_q); + let mut ct: Ciphertext = params.new_ciphertext(log_q); + let mut pt: Plaintext = params.new_plaintext(log_q); + let mut pt_auto: Plaintext = params.new_plaintext(log_q); pt.at_mut(0).encode_vec_i64(log_base2k, log_k, &data, 32); + module.vec_znx_automorphism(p, pt_auto.at_mut(0), pt.at(0)); - encrypt_rlwe_sk(module, &mut ct.elem_mut(), Some(pt.elem()), &sk_svp_ppol, &mut source_xa, &mut source_xe, params.xe(), &mut tmp_bytes); + encrypt_rlwe_sk( + module, + &mut ct.elem_mut(), + Some(pt.elem()), + &sk_svp_ppol, + &mut source_xa, + &mut source_xe, + params.xe(), + &mut tmp_bytes, + ); - module.vec_znx_automorphism_inplace(p, pt.at_mut(0)); + let mut ct_auto: Ciphertext = params.new_ciphertext(log_q); - let mut ct_auto: Ciphertext = Ciphertext::new(module, log_base2k, log_q, 2); + // ct <- AUTO(ct) + automorphism( + module, + &mut ct_auto, + &ct, + &auto_key, + gct_cols, + &mut tmp_bytes, + ); - automorphism(module, &mut ct_auto, &ct, &auto_key, &mut tmp_bytes); + // pt = dec(auto(ct)) - auto(pt) + decrypt_rlwe( + module, + pt.elem_mut(), + ct_auto.elem(), + &sk_svp_ppol, + &mut tmp_bytes, + ); - module.vec_znx_sub_inplace(ct_auto.at_mut(0), pt.at(0)); - ct_auto.at_mut(0).normalize(log_base2k, &mut tmp_bytes); + module.vec_znx_sub_ba_inplace(pt.at_mut(0), pt_auto.at(0)); + + //pt.at(0).print(pt.cols(), 16); - decrypt_rlwe(module, pt.elem_mut(), ct_auto.elem(), &sk_svp_ppol, &mut tmp_bytes); - let noise_have: f64 = pt.at(0).std(log_base2k).log2(); - let var_a_err: f64; - if ct_auto.cols() < ct.cols() { - var_a_err = 1f64 / 12f64; - } else { - var_a_err = 0f64; - } - let var_msg: f64 = (params.xs() as f64) / params.n() as f64; + let var_a_err: f64 = 1f64 / 12f64; let noise_pred: f64 = params.noise_grlwe_product(var_msg, var_a_err, ct_auto.log_q(), auto_key.value.log_q()); - println!("noise_pred: {}", noise_have); - println!("noise_have: {}", noise_pred); - + println!("noise_pred: {}", noise_pred); + println!("noise_have: {}", noise_have); + assert!(noise_have <= noise_pred + 1.0); - - - } } diff --git a/rlwe/src/decryptor.rs b/rlwe/src/decryptor.rs index 8a1e5d7..155222c 100644 --- a/rlwe/src/decryptor.rs +++ b/rlwe/src/decryptor.rs @@ -9,6 +9,7 @@ use base2k::{Module, SvpPPol, SvpPPolOps, VecZnx, VecZnxBigOps, VecZnxDft, VecZn use std::cmp::min; pub struct Decryptor { + #[warn(dead_code)] sk: SvpPPol, } diff --git a/rlwe/src/elem.rs b/rlwe/src/elem.rs index 6be3038..128812f 100644 --- a/rlwe/src/elem.rs +++ b/rlwe/src/elem.rs @@ -1,7 +1,5 @@ use base2k::{Infos, Module, VecZnx, VecZnxOps, VmpPMat, VmpPMatOps}; -use crate::parameters::Parameters; - pub struct Elem { pub value: Vec, pub log_base2k: usize, diff --git a/rlwe/src/gadget_product.rs b/rlwe/src/gadget_product.rs index 4095383..b90ad51 100644 --- a/rlwe/src/gadget_product.rs +++ b/rlwe/src/gadget_product.rs @@ -251,13 +251,10 @@ mod test { // Iterates over all possible cols values for input/output polynomials and gadget ciphertext. (1..a.cols() + 1).for_each(|a_cols| { - let mut a_trunc: VecZnx = params.module().new_vec_znx(a_cols); a_trunc.copy_from(&a); - + (1..gadget_ct.cols() + 1).for_each(|b_cols| { - - let mut res_dft_0: VecZnxDft = params.module().new_vec_znx_dft(b_cols); let mut res_dft_1: VecZnxDft = params.module().new_vec_znx_dft(b_cols); let mut res_big_0: VecZnxBig = res_dft_0.as_vec_znx_big(); @@ -319,7 +316,7 @@ mod test { // a * sk0 + e - a*sk0 = e params .module() - .vec_znx_sub_inplace(pt.at_mut(0), &mut a_times_s); + .vec_znx_sub_ab_inplace(pt.at_mut(0), &mut a_times_s); pt.at_mut(0).normalize(log_base2k, &mut tmp_bytes); //pt.at(0).print(pt.elem().cols(), 16); @@ -347,8 +344,6 @@ mod test { println!("noise_have: {}", noise_have); //assert!(noise_have <= noise_pred + 1.0); - - }); }); } diff --git a/rlwe/src/parameters.rs b/rlwe/src/parameters.rs index 1faa5f4..3d78c50 100644 --- a/rlwe/src/parameters.rs +++ b/rlwe/src/parameters.rs @@ -1,5 +1,7 @@ use base2k::module::{BACKEND, Module}; +pub const DEFAULTSIGMA: f64 = 3.2; + pub struct ParametersLiteral { pub backend: BACKEND, pub log_n: usize, diff --git a/rlwe/src/rgsw_product.rs b/rlwe/src/rgsw_product.rs index e5632cb..352aa39 100644 --- a/rlwe/src/rgsw_product.rs +++ b/rlwe/src/rgsw_product.rs @@ -1,9 +1,9 @@ use crate::{ ciphertext::Ciphertext, - elem::{Elem, ElemCommon, ElemVecZnx}, + elem::{Elem, ElemCommon}, }; use base2k::{ - Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps, + Module, VecZnx, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps, }; use std::cmp::min; diff --git a/rlwe/src/trace.rs b/rlwe/src/trace.rs index 1882345..a92f70a 100644 --- a/rlwe/src/trace.rs +++ b/rlwe/src/trace.rs @@ -1,10 +1,28 @@ -use crate::{automorphism::AutomorphismKey, ciphertext::Ciphertext, elem::ElemCommon}; +use crate::{automorphism::AutomorphismKey, ciphertext::Ciphertext, elem::ElemCommon, parameters::Parameters}; use base2k::{ - Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMatOps, + Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMatOps, assert_alignement, }; use std::collections::HashMap; +pub fn trace_galois_elements(module: &Module) -> Vec{ + let mut gal_els: Vec = Vec::new(); + (0..module.log_n()).for_each(|i|{ + if i == 0 { + gal_els.push(-1); + } else { + gal_els.push(module.galois_element(1 << (i - 1))); + } + }); + gal_els +} + +impl Parameters{ + pub fn trace_tmp_bytes(&self, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize{ + self.automorphism_tmp_bytes(res_logq, in_logq, gct_logq) + } +} + pub fn trace_tmp_bytes( module: &Module, c_cols: usize, @@ -21,17 +39,20 @@ pub fn trace_inplace( a: &mut Ciphertext, start: usize, end: usize, - b: HashMap, + b: &HashMap, + b_cols: usize, tmp_bytes: &mut [u8], ) { let cols: usize = a.cols(); let b_rows: usize; - let b_cols: usize; if let Some((_, key)) = b.iter().next() { b_rows = key.value.rows(); - b_cols = key.value.cols(); + #[cfg(debug_assertions)]{ + println!("{} {}", b_cols, key.value.cols()); + assert!(b_cols <= key.value.cols()) + } } else { panic!("b: HashMap, is empty") } @@ -47,10 +68,10 @@ pub fn trace_inplace( let cols: usize = std::cmp::min(b_cols, a.cols()); let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols)); - let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols)); + let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(b_cols)); let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_b1_dft); - let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_res_dft); + let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(b_cols, tmp_bytes_res_dft); let mut res_big: VecZnxBig = res_dft.as_vec_znx_big(); let log_base2k: usize = a.log_base2k(); @@ -92,21 +113,107 @@ pub fn trace_inplace( #[cfg(test)] mod test { use crate::{ - ciphertext::{Ciphertext, new_gadget_ciphertext}, - decryptor::decrypt_rlwe, - elem::{Elem, ElemCommon, ElemVecZnx}, - encryptor::encrypt_grlwe_sk, - gadget_product::gadget_product_core, - keys::SecretKey, - parameters::{Parameters, ParametersLiteral}, - plaintext::Plaintext, + automorphism::AutomorphismKey, ciphertext::{new_gadget_ciphertext, Ciphertext}, decryptor::decrypt_rlwe, elem::{Elem, ElemCommon, ElemVecZnx}, encryptor::{encrypt_grlwe_sk, encrypt_rlwe_sk}, gadget_product::gadget_product_core, keys::SecretKey, parameters::{Parameters, ParametersLiteral, DEFAULTSIGMA}, plaintext::Plaintext }; use base2k::{ - BACKEND, Infos, Sampling, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, - VecZnxDftOps, VecZnxOps, VmpPMat, alloc_aligned_u8, + BACKEND, Module, Infos, Sampling, SvpPPol, SvpPPolOps, VecZnx, + VecZnxDftOps, VecZnxOps, VmpPMat, alloc_aligned, Encoding, }; use sampling::source::{Source, new_seed}; + use std::collections::HashMap; + use super::{trace_galois_elements, trace_inplace}; #[test] - fn test_trace_inplace() {} + fn test_trace_inplace() { + let log_base2k: usize = 10; + let log_q: usize = 50; + let log_p: usize = 15; + + // Basic parameters with enough limbs to test edge cases + let params_lit: ParametersLiteral = ParametersLiteral { + backend: BACKEND::FFT64, + log_n: 12, + log_q: log_q, + log_p: log_p, + log_base2k: log_base2k, + log_scale: 20, + xe: 3.2, + xs: 1 << 11, + }; + + let params: Parameters = Parameters::new(¶ms_lit); + + let module: &Module = params.module(); + let log_q: usize = params.log_q(); + let log_qp: usize = params.log_qp(); + let gct_rows: usize = params.cols_q(); + let gct_cols: usize = params.cols_qp(); + + // scratch space + let mut tmp_bytes: Vec = alloc_aligned( + params.decrypt_rlwe_tmp_byte(log_q) + | params.encrypt_rlwe_sk_tmp_bytes(log_q) + | params.automorphism_key_new_tmp_bytes(gct_rows, log_qp) + | params.automorphism_tmp_bytes(log_q, log_q, log_qp), + ); + + // Samplers for public and private randomness + let mut source_xe: Source = Source::new(new_seed()); + let mut source_xa: Source = Source::new(new_seed()); + let mut source_xs: Source = Source::new(new_seed()); + + let mut sk: SecretKey = SecretKey::new(module); + sk.fill_ternary_hw(params.xs(), &mut source_xs); + let mut sk_svp_ppol: SvpPPol = module.new_svp_ppol(); + module.svp_prepare(&mut sk_svp_ppol, &sk.0); + + let gal_els: Vec = trace_galois_elements(module); + + let auto_keys: HashMap = AutomorphismKey::new_many(module, &gal_els, &sk, log_base2k, gct_rows, log_qp, &mut source_xa, &mut source_xe, DEFAULTSIGMA, &mut tmp_bytes); + + let mut data: Vec = vec![0i64; params.n()]; + + data.iter_mut().enumerate().for_each(|(i, x)| *x = 1+i as i64); + + let log_k: usize = 2 * log_base2k; + + let mut ct: Ciphertext = params.new_ciphertext(log_q); + let mut pt: Plaintext = params.new_plaintext(log_q); + + pt.at_mut(0).encode_vec_i64(log_base2k, log_k, &data, 32); + pt.at_mut(0).normalize(log_base2k, &mut tmp_bytes); + + pt.at(0).decode_vec_i64(log_base2k, log_k, &mut data); + + pt.at(0).print(pt.cols(), 16); + + encrypt_rlwe_sk( + module, + &mut ct.elem_mut(), + Some(pt.elem()), + &sk_svp_ppol, + &mut source_xa, + &mut source_xe, + params.xe(), + &mut tmp_bytes, + ); + + trace_inplace(module, &mut ct, module.log_n()-2, module.log_n(), &auto_keys, gct_cols, & mut tmp_bytes); + + // pt = dec(auto(ct)) - auto(pt) + decrypt_rlwe( + module, + pt.elem_mut(), + ct.elem(), + &sk_svp_ppol, + &mut tmp_bytes, + ); + + pt.at(0).print(pt.cols(), 16); + + pt.at(0).decode_vec_i64(log_base2k, log_k, &mut data); + + println!("trace: {:?}", &data[..16]); + + } }