From 1f52a3d266870e419ab09c95654003321eb1f75b Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Wed, 12 Feb 2025 08:25:38 +0100 Subject: [PATCH] fixed sampling & rlwe encryption --- base2k/src/sampling.rs | 2 +- rlwe/examples/encryption.rs | 10 +++++--- rlwe/examples/gadget_product.rs | 10 ++++++-- rlwe/src/encryptor.rs | 23 +++++++++++------ rlwe/src/key_generator.rs | 45 +++++++++++++++++++-------------- rlwe/src/keys.rs | 6 ++--- 6 files changed, 60 insertions(+), 36 deletions(-) diff --git a/base2k/src/sampling.rs b/base2k/src/sampling.rs index 2498825..14e5dd6 100644 --- a/base2k/src/sampling.rs +++ b/base2k/src/sampling.rs @@ -33,7 +33,7 @@ impl Sampling for VecZnx { let mask: u64 = base2k - 1; let base2k_half: i64 = (base2k >> 1) as i64; - let size: usize = self.n() * (limbs - 1); + let size: usize = self.n() * limbs; self.data[..size] .iter_mut() diff --git a/rlwe/examples/encryption.rs b/rlwe/examples/encryption.rs index 40397f4..c215adc 100644 --- a/rlwe/examples/encryption.rs +++ b/rlwe/examples/encryption.rs @@ -28,7 +28,10 @@ fn main() { | params.encrypt_rlwe_sk_tmp_bytes(params.log_q()) ]; - let sk: SecretKey = SecretKey::new(params.module()); + let mut source: Source = Source::new([0; 32]); + let mut sk: SecretKey = SecretKey::new(params.module()); + //sk.fill_ternary_hw(params.xs(), &mut source); + sk.0.0[0] = 1; let mut want = vec![i64::default(); params.n()]; @@ -45,11 +48,12 @@ fn main() { println!("log_k: {}", log_k); pt.0.value[0].print_limbs(pt.limbs(), 16); + println!(); let mut ct: Ciphertext = params.new_ciphertext(params.log_q()); - let mut source_xe: Source = Source::new(new_seed()); - let mut source_xa: Source = Source::new(new_seed()); + let mut source_xe: Source = Source::new([1; 32]); + let mut source_xa: Source = Source::new([2; 32]); let mut sk_svp_ppol: base2k::SvpPPol = params.module().svp_new_ppol(); params.module().svp_prepare(&mut sk_svp_ppol, &sk.0); diff --git a/rlwe/examples/gadget_product.rs b/rlwe/examples/gadget_product.rs index 40397f4..04b6aa5 100644 --- a/rlwe/examples/gadget_product.rs +++ b/rlwe/examples/gadget_product.rs @@ -28,7 +28,13 @@ fn main() { | params.encrypt_rlwe_sk_tmp_bytes(params.log_q()) ]; - let sk: SecretKey = SecretKey::new(params.module()); + let mut source: Source = Source::new([0; 32]); + + let mut sk0: SecretKey = SecretKey::new(params.module()); + let mut sk1: SecretKey = SecretKey::new(params.module()); + + sk0.fill_ternary_hw(params.xs(), &mut source); + sk1.fill_ternary_hw(params.xs(), &mut source); let mut want = vec![i64::default(); params.n()]; @@ -52,7 +58,7 @@ fn main() { let mut source_xa: Source = Source::new(new_seed()); let mut sk_svp_ppol: base2k::SvpPPol = params.module().svp_new_ppol(); - params.module().svp_prepare(&mut sk_svp_ppol, &sk.0); + params.module().svp_prepare(&mut sk_svp_ppol, &sk0.0); params.encrypt_rlwe_sk_thread_safe( &mut ct, diff --git a/rlwe/src/encryptor.rs b/rlwe/src/encryptor.rs index 06ecf0f..5b21c7f 100644 --- a/rlwe/src/encryptor.rs +++ b/rlwe/src/encryptor.rs @@ -145,28 +145,35 @@ pub fn encrypt_rlwe_sk_thread_safe( // c1 <- Z_{2^prec}[X]/(X^{N}+1) c1.fill_uniform(log_base2k, limbs, source_xa); - let bytes_of_vec_znx_dft = module.bytes_of_vec_znx_dft(limbs); + let bytes_of_vec_znx_dft: usize = module.bytes_of_vec_znx_dft(limbs); // Scratch space for DFT values let mut buf_dft: VecZnxDft = VecZnxDft::from_bytes(limbs, &mut tmp_bytes[..bytes_of_vec_znx_dft]); - // Applies buf_dft <- s * c1 + // Applies buf_dft <- DFT(s) * DFT(c1) module.svp_apply_dft(&mut buf_dft, sk, c1, limbs); // Alias scratch space let mut buf_big: VecZnxBig = buf_dft.as_vec_znx_big(); - if let Some(pt) = pt { - // buf_big <- m - buf_big - module.vec_znx_big_sub_small_a_inplace(&mut buf_big, pt.at(0)); - }; + // buf_big = s x c1 + module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft, limbs); let carry: &mut [u8] = &mut tmp_bytes[bytes_of_vec_znx_dft..]; - // c0 <- normalize(buf_big) + e + // c0 <- -s x c1 + m let c0: &mut VecZnx = ct.at_mut(0); - module.vec_znx_big_normalize(log_base2k, c0, &buf_big, carry); + + if let Some(pt) = pt { + module.vec_znx_big_sub_small_a_inplace(&mut buf_big, pt.at(0)); + module.vec_znx_big_normalize(log_base2k, c0, &buf_big, carry); + } else { + module.vec_znx_big_normalize(log_base2k, c0, &buf_big, carry); + module.vec_znx_negate_inplace(c0); + } + + // c0 <- -s x c1 + m + e c0.add_normal(log_base2k, log_q, source_xe, sigma, (sigma * 6.0).ceil()); } diff --git a/rlwe/src/key_generator.rs b/rlwe/src/key_generator.rs index 00710ba..06c09ce 100644 --- a/rlwe/src/key_generator.rs +++ b/rlwe/src/key_generator.rs @@ -1,7 +1,7 @@ -use crate::encryptor::encrypt_grlwe_sk_thread_safe; +use crate::encryptor::{encrypt_grlwe_sk_thread_safe, encrypt_grlwe_sk_tmp_bytes}; use crate::keys::{PublicKey, SecretKey, SwitchingKey}; use crate::parameters::Parameters; -use base2k::SvpPPol; +use base2k::{Module, SvpPPol, SvpPPolOps}; use sampling::source::Source; pub struct KeyGenerator {} @@ -38,21 +38,28 @@ impl KeyGenerator { ); pk } - - pub fn gen_switching_key_thread_safe( - &self, - params: &Parameters, - sk_in: &SecretKey, - sk_out: &SecretKey, - rows: usize, - log_q: usize, - tmp_bytes: &mut [u8], - ) -> SwitchingKey { - let swk: SwitchingKey = SwitchingKey::new(params.module(), params.log_base2k(), rows, log_q, 0); - - let module: &base2k::Module = params.module(); - - encrypt_grlwe_sk_thread_safe(module, swk.0, &sk_in.0, sk_out, source_xa, source_xe, sigma, tmp_bytes); - swk - } +} + +pub fn gen_switching_key_thread_safe_tmp_bytes( + module: &Module, + log_base2k: usize, + rows: usize, + log_q: usize, +) -> usize { + encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q) +} + +pub fn gen_switching_key_thread_safe( + module: &Module, + swk: &mut SwitchingKey, + sk_in: &SecretKey, + sk_out: &SvpPPol, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + tmp_bytes: &mut [u8], +) { + encrypt_grlwe_sk_thread_safe( + module, &mut swk.0, &sk_in.0, sk_out, source_xa, source_xe, sigma, tmp_bytes, + ); } diff --git a/rlwe/src/keys.rs b/rlwe/src/keys.rs index cf4edc0..9247eb4 100644 --- a/rlwe/src/keys.rs +++ b/rlwe/src/keys.rs @@ -8,8 +8,8 @@ use sampling::source::Source; pub struct SecretKey(pub Scalar); impl SecretKey { - pub fn new(params: &Module) -> Self { - SecretKey(Scalar::new(params.n())) + pub fn new(module: &Module) -> Self { + SecretKey(Scalar::new(module.n())) } pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) { @@ -58,7 +58,7 @@ impl PublicKey { } } -pub struct SwitchingKey(GadgetCiphertext); +pub struct SwitchingKey(pub GadgetCiphertext); impl SwitchingKey { pub fn new(