From 7c25ad2eba1a58f876452e11122f9c0acac571f5 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Wed, 12 Feb 2025 16:49:53 +0100 Subject: [PATCH] fixed gadget product & related example --- base2k/examples/rlwe_encrypt.rs | 2 +- base2k/src/vec_znx_big.rs | 68 +++------------- rlwe/examples/gadget_product.rs | 139 +++++++++++++++++--------------- rlwe/src/ciphertext.rs | 2 +- rlwe/src/decryptor.rs | 22 ++--- rlwe/src/elem.rs | 2 +- rlwe/src/encryptor.rs | 4 +- rlwe/src/evaluator.rs | 51 +++++------- rlwe/src/keys.rs | 8 +- rlwe/src/plaintext.rs | 2 +- 10 files changed, 128 insertions(+), 172 deletions(-) diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index 085a379..be432fb 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -71,7 +71,7 @@ fn main() { module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft, b.limbs()); // buf_big <- a * s + b - module.vec_znx_big_add_small_inplace(&mut buf_big, &b, b.limbs()); + module.vec_znx_big_add_small_inplace(&mut buf_big, &b); // res <- normalize(buf_big) module.vec_znx_big_normalize(log_base2k, &mut res, &buf_big, &mut carry); diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index 461254d..a80a0fc 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -37,20 +37,13 @@ impl Module { // b <- b - a pub fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) { - let limbs: usize = a.limbs(); - assert!( - b.limbs() >= limbs, - "invalid c_vector: b.limbs()={} < a.limbs()={}", - b.limbs(), - limbs - ); unsafe { vec_znx_big::vec_znx_big_sub_small_a( self.0, b.0, b.limbs() as u64, a.as_ptr(), - limbs as u64, + a.limbs() as u64, a.n() as u64, b.0, b.limbs() as u64, @@ -60,26 +53,13 @@ impl Module { // c <- b - a pub fn vec_znx_big_sub_small_a(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig) { - let limbs: usize = a.limbs(); - assert!( - b.limbs() >= limbs, - "invalid c: b.limbs()={} < a.limbs()={}", - b.limbs(), - limbs - ); - assert!( - c.limbs() >= limbs, - "invalid c: c.limbs()={} < a.limbs()={}", - c.limbs(), - limbs - ); unsafe { vec_znx_big::vec_znx_big_sub_small_a( self.0, c.0, c.limbs() as u64, a.as_ptr(), - limbs as u64, + a.limbs() as u64, a.n() as u64, b.0, b.limbs() as u64, @@ -89,50 +69,31 @@ impl Module { // c <- b + a pub fn vec_znx_big_add_small(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig) { - let limbs: usize = a.limbs(); - assert!( - b.limbs() >= limbs, - "invalid c: b.limbs()={} < a.limbs()={}", - b.limbs(), - limbs - ); - assert!( - c.limbs() >= limbs, - "invalid c: c.limbs()={} < a.limbs()={}", - c.limbs(), - limbs - ); unsafe { vec_znx_big::vec_znx_big_add_small( self.0, c.0, - limbs as u64, + c.limbs() as u64, b.0, - limbs as u64, + a.limbs() as u64, a.as_ptr(), - limbs as u64, + b.limbs() as u64, a.n() as u64, ) } } // b <- b + a - pub fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx, a_limbs: usize) { - assert!( - b.limbs() >= a_limbs, - "invalid c_vector: b.limbs()={} < a.limbs()={}", - b.limbs(), - a_limbs - ); + pub fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) { unsafe { vec_znx_big::vec_znx_big_add_small( self.0, b.0, - a_limbs as u64, + b.limbs() as u64, b.0, - a_limbs as u64, + a.limbs() as u64, a.as_ptr(), - a_limbs as u64, + a.limbs() as u64, a.n() as u64, ) } @@ -150,13 +111,6 @@ impl Module { a: &VecZnxBig, tmp_bytes: &mut [u8], ) { - let limbs: usize = b.limbs(); - assert!( - b.limbs() >= limbs, - "invalid c_vector: b.limbs()={} < a.limbs()={}", - b.limbs(), - limbs - ); assert!( tmp_bytes.len() >= self.vec_znx_big_normalize_tmp_bytes(), "invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_normalize_tmp_bytes()={}", @@ -168,10 +122,10 @@ impl Module { self.0, log_base2k as u64, b.as_mut_ptr(), - limbs as u64, + b.limbs() as u64, b.n() as u64, a.0, - limbs as u64, + a.limbs() as u64, tmp_bytes.as_mut_ptr(), ) } diff --git a/rlwe/examples/gadget_product.rs b/rlwe/examples/gadget_product.rs index 928cf09..b597a50 100644 --- a/rlwe/examples/gadget_product.rs +++ b/rlwe/examples/gadget_product.rs @@ -1,9 +1,13 @@ -use base2k::{Encoding, FFT64, SvpPPolOps}; +use base2k::{FFT64, Infos, Sampling, Scalar, SvpPPolOps, VecZnx, VecZnxBig, VecZnxDft, VecZnxOps}; use rlwe::{ - ciphertext::Ciphertext, - decryptor::{decrypt_rlwe_thread_safe_tmp_byte, Decryptor}, - encryptor::{encrypt_grlwe_sk_tmp_bytes, encrypt_rlwe_sk_tmp_bytes, EncryptorSk}, - evaluator::{gadget_product_thread_safe, gadget_product_inplace_thread_safe, gadget_product_tmp_bytes}, + ciphertext::{Ciphertext, GadgetCiphertext}, + decryptor::{Decryptor, decrypt_rlwe_thread_safe, decrypt_rlwe_thread_safe_tmp_byte}, + elem::Elem, + encryptor::{ + EncryptorSk, encrypt_grlwe_sk_thread_safe, encrypt_grlwe_sk_tmp_bytes, + encrypt_rlwe_sk_tmp_bytes, + }, + evaluator::{gadget_product_thread_safe, gadget_product_tmp_bytes}, key_generator::{gen_switching_key_thread_safe, gen_switching_key_thread_safe_tmp_bytes}, keys::{SecretKey, SwitchingKey}, parameters::{Parameters, ParametersLiteral}, @@ -41,19 +45,20 @@ fn main() { params.log_q(), params.limbs_q(), params.log_qp() - ) | encrypt_grlwe_sk_tmp_bytes(params.module(), params.log_base2k(), params.limbs_qp(), params.log_qp()) + ) + | encrypt_grlwe_sk_tmp_bytes( + params.module(), + params.log_base2k(), + params.limbs_qp(), + params.log_qp() + ) ]; - println!("limbsQP: {}", params.limbs_qp()); - let mut source: Source = Source::new([3; 32]); - let mut sk0: SecretKey = SecretKey::new(params.module()); - let mut sk1: SecretKey = SecretKey::new(params.module()); - - sk0.fill_ternary_hw(params.xs(), &mut source); - sk1.fill_ternary_hw(params.xs(), &mut source); + let mut sk: SecretKey = SecretKey::new(params.module()); + sk.fill_ternary_hw(params.xs(), &mut source); let mut want = vec![i64::default(); params.n()]; @@ -61,82 +66,82 @@ fn main() { let log_base2k = params.log_base2k(); - let log_k: usize = params.log_q() - 2*log_base2k; - - let mut ct: Ciphertext = params.new_ciphertext(params.log_qp()); + let log_k: usize = params.log_q() - 2 * log_base2k; let mut source_xe: Source = Source::new([4; 32]); let mut source_xa: Source = Source::new([5; 32]); - let mut sk0_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol(); - params.module().svp_prepare(&mut sk0_svp_ppol, &sk0.0); + let mut sk_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol(); + params.module().svp_prepare(&mut sk_svp_ppol, &sk.0); - let mut sk1_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol(); - params.module().svp_prepare(&mut sk1_svp_ppol, &sk1.0); - - let mut pt_out: Plaintext = Plaintext::new(params.module(), params.log_base2k(), params.log_qp(), ct.log_scale()); - - //pt_out.0.value[0].encode_vec_i64(log_base2k, log_k, &want, 32); - //pt_out.0.value[0].normalize(log_base2k, &mut tmp_bytes); - - params.encrypt_rlwe_sk_thread_safe( - &mut ct, - Some(&pt_out), - &sk0_svp_ppol, - &mut source_xa, - &mut source_xe, - &mut tmp_bytes, - ); - - params.decrypt_rlwe_thread_safe(&mut pt_out, &ct, &sk0_svp_ppol, &mut tmp_bytes); - - println!("DECRYPT"); - pt_out.0.value[0].print_limbs(pt_out.limbs(), 16); - - - let mut swk: SwitchingKey = SwitchingKey::new( + let mut gadget_ct: GadgetCiphertext = GadgetCiphertext::new( params.module(), - params.log_base2k(), - params.limbs_qp(), + log_base2k, + params.limbs_q(), params.log_qp(), ); - gen_switching_key_thread_safe( + let mut m: Scalar = Scalar::new(params.n()); + m.fill_ternary_prob(0.5, &mut source_xa); + + encrypt_grlwe_sk_thread_safe( params.module(), - &mut swk, - &sk0, - &sk1_svp_ppol, + &mut gadget_ct, + &m, + &sk_svp_ppol, &mut source_xa, &mut source_xe, params.xe(), &mut tmp_bytes, ); - println!("{}", swk.cols()); + let mut res: Elem = Elem::new(params.module(), log_base2k, params.log_q(), 1, 0); + let mut a: VecZnx = VecZnx::new(params.module().n(), params.limbs_q()); + a.fill_uniform(params.log_base2k(), a.limbs(), &mut source_xa); + gadget_product_thread_safe(params.module(), &mut res, &a, &gadget_ct, &mut tmp_bytes); - let mut ct_out: Ciphertext = Ciphertext::new(params.module(), params.log_base2k(), params.log_q()+17, 1, ct.log_scale()); + println!("a.limbs()={}", a.limbs()); + println!("gadget_ct.rows()={}", gadget_ct.rows()); + println!("gadget_ct.cols()={}", gadget_ct.cols()); + println!("res.limbs()={}", res.limbs()); + println!(); - gadget_product_thread_safe(params.module(), &mut ct_out, &ct, &swk.0, &mut tmp_bytes); + println!("a:"); + a.print_limbs(a.limbs(), 16); + println!(); - pt_out.zero(); + println!("m:"); + println!("{:?}", &m.0[..16]); + println!(); - params.decrypt_rlwe_thread_safe(&mut pt_out, &ct_out, &sk1_svp_ppol, &mut tmp_bytes); + let mut a_res: Elem = Elem::new(params.module(), params.log_base2k(), params.log_q(), 0, 0); - pt_out.0.value[0].print_limbs(pt_out.limbs(), 16); + decrypt_rlwe_thread_safe( + params.module(), + &mut a_res, + &res, + &sk_svp_ppol, + &mut tmp_bytes, + ); - let mut have = vec![i64::default(); params.n()]; + let mut m_svp_ppol = params.module().new_svp_ppol(); + params.module().svp_prepare(&mut m_svp_ppol, &m); - //println!("pt_out: {}", log_k); - //pt_out.0.value[0].decode_vec_i64(pt_out.log_base2k(), log_k, &mut have); + let mut a_dft: VecZnxDft = params.module().new_vec_znx_dft(a.limbs()); + let mut a_big: VecZnxBig = a_dft.as_vec_znx_big(); - //println!("want: {:?}", &want[..16]); - //println!("have: {:?}", &have[..16]); - + params + .module() + .svp_apply_dft(&mut a_dft, &m_svp_ppol, &a, a.limbs()); + params + .module() + .vec_znx_idft_tmp_a(&mut a_big, &mut a_dft, a.limbs()); + params + .module() + .vec_znx_big_normalize(params.log_base2k(), &mut a, &a_big, &mut tmp_bytes); + + params.module().vec_znx_sub_inplace(&mut a, &a_res.value[0]); + + println!("a*m - dec(a * GRLWE(m))"); + a.print_limbs(a.limbs(), 16); } - - -pub fn cast_mut_u64_to_mut_u8_slice(data: &mut [u64]) -> &mut [u8] { - let ptr: *mut u8 = data.as_mut_ptr() as *mut u8; - let len: usize = data.len() * std::mem::size_of::(); - unsafe { std::slice::from_raw_parts_mut(ptr, len) } -} \ No newline at end of file diff --git a/rlwe/src/ciphertext.rs b/rlwe/src/ciphertext.rs index 5d72cf6..815e560 100644 --- a/rlwe/src/ciphertext.rs +++ b/rlwe/src/ciphertext.rs @@ -48,7 +48,7 @@ impl Ciphertext { self.0.log_scale } - pub fn zero(&mut self){ + pub fn zero(&mut self) { self.0.zero() } diff --git a/rlwe/src/decryptor.rs b/rlwe/src/decryptor.rs index 027b8bf..6b74059 100644 --- a/rlwe/src/decryptor.rs +++ b/rlwe/src/decryptor.rs @@ -1,5 +1,9 @@ use crate::{ - ciphertext::{Ciphertext, GadgetCiphertext}, elem::Elem, keys::SecretKey, parameters::Parameters, plaintext::Plaintext + ciphertext::{Ciphertext, GadgetCiphertext}, + elem::Elem, + keys::SecretKey, + parameters::Parameters, + plaintext::Plaintext, }; use base2k::{Module, SvpPPol, SvpPPolOps, VecZnxDft}; use std::cmp::min; @@ -46,26 +50,24 @@ pub fn decrypt_rlwe_thread_safe( sk: &SvpPPol, tmp_bytes: &mut [u8], ) { - let limbs: usize = min(res.limbs(), a.limbs()); - assert!( - tmp_bytes.len() >= decrypt_rlwe_thread_safe_tmp_byte(module, limbs), + tmp_bytes.len() >= decrypt_rlwe_thread_safe_tmp_byte(module, a.limbs()), "invalid tmp_bytes: tmp_bytes.len()={} < decrypt_rlwe_thread_safe_tmp_byte={}", tmp_bytes.len(), - decrypt_rlwe_thread_safe_tmp_byte(module, limbs) + decrypt_rlwe_thread_safe_tmp_byte(module, a.limbs()) ); - let res_dft_bytes: usize = module.bytes_of_vec_znx_dft(limbs); + let res_dft_bytes: usize = module.bytes_of_vec_znx_dft(a.limbs()); - let mut res_dft: VecZnxDft = VecZnxDft::from_bytes(limbs, tmp_bytes); + let mut res_dft: VecZnxDft = VecZnxDft::from_bytes(a.limbs(), tmp_bytes); let mut res_big: base2k::VecZnxBig = res_dft.as_vec_znx_big(); // res_dft <- DFT(ct[1]) * DFT(sk) - module.svp_apply_dft(&mut res_dft, sk, &a.value[1], limbs); + module.svp_apply_dft(&mut res_dft, sk, &a.value[1], a.limbs()); // res_big <- ct[1] x sk - module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft, limbs); + module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft, a.limbs()); // res_big <- ct[1] x sk + ct[0] - module.vec_znx_big_add_small_inplace(&mut res_big, &a.value[0], limbs); + module.vec_znx_big_add_small_inplace(&mut res_big, &a.value[0]); // res <- normalize(ct[1] x sk + ct[0]) module.vec_znx_big_normalize( a.log_base2k(), diff --git a/rlwe/src/elem.rs b/rlwe/src/elem.rs index 53fb0c5..c57b53c 100644 --- a/rlwe/src/elem.rs +++ b/rlwe/src/elem.rs @@ -91,7 +91,7 @@ impl Elem { &mut self.value[i] } - pub fn zero(&mut self){ + pub fn zero(&mut self) { self.value.iter_mut().for_each(|i| i.zero()); } } diff --git a/rlwe/src/encryptor.rs b/rlwe/src/encryptor.rs index c5043d0..0e668fc 100644 --- a/rlwe/src/encryptor.rs +++ b/rlwe/src/encryptor.rs @@ -175,7 +175,7 @@ pub fn encrypt_rlwe_sk_thread_safe( } // c0 <- -s x c1 + m + e - //c0.add_normal(log_base2k, log_q, source_xe, sigma, (sigma * 6.0).ceil()); + c0.add_normal(log_base2k, log_q, source_xe, sigma, (sigma * 6.0).ceil()); } pub fn encrypt_grlwe_sk_tmp_bytes( @@ -245,6 +245,7 @@ pub fn encrypt_grlwe_sk_thread_safe( // Zeroes the ith-row of tmp_pt tmp_pt.0.value[0].at_mut(row_i).fill(0); + /* let mut res: Elem = Elem::new(module, log_base2k, log_q, 0, tmp_elem.log_scale); decrypt_rlwe_thread_safe(module, &mut res, &tmp_elem, sk, tmp_bytes_encrypt_sk); @@ -252,6 +253,7 @@ pub fn encrypt_grlwe_sk_thread_safe( println!("row:{}", row_i); res.value[0].print_limbs(res.limbs(), 16); println!(); + */ // GRLWE[row_i][0] = -as + m * 2^{-i*log_base2k} + e*2^{-log_q} module.vmp_prepare_row( diff --git a/rlwe/src/evaluator.rs b/rlwe/src/evaluator.rs index 9e27bdb..ba07d9d 100644 --- a/rlwe/src/evaluator.rs +++ b/rlwe/src/evaluator.rs @@ -1,5 +1,8 @@ -use crate::ciphertext::{Ciphertext, GadgetCiphertext}; -use base2k::{Module, VecZnxBig, VecZnxDft, VmpPMatOps}; +use crate::{ + ciphertext::{Ciphertext, GadgetCiphertext}, + elem::Elem, +}; +use base2k::{Infos, Module, VecZnx, VecZnxBig, VecZnxDft, VmpPMatOps}; pub fn gadget_product_tmp_bytes( module: &Module, @@ -16,34 +19,27 @@ pub fn gadget_product_tmp_bytes( + 2 * module.bytes_of_vec_znx_dft(gct_cols) } -pub fn gadget_product_inplace_thread_safe( - module: &Module, - a: &mut Ciphertext, - b: &GadgetCiphertext, - tmp_bytes: &mut [u8], -) { - // This is safe to do because the relevant values of a are copied to a buffer before being - // overwritten. - unsafe { - let a_ptr: *mut Ciphertext = a; - gadget_product_thread_safe(module, a, &*a_ptr, b, tmp_bytes) - } -} - +/// Evaluates the gadget product res <- a x b. +/// +/// # Arguments +/// +/// * `module`: backend support for operations mod (X^N + 1) +/// * `res`: an [Elem] to store (-cs + m * a + e, c) with res_ncols limbs. +/// * `a`: a [VecZnx] of a_ncols limbs. +/// * `b`: a [GadgetCiphertext] as a vector of (-Bs + m * 2^{-k} + E, B) +/// containing b_nrows [VecZnx], each of b_ncols limbs. +/// +/// # Computation +/// +/// res = sum[min(a_ncols, b_nrows)] decomp(a, i) * (-B[i]s + m * 2^{-k*i} + E[i], B[i]) +/// = (cs + m * a + e, c) with min(res_limbs, b_cols) limbs. pub fn gadget_product_thread_safe( module: &Module, - res: &mut Ciphertext, - a: &Ciphertext, + res: &mut Elem, + a: &VecZnx, b: &GadgetCiphertext, tmp_bytes: &mut [u8], ) { - assert!( - a.log_base2k() == b.log_base2k(), - "invalid inputs: a.log_base2k={} != b.log_base2k={}", - a.log_base2k(), - b.log_base2k() - ); - let log_base2k: usize = b.log_base2k(); let cols: usize = b.cols(); @@ -57,7 +53,7 @@ pub fn gadget_product_thread_safe( let mut res_big: VecZnxBig = res_dft.as_vec_znx_big(); // c1_dft <- DFT(c1) [cols] - module.vec_znx_dft(&mut c1_dft, a.at(1), a.limbs()); + module.vec_znx_dft(&mut c1_dft, a, a.limbs()); // res_dft <- sum[rows] DFT(c1)[cols] x GadgetCiphertext[0][cols] module.vmp_apply_dft_to_dft(&mut res_dft, &c1_dft, &b.value[0], tmp_bytes_vmp_apply_dft); @@ -65,9 +61,6 @@ pub fn gadget_product_thread_safe( // res_big <- IDFT(DFT(c1) x GadgetCiphertext[0]) module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft, cols); - // res_big <- c0 + c1_dft x GadgetCiphertext[0] - module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0), cols); - // res[0] = normalize(c0 + c1_dft x GadgetCiphertext[0]) module.vec_znx_big_normalize(log_base2k, res.at_mut(0), &res_big, tmp_bytes_vmp_apply_dft); diff --git a/rlwe/src/keys.rs b/rlwe/src/keys.rs index 6675028..93032f3 100644 --- a/rlwe/src/keys.rs +++ b/rlwe/src/keys.rs @@ -65,19 +65,19 @@ impl SwitchingKey { SwitchingKey(GadgetCiphertext::new(module, log_base2k, rows, log_q)) } - pub fn n(&self) -> usize{ + pub fn n(&self) -> usize { self.0.n() } - pub fn rows(&self) -> usize{ + pub fn rows(&self) -> usize { self.0.rows() } - pub fn cols(&self) -> usize{ + pub fn cols(&self) -> usize { self.0.cols() } - pub fn log_base2k(&self) -> usize{ + pub fn log_base2k(&self) -> usize { self.0.log_base2k() } } diff --git a/rlwe/src/plaintext.rs b/rlwe/src/plaintext.rs index 30fd85a..b31e297 100644 --- a/rlwe/src/plaintext.rs +++ b/rlwe/src/plaintext.rs @@ -64,7 +64,7 @@ impl Plaintext { self.0.log_scale() } - pub fn zero(&mut self){ + pub fn zero(&mut self) { self.0.zero() }