From dc2fab9a04e6b4b6dcc0ee4edd804c16a3f8a433 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Wed, 12 Feb 2025 11:40:36 +0100 Subject: [PATCH] wip on gadget product --- rlwe/examples/gadget_product.rs | 80 ++++++++++++++++++++------------- rlwe/src/ciphertext.rs | 4 ++ rlwe/src/decryptor.rs | 22 ++++----- rlwe/src/elem.rs | 4 ++ rlwe/src/encryptor.rs | 11 ++++- rlwe/src/evaluator.rs | 10 ++--- rlwe/src/keys.rs | 23 ++++++---- rlwe/src/plaintext.rs | 4 ++ 8 files changed, 102 insertions(+), 56 deletions(-) diff --git a/rlwe/examples/gadget_product.rs b/rlwe/examples/gadget_product.rs index 1a5e61d..928cf09 100644 --- a/rlwe/examples/gadget_product.rs +++ b/rlwe/examples/gadget_product.rs @@ -1,9 +1,9 @@ use base2k::{Encoding, FFT64, SvpPPolOps}; use rlwe::{ ciphertext::Ciphertext, - decryptor::{Decryptor, decrypt_rlwe_thread_safe_tmp_byte}, - encryptor::{EncryptorSk, encrypt_rlwe_sk_tmp_bytes}, - evaluator::{gadget_product_inplace, gadget_product_tmp_bytes}, + decryptor::{decrypt_rlwe_thread_safe_tmp_byte, Decryptor}, + encryptor::{encrypt_grlwe_sk_tmp_bytes, encrypt_rlwe_sk_tmp_bytes, EncryptorSk}, + evaluator::{gadget_product_thread_safe, gadget_product_inplace_thread_safe, gadget_product_tmp_bytes}, key_generator::{gen_switching_key_thread_safe, gen_switching_key_thread_safe_tmp_bytes}, keys::{SecretKey, SwitchingKey}, parameters::{Parameters, ParametersLiteral}, @@ -14,7 +14,7 @@ use sampling::source::{Source, new_seed}; fn main() { let params_lit: ParametersLiteral = ParametersLiteral { log_n: 10, - log_q: 54, + log_q: 68, log_p: 17, log_base2k: 17, log_scale: 20, @@ -40,11 +40,13 @@ fn main() { params.log_q(), params.log_q(), params.limbs_q(), - params.limbs_qp() - ) + params.log_qp() + ) | encrypt_grlwe_sk_tmp_bytes(params.module(), params.log_base2k(), params.limbs_qp(), params.log_qp()) ]; - let mut source: Source = Source::new([0; 32]); + println!("limbsQP: {}", params.limbs_qp()); + + let mut source: Source = Source::new([3; 32]); let mut sk0: SecretKey = SecretKey::new(params.module()); let mut sk1: SecretKey = SecretKey::new(params.module()); @@ -52,26 +54,19 @@ fn main() { sk0.fill_ternary_hw(params.xs(), &mut source); sk1.fill_ternary_hw(params.xs(), &mut source); + let mut want = vec![i64::default(); params.n()]; want.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); - let mut pt: Plaintext = params.new_plaintext(params.log_q()); + let log_base2k = params.log_base2k(); - let log_base2k = pt.log_base2k(); + let log_k: usize = params.log_q() - 2*log_base2k; + + let mut ct: Ciphertext = params.new_ciphertext(params.log_qp()); - let log_k: usize = params.log_q() - 20; - - pt.0.value[0].encode_vec_i64(log_base2k, log_k, &want, 32); - pt.0.value[0].normalize(log_base2k, &mut tmp_bytes); - - println!("log_k: {}", log_k); - pt.0.value[0].print_limbs(pt.limbs(), 16); - - let mut ct: Ciphertext = params.new_ciphertext(params.log_q()); - - let mut source_xe: Source = Source::new([1; 32]); - let mut source_xa: Source = Source::new([2; 32]); + let mut source_xe: Source = Source::new([4; 32]); + let mut source_xa: Source = Source::new([5; 32]); let mut sk0_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol(); params.module().svp_prepare(&mut sk0_svp_ppol, &sk0.0); @@ -79,19 +74,30 @@ fn main() { let mut sk1_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol(); params.module().svp_prepare(&mut sk1_svp_ppol, &sk1.0); + let mut pt_out: Plaintext = Plaintext::new(params.module(), params.log_base2k(), params.log_qp(), ct.log_scale()); + + //pt_out.0.value[0].encode_vec_i64(log_base2k, log_k, &want, 32); + //pt_out.0.value[0].normalize(log_base2k, &mut tmp_bytes); + params.encrypt_rlwe_sk_thread_safe( &mut ct, - Some(&pt), + Some(&pt_out), &sk0_svp_ppol, &mut source_xa, &mut source_xe, &mut tmp_bytes, ); + params.decrypt_rlwe_thread_safe(&mut pt_out, &ct, &sk0_svp_ppol, &mut tmp_bytes); + + println!("DECRYPT"); + pt_out.0.value[0].print_limbs(pt_out.limbs(), 16); + + let mut swk: SwitchingKey = SwitchingKey::new( params.module(), params.log_base2k(), - params.limbs_q(), + params.limbs_qp(), params.log_qp(), ); @@ -106,17 +112,31 @@ fn main() { &mut tmp_bytes, ); - gadget_product_inplace(params.module(), &mut ct, &swk.0, &mut tmp_bytes); + println!("{}", swk.cols()); - params.decrypt_rlwe_thread_safe(&mut pt, &ct, &sk1_svp_ppol, &mut tmp_bytes); + let mut ct_out: Ciphertext = Ciphertext::new(params.module(), params.log_base2k(), params.log_q()+17, 1, ct.log_scale()); - pt.0.value[0].print_limbs(pt.limbs(), 16); + gadget_product_thread_safe(params.module(), &mut ct_out, &ct, &swk.0, &mut tmp_bytes); + + pt_out.zero(); + + params.decrypt_rlwe_thread_safe(&mut pt_out, &ct_out, &sk1_svp_ppol, &mut tmp_bytes); + + pt_out.0.value[0].print_limbs(pt_out.limbs(), 16); let mut have = vec![i64::default(); params.n()]; - println!("pt: {}", log_k); - pt.0.value[0].decode_vec_i64(pt.log_base2k(), log_k, &mut have); + //println!("pt_out: {}", log_k); + //pt_out.0.value[0].decode_vec_i64(pt_out.log_base2k(), log_k, &mut have); - println!("want: {:?}", &want[..16]); - println!("have: {:?}", &have[..16]); + //println!("want: {:?}", &want[..16]); + //println!("have: {:?}", &have[..16]); + } + + +pub fn cast_mut_u64_to_mut_u8_slice(data: &mut [u64]) -> &mut [u8] { + let ptr: *mut u8 = data.as_mut_ptr() as *mut u8; + let len: usize = data.len() * std::mem::size_of::(); + unsafe { std::slice::from_raw_parts_mut(ptr, len) } +} \ No newline at end of file diff --git a/rlwe/src/ciphertext.rs b/rlwe/src/ciphertext.rs index 3daa076..5d72cf6 100644 --- a/rlwe/src/ciphertext.rs +++ b/rlwe/src/ciphertext.rs @@ -48,6 +48,10 @@ impl Ciphertext { self.0.log_scale } + pub fn zero(&mut self){ + self.0.zero() + } + pub fn as_plaintext(&self) -> Plaintext { unsafe { Plaintext(std::ptr::read(&self.0)) } } diff --git a/rlwe/src/decryptor.rs b/rlwe/src/decryptor.rs index 727dba5..027b8bf 100644 --- a/rlwe/src/decryptor.rs +++ b/rlwe/src/decryptor.rs @@ -1,5 +1,5 @@ use crate::{ - ciphertext::Ciphertext, keys::SecretKey, parameters::Parameters, plaintext::Plaintext, + ciphertext::{Ciphertext, GadgetCiphertext}, elem::Elem, keys::SecretKey, parameters::Parameters, plaintext::Plaintext }; use base2k::{Module, SvpPPol, SvpPPolOps, VecZnxDft}; use std::cmp::min; @@ -35,18 +35,18 @@ impl Parameters { sk: &SvpPPol, tmp_bytes: &mut [u8], ) { - decrypt_rlwe_thread_safe(self.module(), res, ct, sk, tmp_bytes) + decrypt_rlwe_thread_safe(self.module(), &mut res.0, &ct.0, sk, tmp_bytes) } } pub fn decrypt_rlwe_thread_safe( module: &Module, - res: &mut Plaintext, - ct: &Ciphertext, + res: &mut Elem, + a: &Elem, sk: &SvpPPol, tmp_bytes: &mut [u8], ) { - let limbs: usize = min(res.limbs(), ct.limbs()); + let limbs: usize = min(res.limbs(), a.limbs()); assert!( tmp_bytes.len() >= decrypt_rlwe_thread_safe_tmp_byte(module, limbs), @@ -61,20 +61,20 @@ pub fn decrypt_rlwe_thread_safe( let mut res_big: base2k::VecZnxBig = res_dft.as_vec_znx_big(); // res_dft <- DFT(ct[1]) * DFT(sk) - module.svp_apply_dft(&mut res_dft, sk, &ct.0.value[1], limbs); + module.svp_apply_dft(&mut res_dft, sk, &a.value[1], limbs); // res_big <- ct[1] x sk module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft, limbs); // res_big <- ct[1] x sk + ct[0] - module.vec_znx_big_add_small_inplace(&mut res_big, &ct.0.value[0], limbs); + module.vec_znx_big_add_small_inplace(&mut res_big, &a.value[0], limbs); // res <- normalize(ct[1] x sk + ct[0]) module.vec_znx_big_normalize( - ct.log_base2k(), + a.log_base2k(), res.at_mut(0), &res_big, &mut tmp_bytes[res_dft_bytes..], ); - res.0.log_base2k = ct.log_base2k(); - res.0.log_q = min(res.log_q(), ct.log_q()); - res.0.log_scale = ct.log_scale(); + res.log_base2k = a.log_base2k(); + res.log_q = min(res.log_q(), a.log_q()); + res.log_scale = a.log_scale(); } diff --git a/rlwe/src/elem.rs b/rlwe/src/elem.rs index 570833c..53fb0c5 100644 --- a/rlwe/src/elem.rs +++ b/rlwe/src/elem.rs @@ -90,6 +90,10 @@ impl Elem { assert!(i <= self.degree()); &mut self.value[i] } + + pub fn zero(&mut self){ + self.value.iter_mut().for_each(|i| i.zero()); + } } impl Parameters { diff --git a/rlwe/src/encryptor.rs b/rlwe/src/encryptor.rs index 9e57c15..c5043d0 100644 --- a/rlwe/src/encryptor.rs +++ b/rlwe/src/encryptor.rs @@ -1,4 +1,5 @@ use crate::ciphertext::{Ciphertext, GadgetCiphertext}; +use crate::decryptor::decrypt_rlwe_thread_safe; use crate::elem::Elem; use crate::keys::SecretKey; use crate::parameters::Parameters; @@ -174,7 +175,7 @@ pub fn encrypt_rlwe_sk_thread_safe( } // c0 <- -s x c1 + m + e - c0.add_normal(log_base2k, log_q, source_xe, sigma, (sigma * 6.0).ceil()); + //c0.add_normal(log_base2k, log_q, source_xe, sigma, (sigma * 6.0).ceil()); } pub fn encrypt_grlwe_sk_tmp_bytes( @@ -244,6 +245,14 @@ pub fn encrypt_grlwe_sk_thread_safe( // Zeroes the ith-row of tmp_pt tmp_pt.0.value[0].at_mut(row_i).fill(0); + let mut res: Elem = Elem::new(module, log_base2k, log_q, 0, tmp_elem.log_scale); + + decrypt_rlwe_thread_safe(module, &mut res, &tmp_elem, sk, tmp_bytes_encrypt_sk); + + println!("row:{}", row_i); + res.value[0].print_limbs(res.limbs(), 16); + println!(); + // GRLWE[row_i][0] = -as + m * 2^{-i*log_base2k} + e*2^{-log_q} module.vmp_prepare_row( &mut ct.value[0], diff --git a/rlwe/src/evaluator.rs b/rlwe/src/evaluator.rs index af9c041..9e27bdb 100644 --- a/rlwe/src/evaluator.rs +++ b/rlwe/src/evaluator.rs @@ -16,7 +16,7 @@ pub fn gadget_product_tmp_bytes( + 2 * module.bytes_of_vec_znx_dft(gct_cols) } -pub fn gadget_product_inplace( +pub fn gadget_product_inplace_thread_safe( module: &Module, a: &mut Ciphertext, b: &GadgetCiphertext, @@ -26,11 +26,11 @@ pub fn gadget_product_inplace( // overwritten. unsafe { let a_ptr: *mut Ciphertext = a; - gadget_product(module, a, &*a_ptr, b, tmp_bytes) + gadget_product_thread_safe(module, a, &*a_ptr, b, tmp_bytes) } } -pub fn gadget_product( +pub fn gadget_product_thread_safe( module: &Module, res: &mut Ciphertext, a: &Ciphertext, @@ -56,10 +56,10 @@ pub fn gadget_product( let mut res_dft: VecZnxDft = module.new_vec_znx_from_bytes(cols, tmp_bytes_res_dft); let mut res_big: VecZnxBig = res_dft.as_vec_znx_big(); - // c1_dft <- DFT(b[1]) + // c1_dft <- DFT(c1) [cols] module.vec_znx_dft(&mut c1_dft, a.at(1), a.limbs()); - // res_dft <- DFT(c1) x GadgetCiphertext[0] + // res_dft <- sum[rows] DFT(c1)[cols] x GadgetCiphertext[0][cols] module.vmp_apply_dft_to_dft(&mut res_dft, &c1_dft, &b.value[0], tmp_bytes_vmp_apply_dft); // res_big <- IDFT(DFT(c1) x GadgetCiphertext[0]) diff --git a/rlwe/src/keys.rs b/rlwe/src/keys.rs index 9955596..6675028 100644 --- a/rlwe/src/keys.rs +++ b/rlwe/src/keys.rs @@ -65,14 +65,19 @@ impl SwitchingKey { SwitchingKey(GadgetCiphertext::new(module, log_base2k, rows, log_q)) } - pub fn gen_thread_safe( - &mut self, - params: &mut Parameters, - sk_in: &SvpPPol, - sk_out: &SvpPPol, - xa_source: &mut Source, - xe_source: &mut Source, - tmp_bytes: &mut [u8], - ) { + pub fn n(&self) -> usize{ + self.0.n() + } + + pub fn rows(&self) -> usize{ + self.0.rows() + } + + pub fn cols(&self) -> usize{ + self.0.cols() + } + + pub fn log_base2k(&self) -> usize{ + self.0.log_base2k() } } diff --git a/rlwe/src/plaintext.rs b/rlwe/src/plaintext.rs index 174a07d..30fd85a 100644 --- a/rlwe/src/plaintext.rs +++ b/rlwe/src/plaintext.rs @@ -64,6 +64,10 @@ impl Plaintext { self.0.log_scale() } + pub fn zero(&mut self){ + self.0.zero() + } + pub fn as_ciphertext(&self) -> Ciphertext { unsafe { Ciphertext(std::ptr::read(&self.0)) } }