diff --git a/base2k/src/svp.rs b/base2k/src/svp.rs index c8b6139..b2f0166 100644 --- a/base2k/src/svp.rs +++ b/base2k/src/svp.rs @@ -90,6 +90,10 @@ impl Scalar { } pub fn raw(&self) -> &[i64] { + unsafe { std::slice::from_raw_parts(self.ptr, self.n) } + } + + pub fn raw_mut(&self) -> &mut [i64] { unsafe { std::slice::from_raw_parts_mut(self.ptr, self.n) } } diff --git a/rlwe/benches/gadget_product.rs b/rlwe/benches/gadget_product.rs index 64f2aad..856f508 100644 --- a/rlwe/benches/gadget_product.rs +++ b/rlwe/benches/gadget_product.rs @@ -24,9 +24,7 @@ fn bench_gadget_product_inplace(c: &mut Criterion) { tmp_bytes: &'a mut [u8], ) -> Box { Box::new(move || { - gadget_product_core( - module, res_dft_0, res_dft_1, a, b, b_cols, tmp_bytes, - ); + gadget_product_core(module, res_dft_0, res_dft_1, a, b, b_cols, tmp_bytes); }) } diff --git a/rlwe/src/automorphism.rs b/rlwe/src/automorphism.rs index 7083716..7352b4c 100644 --- a/rlwe/src/automorphism.rs +++ b/rlwe/src/automorphism.rs @@ -63,13 +63,9 @@ impl AutomorphismKey { sigma: f64, tmp_bytes: &mut [u8], ) -> Self { - Self::new_many_core(module, &vec![p], sk, log_base2k, rows, log_q, source_xa, source_xe, sigma, tmp_bytes).into_iter().next().unwrap() - } - - pub fn new_many(module: &Module, p: &Vec, sk: &SecretKey, log_base2k: usize, rows: usize, log_q: usize, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, tmp_bytes: &mut [u8]) -> HashMap{ Self::new_many_core( module, - p, + &vec![p], sk, log_base2k, rows, @@ -80,12 +76,43 @@ impl AutomorphismKey { tmp_bytes, ) .into_iter() + .next() + .unwrap() + } + + pub fn new_many( + module: &Module, + p: &Vec, + sk: &SecretKey, + log_base2k: usize, + rows: usize, + log_q: usize, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + tmp_bytes: &mut [u8], + ) -> HashMap { + Self::new_many_core( + module, p, sk, log_base2k, rows, log_q, source_xa, source_xe, sigma, tmp_bytes, + ) + .into_iter() .zip(p.iter().cloned()) .map(|(key, pi)| (pi, key)) .collect() } - fn new_many_core(module: &Module, p: &Vec, sk: &SecretKey, log_base2k: usize, rows: usize, log_q: usize, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, tmp_bytes: &mut [u8]) -> Vec{ + fn new_many_core( + module: &Module, + p: &Vec, + sk: &SecretKey, + log_base2k: usize, + rows: usize, + log_q: usize, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + tmp_bytes: &mut [u8], + ) -> Vec { let (sk_auto_bytes, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_scalar()); let (sk_out_bytes, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_svp_ppol()); @@ -93,19 +120,23 @@ impl AutomorphismKey { let mut sk_out: SvpPPol = module.new_svp_ppol_from_bytes_borrow(sk_out_bytes); let mut keys: Vec = Vec::new(); - - p.iter().for_each(|pi|{ - let mut value: Ciphertext = new_gadget_ciphertext(module, log_base2k, rows, log_q); + + p.iter().for_each(|pi| { + let mut value: Ciphertext = + new_gadget_ciphertext(module, log_base2k, rows, log_q); let p_inv: i64 = module.galois_element_inv(*pi); - + module.vec_znx_automorphism(p_inv, &mut sk_auto.as_vec_znx(), &sk.0.as_vec_znx()); module.svp_prepare(&mut sk_out, &sk_auto); encrypt_grlwe_sk( module, &mut value, &sk.0, &sk_out, source_xa, source_xe, sigma, tmp_bytes, ); - keys.push(Self { value: value, p: *pi }) + keys.push(Self { + value: value, + p: *pi, + }) }); keys @@ -408,7 +439,7 @@ mod test { encrypt_rlwe_sk( module, &mut ct.elem_mut(), - Some(pt.elem()), + Some(pt.at(0)), &sk_svp_ppol, &mut source_xa, &mut source_xe, diff --git a/rlwe/src/encryptor.rs b/rlwe/src/encryptor.rs index 8d8f3c4..a10da7f 100644 --- a/rlwe/src/encryptor.rs +++ b/rlwe/src/encryptor.rs @@ -5,12 +5,38 @@ use crate::parameters::Parameters; use crate::plaintext::Plaintext; use base2k::sampling::Sampling; use base2k::{ - Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, - VecZnxOps, VmpPMat, VmpPMatOps, + Infos, Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, + VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, }; use sampling::source::{Source, new_seed}; +impl Parameters { + pub fn encrypt_rlwe_sk_tmp_bytes(&self, log_q: usize) -> usize { + encrypt_rlwe_sk_tmp_bytes(self.module(), self.log_base2k(), log_q) + } + pub fn encrypt_rlwe_sk( + &self, + ct: &mut Ciphertext, + pt: Option<&Plaintext>, + sk: &SvpPPol, + source_xa: &mut Source, + source_xe: &mut Source, + tmp_bytes: &mut [u8], + ) { + encrypt_rlwe_sk( + self.module(), + &mut ct.0, + pt.map(|pt| pt.at(0)), + sk, + source_xa, + source_xe, + self.xe(), + tmp_bytes, + ) + } +} + pub struct EncryptorSk { sk: SvpPPol, source_xa: Source, @@ -86,42 +112,27 @@ impl EncryptorSk { } } -impl Parameters { - pub fn encrypt_rlwe_sk_tmp_bytes(&self, log_q: usize) -> usize { - encrypt_rlwe_sk_tmp_bytes(self.module(), self.log_base2k(), log_q) - } - - pub fn encrypt_rlwe_sk( - &self, - ct: &mut Ciphertext, - pt: Option<&Plaintext>, - sk: &SvpPPol, - source_xa: &mut Source, - source_xe: &mut Source, - tmp_bytes: &mut [u8], - ) { - encrypt_rlwe_sk( - self.module(), - &mut ct.0, - pt.map(|pt| &pt.0), - sk, - source_xa, - source_xe, - self.xe(), - tmp_bytes, - ) - } -} - pub fn encrypt_rlwe_sk_tmp_bytes(module: &Module, log_base2k: usize, log_q: usize) -> usize { module.bytes_of_vec_znx_dft((log_q + log_base2k - 1) / log_base2k) + module.vec_znx_big_normalize_tmp_bytes() } - pub fn encrypt_rlwe_sk( module: &Module, ct: &mut Elem, - pt: Option<&Elem>, + pt: Option<&VecZnx>, + sk: &SvpPPol, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + tmp_bytes: &mut [u8], +) { + encrypt_rlwe_sk_core::<0>(module, ct, pt, sk, source_xa, source_xe, sigma, tmp_bytes) +} + +fn encrypt_rlwe_sk_core( + module: &Module, + ct: &mut Elem, + pt: Option<&VecZnx>, sk: &SvpPPol, source_xa: &mut Source, source_xe: &mut Source, @@ -161,21 +172,35 @@ pub fn encrypt_rlwe_sk( // buf_big = s x c1 module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft); - // c0 <- -s x c1 + m - let c0: &mut VecZnx = ct.at_mut(0); - - if let Some(pt) = pt { - module.vec_znx_big_sub_small_a_inplace(&mut buf_big, pt.at(0)); - module.vec_znx_big_normalize(log_base2k, c0, &buf_big, tmp_bytes_normalize); - } else { - module.vec_znx_big_normalize(log_base2k, c0, &buf_big, tmp_bytes_normalize); - module.vec_znx_negate_inplace(c0); + match PT_POS { + // c0 <- -s x c1 + m + 0 => { + let c0: &mut VecZnx = ct.at_mut(0); + if let Some(pt) = pt { + module.vec_znx_big_sub_small_a_inplace(&mut buf_big, pt); + module.vec_znx_big_normalize(log_base2k, c0, &buf_big, tmp_bytes_normalize); + } else { + module.vec_znx_big_normalize(log_base2k, c0, &buf_big, tmp_bytes_normalize); + module.vec_znx_negate_inplace(c0); + } + } + // c1 <- c1 + m + 1 => { + if let Some(pt) = pt { + module.vec_znx_add_inplace(c1, pt); + c1.normalize(log_base2k, tmp_bytes_normalize); + } + let c0: &mut VecZnx = ct.at_mut(0); + module.vec_znx_big_normalize(log_base2k, c0, &buf_big, tmp_bytes_normalize); + module.vec_znx_negate_inplace(c0); + } + _ => panic!("PT_POS must be 1 or 2"), } // c0 <- -s x c1 + m + e module.add_normal( log_base2k, - c0, + ct.at_mut(0), log_q, source_xe, sigma, @@ -212,10 +237,98 @@ pub fn encrypt_grlwe_sk( sigma: f64, tmp_bytes: &mut [u8], ) { - let rows: usize = ct.rows(); let log_q: usize = ct.log_q(); - //let cols: usize = (log_q + ct.log_base2k() - 1) / ct.log_base2k(); let log_base2k: usize = ct.log_base2k(); + let (left, right) = ct.0.value.split_at_mut(1); + encrypt_grlwe_sk_core::<0>( + module, + log_base2k, + [&mut left[0], &mut right[0]], + log_q, + m, + sk, + source_xa, + source_xe, + sigma, + tmp_bytes, + ) +} + +impl Parameters { + pub fn encrypt_rgsw_sk_tmp_bytes(&self, rows: usize, log_q: usize) -> usize { + encrypt_rgsw_sk_tmp_bytes(self.module(), self.log_base2k(), rows, log_q) + } +} + +pub fn encrypt_rgsw_sk_tmp_bytes( + module: &Module, + log_base2k: usize, + rows: usize, + log_q: usize, +) -> usize { + let cols = (log_q + log_base2k - 1) / log_base2k; + Elem::::bytes_of(module, log_base2k, log_q, 2) + + Plaintext::bytes_of(module, log_base2k, log_q) + + encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q) + + module.vmp_prepare_tmp_bytes(rows, cols) +} + +pub fn encrypt_rgsw_sk( + module: &Module, + ct: &mut Ciphertext, + m: &Scalar, + sk: &SvpPPol, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + tmp_bytes: &mut [u8], +) { + let log_q: usize = ct.log_q(); + let log_base2k: usize = ct.log_base2k(); + + let (left, right) = ct.0.value.split_at_mut(2); + let (ll, lr) = left.split_at_mut(1); + let (rl, rr) = right.split_at_mut(1); + + encrypt_grlwe_sk_core::<0>( + module, + log_base2k, + [&mut ll[0], &mut lr[0]], + log_q, + m, + sk, + source_xa, + source_xe, + sigma, + tmp_bytes, + ); + encrypt_grlwe_sk_core::<1>( + module, + log_base2k, + [&mut rl[0], &mut rr[0]], + log_q, + m, + sk, + source_xa, + source_xe, + sigma, + tmp_bytes, + ); +} + +fn encrypt_grlwe_sk_core( + module: &Module, + log_base2k: usize, + mut ct: [&mut VmpPMat; 2], + log_q: usize, + m: &Scalar, + sk: &SvpPPol, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + tmp_bytes: &mut [u8], +) { + let rows: usize = ct[0].rows(); let min_tmp_bytes_len = encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q); @@ -235,7 +348,7 @@ pub fn encrypt_grlwe_sk( let (tmp_bytes_elem, tmp_bytes_vmp_prepare_row) = tmp_bytes.split_at_mut(bytes_of_elem); let mut tmp_elem: Elem = - Elem::::from_bytes_borrow(module, log_base2k, ct.log_q(), 2, tmp_bytes_elem); + Elem::::from_bytes_borrow(module, log_base2k, log_q, 2, tmp_bytes_elem); let mut tmp_pt: Plaintext = Plaintext::from_bytes_borrow(module, log_base2k, log_q, tmp_bytes_pt); @@ -244,10 +357,10 @@ pub fn encrypt_grlwe_sk( tmp_pt.at_mut(0).at_mut(row_i).copy_from_slice(&m.raw()); // Encrypts RLWE(m * 2^{-log_base2k*i}) - encrypt_rlwe_sk( + encrypt_rlwe_sk_core::( module, &mut tmp_elem, - Some(&tmp_pt.0), + Some(tmp_pt.at(0)), sk, source_xa, source_xe, @@ -255,31 +368,21 @@ pub fn encrypt_grlwe_sk( tmp_bytes_enc_sk, ); - //tmp_pt.at(0).print(tmp_pt.cols(), 16); - //println!(); - // Zeroes the ith-row of tmp_pt tmp_pt.at_mut(0).at_mut(row_i).fill(0); - //println!("row:{}/{}", row_i, rows); - //tmp_elem.at(0).print(tmp_elem.cols(), tmp_elem.n()); - //tmp_elem.at(1).print(tmp_elem.cols(), tmp_elem.n()); - //println!(); - //println!(">>>"); - // GRLWE[row_i][0||1] = [-as + m * 2^{-i*log_base2k} + e*2^{-log_q} || a] module.vmp_prepare_row( - &mut ct.at_mut(0), + ct[0], tmp_elem.at(0).raw(), row_i, tmp_bytes_vmp_prepare_row, ); module.vmp_prepare_row( - &mut ct.at_mut(1), + &mut ct[1], tmp_elem.at(1).raw(), row_i, tmp_bytes_vmp_prepare_row, ); }); - //println!("DONE"); } diff --git a/rlwe/src/parameters.rs b/rlwe/src/parameters.rs index 3d78c50..cd3a91d 100644 --- a/rlwe/src/parameters.rs +++ b/rlwe/src/parameters.rs @@ -1,6 +1,6 @@ use base2k::module::{BACKEND, Module}; -pub const DEFAULTSIGMA: f64 = 3.2; +pub const DEFAULT_SIGMA: f64 = 3.2; pub struct ParametersLiteral { pub backend: BACKEND, diff --git a/rlwe/src/rgsw_product.rs b/rlwe/src/rgsw_product.rs index 352aa39..13d2a90 100644 --- a/rlwe/src/rgsw_product.rs +++ b/rlwe/src/rgsw_product.rs @@ -1,54 +1,349 @@ -use crate::{ - ciphertext::Ciphertext, - elem::{Elem, ElemCommon}, -}; +use crate::{ciphertext::Ciphertext, elem::ElemCommon, parameters::Parameters}; use base2k::{ - Module, VecZnx, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps, + Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps, + assert_alignement, }; use std::cmp::min; +impl Parameters { + pub fn rgsw_product_tmp_bytes( + &self, + res_logq: usize, + in_logq: usize, + gct_logq: usize, + ) -> usize { + rgsw_product_tmp_bytes( + self.module(), + self.log_base2k(), + res_logq, + in_logq, + gct_logq, + ) + } +} +pub fn rgsw_product_tmp_bytes( + module: &Module, + log_base2k: usize, + res_logq: usize, + in_logq: usize, + gct_logq: usize, +) -> usize { + let gct_cols: usize = (gct_logq + log_base2k - 1) / log_base2k; + let in_cols: usize = (in_logq + log_base2k - 1) / log_base2k; + let res_cols: usize = (res_logq + log_base2k - 1) / log_base2k; + return module.vmp_apply_dft_to_dft_tmp_bytes(res_cols, in_cols, in_cols, gct_cols) + + module.bytes_of_vec_znx_dft(std::cmp::min(res_cols, in_cols)) + + 2 * module.bytes_of_vec_znx_dft(gct_cols); +} + pub fn rgsw_product( module: &Module, - _res: &mut Elem, + c: &mut Ciphertext, a: &Ciphertext, b: &Ciphertext, + b_cols: usize, tmp_bytes: &mut [u8], ) { - let _log_base2k: usize = b.log_base2k(); - let rows: usize = min(b.rows(), a.cols()); - let cols: usize = b.cols(); - let in_cols = a.cols(); - let out_cols: usize = a.cols(); + #[cfg(debug_assertions)] + { + assert!(b_cols <= b.cols()); + assert_eq!(c.size(), 2); + assert_eq!(a.size(), 2); + assert_eq!(b.size(), 4); + assert!( + tmp_bytes.len() + >= rgsw_product_tmp_bytes( + module, + c.cols(), + a.cols(), + min(b.rows(), a.cols()), + b_cols + ) + ); + assert_alignement(tmp_bytes.as_ptr()); + } - let bytes_of_vec_znx_dft = module.bytes_of_vec_znx_dft(cols); - let bytes_of_vmp_apply_dft_to_dft = - module.vmp_apply_dft_to_dft_tmp_bytes(out_cols, in_cols, rows, cols); + let (tmp_bytes_ai_dft, tmp_bytes) = + tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(a.cols())); + let (tmp_bytes_c0_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(b_cols)); + let (tmp_bytes_c1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(b_cols)); - let (tmp_bytes_c0_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft); - let (tmp_bytes_c1_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft); - let (tmp_bytes_tmp_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft); - let (tmp_bytes_r1_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft); - let (tmp_bytes_r2_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft); - let (bytes_of_vmp_apply_dft_to_dft, tmp_bytes) = - tmp_bytes.split_at_mut(bytes_of_vmp_apply_dft_to_dft); + let mut ai_dft: VecZnxDft = + module.new_vec_znx_dft_from_bytes_borrow(a.cols(), tmp_bytes_ai_dft); + let mut c0_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(b_cols, tmp_bytes_c0_dft); + let mut c1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(b_cols, tmp_bytes_c1_dft); - let mut c0_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_c0_dft); - let mut c1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_c1_dft); - let mut _tmp_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_tmp_dft); - let mut r1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_r1_dft); - let mut _r2_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_r2_dft); + let mut c0_big: VecZnxBig = c0_dft.as_vec_znx_big(); + let mut c1_big: VecZnxBig = c1_dft.as_vec_znx_big(); - // c0_dft <- DFT(a[0]) - module.vec_znx_dft(&mut c0_dft, a.at(0)); + module.vec_znx_dft(&mut ai_dft, a.at(0)); + module.vmp_apply_dft_to_dft(&mut c0_dft, &ai_dft, b.at(0), tmp_bytes); + module.vmp_apply_dft_to_dft(&mut c1_dft, &ai_dft, b.at(1), tmp_bytes); - // r_dft <- sum[rows] c0_dft[cols] x RGSW[0][cols] - module.vmp_apply_dft_to_dft( - &mut r1_dft, - &c1_dft, - &b.0.value[0], - bytes_of_vmp_apply_dft_to_dft, - ); + module.vec_znx_dft(&mut ai_dft, a.at(1)); + module.vmp_apply_dft_to_dft_add(&mut c0_dft, &ai_dft, b.at(2), tmp_bytes); + module.vmp_apply_dft_to_dft_add(&mut c1_dft, &ai_dft, b.at(3), tmp_bytes); - // c1_dft <- DFT(a[1]) - module.vec_znx_dft(&mut c1_dft, a.at(1)); + module.vec_znx_idft_tmp_a(&mut c0_big, &mut c0_dft); + module.vec_znx_idft_tmp_a(&mut c1_big, &mut c1_dft); + + module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(0), &mut c0_big, tmp_bytes); + module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(1), &mut c1_big, tmp_bytes); +} + +pub fn rgsw_product_inplace( + module: &Module, + a: &mut Ciphertext, + b: &Ciphertext, + b_cols: usize, + tmp_bytes: &mut [u8], +) { + #[cfg(debug_assertions)] + { + assert!(b_cols <= b.cols()); + assert_eq!(a.size(), 2); + assert_eq!(b.size(), 4); + assert!( + tmp_bytes.len() + >= rgsw_product_tmp_bytes( + module, + a.cols(), + a.cols(), + min(b.rows(), a.cols()), + b_cols + ) + ); + assert_alignement(tmp_bytes.as_ptr()); + } + + let (tmp_bytes_ai_dft, tmp_bytes) = + tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(a.cols())); + let (tmp_bytes_c0_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(b_cols)); + let (tmp_bytes_c1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(b_cols)); + + let mut ai_dft: VecZnxDft = + module.new_vec_znx_dft_from_bytes_borrow(a.cols(), tmp_bytes_ai_dft); + let mut c0_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(b_cols, tmp_bytes_c0_dft); + let mut c1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(b_cols, tmp_bytes_c1_dft); + + let mut c0_big: VecZnxBig = c0_dft.as_vec_znx_big(); + let mut c1_big: VecZnxBig = c1_dft.as_vec_znx_big(); + + module.vec_znx_dft(&mut ai_dft, a.at(0)); + module.vmp_apply_dft_to_dft(&mut c0_dft, &ai_dft, b.at(0), tmp_bytes); + module.vmp_apply_dft_to_dft(&mut c1_dft, &ai_dft, b.at(1), tmp_bytes); + + module.vec_znx_dft(&mut ai_dft, a.at(1)); + module.vmp_apply_dft_to_dft_add(&mut c0_dft, &ai_dft, b.at(2), tmp_bytes); + module.vmp_apply_dft_to_dft_add(&mut c1_dft, &ai_dft, b.at(3), tmp_bytes); + + module.vec_znx_idft_tmp_a(&mut c0_big, &mut c0_dft); + module.vec_znx_idft_tmp_a(&mut c1_big, &mut c1_dft); + + module.vec_znx_big_normalize(a.log_base2k(), a.at_mut(0), &mut c0_big, tmp_bytes); + module.vec_znx_big_normalize(a.log_base2k(), a.at_mut(1), &mut c1_big, tmp_bytes); +} + +#[cfg(test)] +mod test { + use crate::{ + ciphertext::{Ciphertext, new_rgsw_ciphertext}, + decryptor::decrypt_rlwe, + elem::ElemCommon, + encryptor::{encrypt_rgsw_sk, encrypt_rlwe_sk}, + keys::SecretKey, + parameters::{DEFAULT_SIGMA, Parameters, ParametersLiteral}, + plaintext::Plaintext, + rgsw_product::rgsw_product_inplace, + }; + use base2k::{ + BACKEND, Encoding, Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxOps, VmpPMat, + alloc_aligned, + }; + use sampling::source::{Source, new_seed}; + + #[test] + fn test_rgsw_product() { + let log_base2k: usize = 10; + let log_q: usize = 50; + let log_p: usize = 15; + + // Basic parameters with enough limbs to test edge cases + let params_lit: ParametersLiteral = ParametersLiteral { + backend: BACKEND::FFT64, + log_n: 12, + log_q: log_q, + log_p: log_p, + log_base2k: log_base2k, + log_scale: 20, + xe: 3.2, + xs: 1 << 11, + }; + + let params: Parameters = Parameters::new(¶ms_lit); + + let module: &Module = params.module(); + let log_q: usize = params.log_q(); + let log_qp: usize = params.log_qp(); + let gct_rows: usize = params.cols_q(); + let gct_cols: usize = params.cols_qp(); + + // scratch space + let mut tmp_bytes: Vec = alloc_aligned( + params.decrypt_rlwe_tmp_byte(log_q) + | params.encrypt_rlwe_sk_tmp_bytes(log_q) + | params.rgsw_product_tmp_bytes(log_q, log_q, log_qp) + | params.encrypt_rgsw_sk_tmp_bytes(gct_rows, log_qp), + ); + + // Samplers for public and private randomness + let mut source_xe: Source = Source::new(new_seed()); + let mut source_xa: Source = Source::new(new_seed()); + let mut source_xs: Source = Source::new(new_seed()); + + let mut sk: SecretKey = SecretKey::new(module); + sk.fill_ternary_hw(params.xs(), &mut source_xs); + let mut sk_svp_ppol: SvpPPol = module.new_svp_ppol(); + module.svp_prepare(&mut sk_svp_ppol, &sk.0); + + let mut ct_rgsw: Ciphertext = + new_rgsw_ciphertext(module, log_base2k, gct_rows, log_qp); + + let k: i64 = 3; + + // X^k + let m: Scalar = module.new_scalar(); + let data: &mut [i64] = m.raw_mut(); + data[k as usize] = 1; + + encrypt_rgsw_sk( + module, + &mut ct_rgsw, + &m, + &sk_svp_ppol, + &mut source_xa, + &mut source_xe, + DEFAULT_SIGMA, + &mut tmp_bytes, + ); + + let log_k: usize = 2 * log_base2k; + + let mut ct: Ciphertext = params.new_ciphertext(log_q); + let mut pt: Plaintext = params.new_plaintext(log_q); + let mut pt_rotate: Plaintext = params.new_plaintext(log_q); + + pt.at_mut(0).encode_vec_i64(log_base2k, log_k, &data, 32); + + module.vec_znx_rotate(k, pt_rotate.at_mut(0), pt.at_mut(0)); + + encrypt_rlwe_sk( + module, + &mut ct.elem_mut(), + Some(pt.at(0)), + &sk_svp_ppol, + &mut source_xa, + &mut source_xe, + params.xe(), + &mut tmp_bytes, + ); + + rgsw_product_inplace(module, &mut ct, &ct_rgsw, gct_cols, &mut tmp_bytes); + + decrypt_rlwe( + module, + pt.elem_mut(), + ct.elem(), + &sk_svp_ppol, + &mut tmp_bytes, + ); + + module.vec_znx_sub_ba_inplace(pt.at_mut(0), pt_rotate.at(0)); + + //pt.at(0).print(pt.cols(), 16); + + let noise_have: f64 = pt.at(0).std(log_base2k).log2(); + + let var_msg: f64 = 1f64 / params.n() as f64; // X^{k} + let var_a0_err: f64 = params.xe() * params.xe(); + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_pred: f64 = + params.noise_rgsw_product(var_msg, var_a0_err, var_a1_err, ct.log_q(), ct_rgsw.log_q()); + + println!("noise_pred: {}", noise_pred); + println!("noise_have: {}", noise_have); + + assert!(noise_have <= noise_pred + 1.0); + } +} + +impl Parameters { + pub fn noise_rgsw_product( + &self, + var_msg: f64, + var_a0_err: f64, + var_a1_err: f64, + a_logq: usize, + b_logq: usize, + ) -> f64 { + let n: f64 = self.n() as f64; + let var_xs: f64 = self.xs() as f64; + + let var_gct_err_lhs: f64; + let var_gct_err_rhs: f64; + if b_logq < self.log_qp() { + let var_round: f64 = 1f64 / 12f64; + var_gct_err_lhs = var_round; + var_gct_err_rhs = var_round; + } else { + var_gct_err_lhs = self.xe() * self.xe(); + var_gct_err_rhs = 0f64; + } + + noise_rgsw_product( + n, + self.log_base2k(), + var_xs, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + a_logq, + b_logq, + ) + } +} + +pub fn noise_rgsw_product( + n: f64, + log_base2k: usize, + var_xs: f64, + var_msg: f64, + var_a0_err: f64, + var_a1_err: f64, + var_gct_err_lhs: f64, + var_gct_err_rhs: f64, + a_logq: usize, + b_logq: usize, +) -> f64 { + let a_logq: usize = min(a_logq, b_logq); + let a_cols: usize = (a_logq + log_base2k - 1) / log_base2k; + + let b_scale = 2.0f64.powi(b_logq as i32); + let a_scale: f64 = 2.0f64.powi((b_logq - a_logq) as i32); + + let base: f64 = (1 << (log_base2k)) as f64; + let var_base: f64 = base * base / 12f64; + + // lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2) + // rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs + let mut noise: f64 = + 2.0 * (a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs); + noise += var_msg * var_a0_err * a_scale * a_scale * n; + noise += var_msg * var_a1_err * a_scale * a_scale * n * var_xs; + noise = noise.sqrt(); + noise /= b_scale; + noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}] } diff --git a/rlwe/src/test.rs b/rlwe/src/test.rs new file mode 100644 index 0000000..2a7e9d0 --- /dev/null +++ b/rlwe/src/test.rs @@ -0,0 +1,113 @@ +use base2k::{alloc_aligned, SvpPPol, SvpPPolOps, VecZnx, BACKEND}; +use sampling::source::{Source, new_seed}; +use crate::{ciphertext::Ciphertext, decryptor::decrypt_rlwe, elem::ElemCommon, encryptor::encrypt_rlwe_sk, keys::SecretKey, parameters::{Parameters, ParametersLiteral, DEFAULT_SIGMA}, plaintext::Plaintext}; + + + +pub struct Context{ + pub params: Parameters, + pub sk0: SecretKey, + pub sk0_ppol:SvpPPol, + pub sk1: SecretKey, + pub sk1_ppol: SvpPPol, + pub tmp_bytes: Vec, +} + +impl Context{ + pub fn new(log_n: usize, log_base2k: usize, log_q: usize, log_p: usize) -> Self{ + + let params_lit: ParametersLiteral = ParametersLiteral { + backend: BACKEND::FFT64, + log_n: log_n, + log_q: log_q, + log_p: log_p, + log_base2k: log_base2k, + log_scale: 20, + xe: DEFAULT_SIGMA, + xs: 1 << (log_n-1), + }; + + let params: Parameters =Parameters::new(¶ms_lit); + let module = params.module(); + + let log_q: usize = params.log_q(); + + let mut source_xs: Source = Source::new(new_seed()); + + let mut sk0: SecretKey = SecretKey::new(module); + sk0.fill_ternary_hw(params.xs(), &mut source_xs); + let mut sk0_ppol: base2k::SvpPPol = module.new_svp_ppol(); + module.svp_prepare(&mut sk0_ppol, &sk0.0); + + let mut sk1: SecretKey = SecretKey::new(module); + sk1.fill_ternary_hw(params.xs(), &mut source_xs); + let mut sk1_ppol: base2k::SvpPPol = module.new_svp_ppol(); + module.svp_prepare(&mut sk1_ppol, &sk1.0); + + let tmp_bytes: Vec = alloc_aligned(params.decrypt_rlwe_tmp_byte(log_q)| params.encrypt_rlwe_sk_tmp_bytes(log_q)); + + Context{ + params: params, + sk0: sk0, + sk0_ppol: sk0_ppol, + sk1: sk1, + sk1_ppol: sk1_ppol, + tmp_bytes: tmp_bytes, + + } + } + + pub fn encrypt_rlwe_sk0(&mut self, pt: &Plaintext, ct: &mut Ciphertext){ + + let mut source_xe: Source = Source::new(new_seed()); + let mut source_xa: Source = Source::new(new_seed()); + + encrypt_rlwe_sk( + self.params.module(), + ct.elem_mut(), + Some(pt.elem()), + &self.sk0_ppol, + &mut source_xa, + &mut source_xe, + self.params.xe(), + &mut self.tmp_bytes, + ); + } + + pub fn encrypt_rlwe_sk1(&mut self, ct: &mut Ciphertext, pt: &Plaintext){ + + let mut source_xe: Source = Source::new(new_seed()); + let mut source_xa: Source = Source::new(new_seed()); + + encrypt_rlwe_sk( + self.params.module(), + ct.elem_mut(), + Some(pt.elem()), + &self.sk1_ppol, + &mut source_xa, + &mut source_xe, + self.params.xe(), + &mut self.tmp_bytes, + ); + } + + pub fn decrypt_sk0(&mut self, pt: &mut Plaintext, ct: &Ciphertext){ + decrypt_rlwe( + self.params.module(), + pt.elem_mut(), + ct.elem(), + &self.sk0_ppol, + &mut self.tmp_bytes, + ); + } + + pub fn decrypt_sk1(&mut self, pt: &mut Plaintext, ct: &Ciphertext){ + decrypt_rlwe( + self.params.module(), + pt.elem_mut(), + ct.elem(), + &self.sk1_ppol, + &mut self.tmp_bytes, + ); + } +} \ No newline at end of file diff --git a/rlwe/src/trace.rs b/rlwe/src/trace.rs index a92f70a..85a8212 100644 --- a/rlwe/src/trace.rs +++ b/rlwe/src/trace.rs @@ -1,13 +1,14 @@ -use crate::{automorphism::AutomorphismKey, ciphertext::Ciphertext, elem::ElemCommon, parameters::Parameters}; +use crate::{ + automorphism::AutomorphismKey, ciphertext::Ciphertext, elem::ElemCommon, parameters::Parameters, +}; use base2k::{ - Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMatOps, - assert_alignement, + Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMatOps, assert_alignement, }; use std::collections::HashMap; -pub fn trace_galois_elements(module: &Module) -> Vec{ +pub fn trace_galois_elements(module: &Module) -> Vec { let mut gal_els: Vec = Vec::new(); - (0..module.log_n()).for_each(|i|{ + (0..module.log_n()).for_each(|i| { if i == 0 { gal_els.push(-1); } else { @@ -17,8 +18,8 @@ pub fn trace_galois_elements(module: &Module) -> Vec{ gal_els } -impl Parameters{ - pub fn trace_tmp_bytes(&self, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize{ +impl Parameters { + pub fn trace_tmp_bytes(&self, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize { self.automorphism_tmp_bytes(res_logq, in_logq, gct_logq) } } @@ -49,7 +50,8 @@ pub fn trace_inplace( if let Some((_, key)) = b.iter().next() { b_rows = key.value.rows(); - #[cfg(debug_assertions)]{ + #[cfg(debug_assertions)] + { println!("{} {}", b_cols, key.value.cols()); assert!(b_cols <= key.value.cols()) } @@ -68,10 +70,12 @@ pub fn trace_inplace( let cols: usize = std::cmp::min(b_cols, a.cols()); let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols)); - let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(b_cols)); + let (tmp_bytes_res_dft, tmp_bytes) = + tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(b_cols)); let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_b1_dft); - let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(b_cols, tmp_bytes_res_dft); + let mut res_dft: VecZnxDft = + module.new_vec_znx_dft_from_bytes_borrow(b_cols, tmp_bytes_res_dft); let mut res_big: VecZnxBig = res_dft.as_vec_znx_big(); let log_base2k: usize = a.log_base2k(); @@ -112,16 +116,20 @@ pub fn trace_inplace( #[cfg(test)] mod test { + use super::{trace_galois_elements, trace_inplace}; use crate::{ - automorphism::AutomorphismKey, ciphertext::{new_gadget_ciphertext, Ciphertext}, decryptor::decrypt_rlwe, elem::{Elem, ElemCommon, ElemVecZnx}, encryptor::{encrypt_grlwe_sk, encrypt_rlwe_sk}, gadget_product::gadget_product_core, keys::SecretKey, parameters::{Parameters, ParametersLiteral, DEFAULTSIGMA}, plaintext::Plaintext - }; - use base2k::{ - BACKEND, Module, Infos, Sampling, SvpPPol, SvpPPolOps, VecZnx, - VecZnxDftOps, VecZnxOps, VmpPMat, alloc_aligned, Encoding, + automorphism::AutomorphismKey, + ciphertext::Ciphertext, + decryptor::decrypt_rlwe, + elem::ElemCommon, + encryptor::encrypt_rlwe_sk, + keys::SecretKey, + parameters::{DEFAULT_SIGMA, Parameters, ParametersLiteral}, + plaintext::Plaintext, }; + use base2k::{BACKEND, Encoding, Module, SvpPPol, SvpPPolOps, VecZnx, alloc_aligned}; use sampling::source::{Source, new_seed}; use std::collections::HashMap; - use super::{trace_galois_elements, trace_inplace}; #[test] fn test_trace_inplace() { @@ -169,11 +177,24 @@ mod test { let gal_els: Vec = trace_galois_elements(module); - let auto_keys: HashMap = AutomorphismKey::new_many(module, &gal_els, &sk, log_base2k, gct_rows, log_qp, &mut source_xa, &mut source_xe, DEFAULTSIGMA, &mut tmp_bytes); + let auto_keys: HashMap = AutomorphismKey::new_many( + module, + &gal_els, + &sk, + log_base2k, + gct_rows, + log_qp, + &mut source_xa, + &mut source_xe, + DEFAULT_SIGMA, + &mut tmp_bytes, + ); let mut data: Vec = vec![0i64; params.n()]; - data.iter_mut().enumerate().for_each(|(i, x)| *x = 1+i as i64); + data.iter_mut() + .enumerate() + .for_each(|(i, x)| *x = 1 + i as i64); let log_k: usize = 2 * log_base2k; @@ -190,7 +211,7 @@ mod test { encrypt_rlwe_sk( module, &mut ct.elem_mut(), - Some(pt.elem()), + Some(pt.at(0)), &sk_svp_ppol, &mut source_xa, &mut source_xe, @@ -198,7 +219,16 @@ mod test { &mut tmp_bytes, ); - trace_inplace(module, &mut ct, module.log_n()-2, module.log_n(), &auto_keys, gct_cols, & mut tmp_bytes); + trace_inplace(module, &mut ct, 0, 4, &auto_keys, gct_cols, &mut tmp_bytes); + trace_inplace( + module, + &mut ct, + 4, + module.log_n(), + &auto_keys, + gct_cols, + &mut tmp_bytes, + ); // pt = dec(auto(ct)) - auto(pt) decrypt_rlwe( @@ -214,6 +244,5 @@ mod test { pt.at(0).decode_vec_i64(log_base2k, log_k, &mut data); println!("trace: {:?}", &data[..16]); - } }