diff --git a/rlwe/benches/gadget_product.rs b/rlwe/benches/gadget_product.rs index df62299..033e2a1 100644 --- a/rlwe/benches/gadget_product.rs +++ b/rlwe/benches/gadget_product.rs @@ -1,33 +1,51 @@ use base2k::{ - Encoding, FFT64, Infos, Module, Sampling, Scalar, SvpPPolOps, VecZnx, VecZnxBig, VecZnxDft, - VecZnxOps, + FFT64, Module, Sampling, SvpPPolOps, VecZnx, VecZnxApi, VecZnxBig, VecZnxDft, VecZnxDftOps, + VmpPMat, VmpPMatOps, alloc_aligned_u8, }; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; use rlwe::{ ciphertext::{Ciphertext, GadgetCiphertext}, - decryptor::{Decryptor, decrypt_rlwe_thread_safe, decrypt_rlwe_thread_safe_tmp_byte}, elem::Elem, - encryptor::{ - EncryptorSk, encrypt_grlwe_sk_thread_safe, encrypt_grlwe_sk_tmp_bytes, - encrypt_rlwe_sk_tmp_bytes, - }, - evaluator::{gadget_product_inplace_thread_safe, gadget_product_tmp_bytes}, - key_generator::{gen_switching_key_thread_safe, gen_switching_key_thread_safe_tmp_bytes}, - keys::{SecretKey, SwitchingKey}, + encryptor::{encrypt_grlwe_sk_thread_safe, encrypt_grlwe_sk_tmp_bytes}, + evaluator::gadget_product_tmp_bytes, + key_generator::gen_switching_key_thread_safe_tmp_bytes, + keys::SecretKey, parameters::{Parameters, ParametersLiteral}, - plaintext::Plaintext, }; use sampling::source::{Source, new_seed}; fn gadget_product_inplace(c: &mut Criterion) { fn gadget_product<'a>( module: &'a Module, - elem: &'a mut Elem, + elem: &'a mut Elem, gadget_ct: &'a GadgetCiphertext, tmp_bytes: &'a mut [u8], ) -> Box { + let factor: usize = 2; + + let log_base2k: usize = 32; + let limbs: usize = 2; + let rows: usize = factor * limbs; + let cols: usize = factor * limbs + 1; + + let pmat: VmpPMat = module.new_vmp_pmat(rows, cols); + + let mut tmp_bytes: Vec = + alloc_aligned_u8(module.vmp_apply_dft_tmp_bytes(cols, rows, rows, cols), 64); + + let mut a_dft: VecZnxDft = module.new_vec_znx_dft(rows); + let mut res_dft: VecZnxDft = module.new_vec_znx_dft(cols); + let mut res_big: VecZnxBig = res_dft.as_vec_znx_big(); + let mut a: VecZnx = VecZnx::new(module.n(), rows); + let mut source = Source::new(new_seed()); + module.fill_uniform(log_base2k, &mut a, limbs, &mut source); + Box::new(move || { - gadget_product_inplace_thread_safe::(module, elem, gadget_ct, tmp_bytes) + module.vec_znx_dft(&mut a_dft, &a, rows); + module.vmp_apply_dft_to_dft(&mut res_dft, &mut a_dft, &pmat, &mut tmp_bytes); + module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft, cols); + + //gadget_product_inplace_thread_safe::(module, elem, gadget_ct, tmp_bytes) }) } @@ -47,15 +65,14 @@ fn gadget_product_inplace(c: &mut Criterion) { let params: Parameters = Parameters::new::(¶ms_lit); - let mut tmp_bytes: Vec = vec![ - 0u8; + let mut tmp_bytes: Vec = alloc_aligned_u8( params.decrypt_rlwe_thread_safe_tmp_byte(params.log_q()) | params.encrypt_rlwe_sk_tmp_bytes(params.log_q()) | gen_switching_key_thread_safe_tmp_bytes( params.module(), params.log_base2k(), params.limbs_q(), - params.log_q() + params.log_q(), ) | gadget_product_tmp_bytes( params.module(), @@ -63,15 +80,16 @@ fn gadget_product_inplace(c: &mut Criterion) { params.log_q(), params.log_q(), params.limbs_q(), - params.log_qp() + params.log_qp(), ) | encrypt_grlwe_sk_tmp_bytes( params.module(), params.log_base2k(), params.limbs_qp(), - params.log_qp() - ) - ]; + params.log_qp(), + ), + 64, + ); let mut source: Source = Source::new([3; 32]); diff --git a/rlwe/examples/encryption.rs b/rlwe/examples/encryption.rs index f3dceca..8b692ad 100644 --- a/rlwe/examples/encryption.rs +++ b/rlwe/examples/encryption.rs @@ -1,4 +1,4 @@ -use base2k::{Encoding, FFT64, SvpPPolOps}; +use base2k::{Encoding, FFT64, SvpPPolOps, VecZnxApi, VecZnx}; use rlwe::{ ciphertext::Ciphertext, decryptor::{Decryptor, decrypt_rlwe_thread_safe_tmp_byte}, @@ -37,7 +37,7 @@ fn main() { want.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); - let mut pt: Plaintext = params.new_plaintext(params.log_q()); + let mut pt: Plaintext = params.new_plaintext(params.log_q()); let log_base2k = pt.log_base2k(); diff --git a/rlwe/examples/gadget_product.rs b/rlwe/examples/gadget_product.rs index 8ebd0f1..f7e4da5 100644 --- a/rlwe/examples/gadget_product.rs +++ b/rlwe/examples/gadget_product.rs @@ -1,5 +1,6 @@ use base2k::{ - Encoding, FFT64, Infos, Sampling, Scalar, SvpPPolOps, VecZnx, VecZnxBig, VecZnxDft, VecZnxOps, + Encoding, FFT64, Infos, Sampling, Scalar, SvpPPolOps, VecZnx, VecZnxApi, VecZnxBig, VecZnxDft, + VecZnxOps, }; use rlwe::{ ciphertext::{Ciphertext, GadgetCiphertext}, @@ -19,13 +20,13 @@ use sampling::source::{Source, new_seed}; fn main() { let params_lit: ParametersLiteral = ParametersLiteral { - log_n: 10, + log_n: 4, log_q: 68, log_p: 17, log_base2k: 17, log_scale: 20, xe: 3.2, - xs: 128, + xs: 8, }; let params: Parameters = Parameters::new::(¶ms_lit); @@ -99,7 +100,9 @@ fn main() { &mut tmp_bytes, ); - let mut pt: Plaintext = Plaintext::new( + println!("DONE?"); + + let mut pt: Plaintext = Plaintext::::new( params.module(), params.log_base2k(), params.log_q(), @@ -122,7 +125,7 @@ fn main() { &mut tmp_bytes, ); - gadget_product_inplace_thread_safe::( + gadget_product_inplace_thread_safe::( params.module(), &mut ct.0, &gadget_ct, @@ -145,7 +148,7 @@ fn main() { pt.0.value[0].print_limbs(pt.limbs(), 16); - let mut have = vec![i64::default(); params.n()]; + let mut have: Vec = vec![i64::default(); params.n()]; println!("pt: {}", log_k); pt.0.value[0].decode_vec_i64(pt.log_base2k(), log_k, &mut have); diff --git a/rlwe/src/ciphertext.rs b/rlwe/src/ciphertext.rs index 815e560..9875744 100644 --- a/rlwe/src/ciphertext.rs +++ b/rlwe/src/ciphertext.rs @@ -1,9 +1,9 @@ -use crate::elem::Elem; +use crate::elem::{Elem, ElemBasics}; use crate::parameters::Parameters; use crate::plaintext::Plaintext; -use base2k::{Module, VecZnx, VmpPMat, VmpPMatOps}; +use base2k::{Infos, Module, VecZnx, VecZnxApi, VmpPMat, VmpPMatOps}; -pub struct Ciphertext(pub Elem); +pub struct Ciphertext(pub Elem); impl Ciphertext { pub fn new( @@ -32,11 +32,11 @@ impl Ciphertext { self.0.limbs() } - pub fn at(&self, i: usize) -> &VecZnx { + pub fn at(&self, i: usize) -> &(impl VecZnxApi + Infos) { self.0.at(i) } - pub fn at_mut(&mut self, i: usize) -> &mut VecZnx { + pub fn at_mut(&mut self, i: usize) -> &mut (impl VecZnxApi + Infos) { self.0.at_mut(i) } @@ -52,7 +52,7 @@ impl Ciphertext { self.0.zero() } - pub fn as_plaintext(&self) -> Plaintext { + pub fn as_plaintext(&self) -> Plaintext { unsafe { Plaintext(std::ptr::read(&self.0)) } } } @@ -64,7 +64,7 @@ impl Parameters { } pub struct GadgetCiphertext { - pub value: Vec, + pub value: VmpPMat, pub log_base2k: usize, pub log_q: usize, } @@ -72,29 +72,23 @@ pub struct GadgetCiphertext { impl GadgetCiphertext { pub fn new(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> Self { let cols: usize = (log_q + log_base2k - 1) / log_base2k; - let mut value: Vec = Vec::new(); - (0..rows).for_each(|_| value.push(module.new_vmp_pmat(rows, cols))); Self { - value, + value: module.new_vmp_pmat(rows, cols * 2), log_base2k, log_q, } } pub fn n(&self) -> usize { - self.value[0].n + self.value.n } pub fn rows(&self) -> usize { - self.value[0].rows + self.value.rows } pub fn cols(&self) -> usize { - self.value[0].cols - } - - pub fn degree(&self) -> usize { - self.value.len() - 1 + self.value.cols } pub fn log_q(&self) -> usize { @@ -107,8 +101,38 @@ impl GadgetCiphertext { } pub struct RGSWCiphertext { - pub value: [GadgetCiphertext; 2], + pub value: VmpPMat, pub log_base2k: usize, pub log_q: usize, - pub log_p: usize, +} + +impl RGSWCiphertext { + pub fn new(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> Self { + let cols: usize = (log_q + log_base2k - 1) / log_base2k; + Self { + value: module.new_vmp_pmat(rows * 2, cols * 2), + log_base2k, + log_q, + } + } + + pub fn n(&self) -> usize { + self.value.n + } + + pub fn rows(&self) -> usize { + self.value.rows + } + + pub fn cols(&self) -> usize { + self.value.cols + } + + pub fn log_q(&self) -> usize { + self.log_q + } + + pub fn log_base2k(&self) -> usize { + self.log_base2k + } } diff --git a/rlwe/src/decryptor.rs b/rlwe/src/decryptor.rs index 6b74059..bab6dc7 100644 --- a/rlwe/src/decryptor.rs +++ b/rlwe/src/decryptor.rs @@ -1,11 +1,13 @@ use crate::{ - ciphertext::{Ciphertext, GadgetCiphertext}, - elem::Elem, + ciphertext::Ciphertext, + elem::{Elem, ElemBasics}, keys::SecretKey, parameters::Parameters, plaintext::Plaintext, }; -use base2k::{Module, SvpPPol, SvpPPolOps, VecZnxDft}; +use base2k::{ + Infos, VecZnx, Module, SvpPPol, SvpPPolOps, VecZnxApi, VecZnxBigOps, VecZnxDft, VecZnxDftOps, +}; use std::cmp::min; pub struct Decryptor { @@ -34,7 +36,7 @@ impl Parameters { pub fn decrypt_rlwe_thread_safe( &self, - res: &mut Plaintext, + res: &mut Plaintext, ct: &Ciphertext, sk: &SvpPPol, tmp_bytes: &mut [u8], @@ -43,13 +45,15 @@ impl Parameters { } } -pub fn decrypt_rlwe_thread_safe( +pub fn decrypt_rlwe_thread_safe( module: &Module, - res: &mut Elem, - a: &Elem, + res: &mut Elem, + a: &Elem, sk: &SvpPPol, tmp_bytes: &mut [u8], -) { +) where + T: VecZnxApi + Infos, +{ assert!( tmp_bytes.len() >= decrypt_rlwe_thread_safe_tmp_byte(module, a.limbs()), "invalid tmp_bytes: tmp_bytes.len()={} < decrypt_rlwe_thread_safe_tmp_byte={}", diff --git a/rlwe/src/elem.rs b/rlwe/src/elem.rs index c57b53c..7a3ed63 100644 --- a/rlwe/src/elem.rs +++ b/rlwe/src/elem.rs @@ -1,14 +1,44 @@ use crate::parameters::Parameters; -use base2k::{Infos, Module, VecZnx, VecZnxOps}; +use base2k::{Infos, Module, VecZnx, VecZnxApi, VecZnxBorrow, VecZnxOps}; -pub struct Elem { - pub value: Vec, + +impl Parameters { + pub fn bytes_of_elem(&self, log_q: usize, degree: usize) -> usize { + Elem::::bytes_of(self.module(), self.log_base2k(), log_q, degree) + } + + pub fn elem_from_bytes(&self, log_q: usize, degree: usize, bytes: &mut [u8]) -> Elem { + Elem::::from_bytes(self.module(), self.log_base2k(), log_q, degree, bytes) + } + + pub fn elem_borrow_from_bytes(&self, log_q: usize, degree: usize, bytes: &mut [u8]) -> Elem { + Elem::::from_bytes(self.module(), self.log_base2k(), log_q, degree, bytes) + } +} + +pub struct Elem { + pub value: Vec, pub log_base2k: usize, pub log_q: usize, pub log_scale: usize, } -impl Elem { +pub trait ElemBasics +where + T: VecZnxApi + Infos, +{ + fn n(&self) -> usize; + fn degree(&self) -> usize; + fn limbs(&self) -> usize; + fn log_base2k(&self) -> usize; + fn log_scale(&self) -> usize; + fn log_q(&self) -> usize; + fn at(&self, i: usize) -> &T; + fn at_mut(&mut self, i: usize) -> &mut T; + fn zero(&mut self); +} + +impl Elem { pub fn new( module: &Module, log_base2k: usize, @@ -43,7 +73,7 @@ impl Elem { assert!(bytes.len() >= Self::bytes_of(module, log_base2k, log_q, degree)); let mut value: Vec = Vec::new(); let limbs: usize = (log_q + log_base2k - 1) / log_base2k; - let size = VecZnx::bytes(n, limbs); + let size = VecZnx::bytes_of(n, limbs); let mut ptr: usize = 0; (0..degree + 1).for_each(|_| { value.push(VecZnx::from_bytes(n, limbs, &mut bytes[ptr..])); @@ -56,52 +86,78 @@ impl Elem { log_scale: 0, } } +} - pub fn n(&self) -> usize { +impl Elem { + + pub fn bytes_of(module: &Module, log_base2k: usize, log_q: usize, degree: usize) -> usize { + let cols = (log_q + log_base2k - 1) / log_base2k; + module.n() * cols * (degree + 1) * 8 + } + + pub fn from_bytes( + module: &Module, + log_base2k: usize, + log_q: usize, + degree: usize, + bytes: &mut [u8], + ) -> Self { + let n: usize = module.n(); + assert!(bytes.len() >= Self::bytes_of(module, log_base2k, log_q, degree)); + let mut value: Vec = Vec::new(); + let limbs: usize = (log_q + log_base2k - 1) / log_base2k; + let size = VecZnxBorrow::bytes_of(n, limbs); + let mut ptr: usize = 0; + (0..degree + 1).for_each(|_| { + value.push(VecZnxBorrow::from_bytes(n, limbs, &mut bytes[ptr..])); + ptr += size + }); + Self { + value, + log_q, + log_base2k, + log_scale: 0, + } + } +} + + +impl ElemBasics for Elem { + fn n(&self) -> usize { self.value[0].n() } - pub fn degree(&self) -> usize { + fn degree(&self) -> usize { self.value.len() } - pub fn limbs(&self) -> usize { + fn limbs(&self) -> usize { self.value[0].limbs() } - pub fn log_base2k(&self) -> usize { + fn log_base2k(&self) -> usize { self.log_base2k } - pub fn log_scale(&self) -> usize { + fn log_scale(&self) -> usize { self.log_scale } - pub fn log_q(&self) -> usize { + fn log_q(&self) -> usize { self.log_q } - pub fn at(&self, i: usize) -> &VecZnx { + fn at(&self, i: usize) -> &T { assert!(i <= self.degree()); &self.value[i] } - pub fn at_mut(&mut self, i: usize) -> &mut VecZnx { + fn at_mut(&mut self, i: usize) -> &mut T { assert!(i <= self.degree()); &mut self.value[i] } - pub fn zero(&mut self) { + fn zero(&mut self) { self.value.iter_mut().for_each(|i| i.zero()); } } - -impl Parameters { - pub fn bytes_of_elem(&self, log_q: usize, degree: usize) -> usize { - Elem::bytes_of(self.module(), self.log_base2k(), log_q, degree) - } - - pub fn elem_from_bytes(&self, log_q: usize, degree: usize, bytes: &mut [u8]) -> Elem { - Elem::from_bytes(self.module(), self.log_base2k(), log_q, degree, bytes) - } -} diff --git a/rlwe/src/encryptor.rs b/rlwe/src/encryptor.rs index 0e668fc..265959a 100644 --- a/rlwe/src/encryptor.rs +++ b/rlwe/src/encryptor.rs @@ -1,14 +1,14 @@ use crate::ciphertext::{Ciphertext, GadgetCiphertext}; -use crate::decryptor::decrypt_rlwe_thread_safe; -use crate::elem::Elem; +use crate::elem::{Elem, ElemBasics}; use crate::keys::SecretKey; use crate::parameters::Parameters; use crate::plaintext::Plaintext; -use base2k::ffi::znx::znx_zero_i64_ref; use base2k::sampling::Sampling; use base2k::{ - Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxDft, VecZnxOps, VmpPMatOps, + cast_mut, Infos, VecZnxBorrow, Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxApi, VecZnxBig, VecZnxBigOps, + VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMatOps, alloc_aligned_u8, cast, }; +use rand_distr::num_traits::ops::bytes; use sampling::source::{Source, new_seed}; pub struct EncryptorSk { @@ -53,7 +53,7 @@ impl EncryptorSk { &mut self, params: &Parameters, ct: &mut Ciphertext, - pt: Option<&Plaintext>, + pt: Option<&Plaintext>, ) { assert!( self.initialized == true, @@ -73,7 +73,7 @@ impl EncryptorSk { &self, params: &Parameters, ct: &mut Ciphertext, - pt: Option<&Plaintext>, + pt: Option<&Plaintext>, source_xa: &mut Source, source_xe: &mut Source, tmp_bytes: &mut [u8], @@ -94,13 +94,13 @@ impl Parameters { pub fn encrypt_rlwe_sk_thread_safe( &self, ct: &mut Ciphertext, - pt: Option<&Plaintext>, + pt: Option<&Plaintext>, sk: &SvpPPol, source_xa: &mut Source, source_xe: &mut Source, tmp_bytes: &mut [u8], ) { - encrypt_rlwe_sk_thread_safe( + encrypt_rlwe_sk_thread_safe::( self.module(), &mut ct.0, pt.map(|pt| &pt.0), @@ -118,16 +118,18 @@ pub fn encrypt_rlwe_sk_tmp_bytes(module: &Module, log_base2k: usize, log_q: usiz + module.vec_znx_big_normalize_tmp_bytes() } -pub fn encrypt_rlwe_sk_thread_safe( +pub fn encrypt_rlwe_sk_thread_safe( module: &Module, - ct: &mut Elem, - pt: Option<&Elem>, + ct: &mut Elem, + pt: Option<&Elem>, sk: &SvpPPol, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, tmp_bytes: &mut [u8], -) { +) where + T: VecZnxApi + Infos, +{ let limbs: usize = ct.limbs(); let log_base2k: usize = ct.log_base2k(); let log_q: usize = ct.log_q(); @@ -141,10 +143,10 @@ pub fn encrypt_rlwe_sk_thread_safe( let log_q: usize = ct.log_q(); let log_base2k: usize = ct.log_base2k(); - let c1: &mut VecZnx = ct.at_mut(1); + let c1: &mut T = ct.at_mut(1); // c1 <- Z_{2^prec}[X]/(X^{N}+1) - c1.fill_uniform(log_base2k, limbs, source_xa); + module.fill_uniform(log_base2k, c1, limbs, source_xa); let bytes_of_vec_znx_dft: usize = module.bytes_of_vec_znx_dft(limbs); @@ -164,7 +166,7 @@ pub fn encrypt_rlwe_sk_thread_safe( let carry: &mut [u8] = &mut tmp_bytes[bytes_of_vec_znx_dft..]; // c0 <- -s x c1 + m - let c0: &mut VecZnx = ct.at_mut(0); + let c0: &mut T = ct.at_mut(0); if let Some(pt) = pt { module.vec_znx_big_sub_small_a_inplace(&mut buf_big, pt.at(0)); @@ -175,7 +177,14 @@ pub fn encrypt_rlwe_sk_thread_safe( } // c0 <- -s x c1 + m + e - c0.add_normal(log_base2k, log_q, source_xe, sigma, (sigma * 6.0).ceil()); + module.add_normal( + log_base2k, + c0, + log_q, + source_xe, + sigma, + (sigma * 6.0).ceil(), + ); } pub fn encrypt_grlwe_sk_tmp_bytes( @@ -185,10 +194,10 @@ pub fn encrypt_grlwe_sk_tmp_bytes( log_q: usize, ) -> usize { let cols = (log_q + log_base2k - 1) / log_base2k; - Elem::bytes_of(module, log_base2k, log_q, 1) - + Plaintext::bytes_of(module, log_base2k, log_q) + Elem::::bytes_of(module, log_base2k, log_q, 1) + + Plaintext::::bytes_of(module, log_base2k, log_q) + encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q) - + module.vmp_prepare_tmp_bytes(rows, cols) + + module.vmp_prepare_tmp_bytes(rows, 2 * cols) } pub fn encrypt_grlwe_sk_thread_safe( @@ -203,6 +212,7 @@ pub fn encrypt_grlwe_sk_thread_safe( ) { let rows: usize = ct.rows(); let log_q: usize = ct.log_q(); + let cols: usize = (log_q + ct.log_base2k() - 1) / ct.log_base2k(); let log_base2k: usize = ct.log_base2k(); let min_tmp_bytes_len = encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q); @@ -214,23 +224,24 @@ pub fn encrypt_grlwe_sk_thread_safe( min_tmp_bytes_len ); - let mut ptr: usize = 0; - let mut tmp_elem: Elem = Elem::from_bytes(module, log_base2k, ct.log_q(), 1, tmp_bytes); - let bytes_of_elem: usize = Elem::bytes_of(module, log_base2k, log_q, 1); - ptr += bytes_of_elem; + let bytes_of_elem: usize = Elem::::bytes_of(module, log_base2k, log_q, 1); + let bytes_of_pt: usize = Plaintext::::bytes_of(module, log_base2k, log_q); + let bytes_of_enc_sk: usize = encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q); + let bytes_of_vmp_prepare_row: usize = module.vmp_prepare_tmp_bytes(rows, 2 * cols); - let mut tmp_pt: Plaintext = - Plaintext::from_bytes(module, log_base2k, log_q, &mut tmp_bytes[ptr..]); - ptr += Plaintext::bytes_of(module, log_base2k, log_q); + let (tmp_bytes_pt, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_pt); + let (tmp_bytes_enc_sk, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_enc_sk); + let (tmp_bytes_elem, tmp_bytes_vmp_prepare_row) = tmp_bytes.split_at_mut(bytes_of_elem); - let (tmp_bytes_encrypt_sk, tmp_bytes_vmp_prepare_row) = - tmp_bytes[ptr..].split_at_mut(encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q)); + let mut tmp_elem: Elem = Elem::::from_bytes(module, log_base2k, ct.log_q(), 1, tmp_bytes_elem); + let mut tmp_pt: Plaintext = Plaintext::::from_bytes(module, log_base2k, log_q, tmp_bytes_pt); (0..rows).for_each(|row_i| { // Sets the i-th row of the RLWE sample to m (i.e. m * 2^{-log_base2k*i}) - tmp_pt.0.value[0].at_mut(row_i).copy_from_slice(&m.0); + tmp_pt.0.value[0].at_mut(row_i).copy_from_slice(&m.0); // Encrypts RLWE(m * 2^{-log_base2k*i}) + encrypt_rlwe_sk_thread_safe( module, &mut tmp_elem, @@ -239,36 +250,26 @@ pub fn encrypt_grlwe_sk_thread_safe( source_xa, source_xe, sigma, - tmp_bytes_encrypt_sk, + tmp_bytes_enc_sk, ); + // Zeroes the ith-row of tmp_pt tmp_pt.0.value[0].at_mut(row_i).fill(0); - /* - let mut res: Elem = Elem::new(module, log_base2k, log_q, 0, tmp_elem.log_scale); - - decrypt_rlwe_thread_safe(module, &mut res, &tmp_elem, sk, tmp_bytes_encrypt_sk); - - println!("row:{}", row_i); - res.value[0].print_limbs(res.limbs(), 16); + println!("row:{}/{}", row_i, rows); + tmp_elem.at(0).print_limbs(tmp_elem.limbs(), tmp_elem.n()); + tmp_elem.at(1).print_limbs(tmp_elem.limbs(), tmp_elem.n()); println!(); - */ + println!(">>>"); - // GRLWE[row_i][0] = -as + m * 2^{-i*log_base2k} + e*2^{-log_q} + // GRLWE[row_i][0||1] = [-as + m * 2^{-i*log_base2k} + e*2^{-log_q} || a] module.vmp_prepare_row( - &mut ct.value[0], - tmp_elem.at(0), + &mut ct.value, + cast_mut::(tmp_bytes_elem), row_i, tmp_bytes_vmp_prepare_row, ); - - // GRLWE[row_i][1] = a - module.vmp_prepare_row( - &mut ct.value[1], - tmp_elem.at(1), - row_i, - tmp_bytes_vmp_prepare_row, - ); - }) + }); + println!("DONE"); } diff --git a/rlwe/src/evaluator.rs b/rlwe/src/evaluator.rs index e09cf1d..0eee16b 100644 --- a/rlwe/src/evaluator.rs +++ b/rlwe/src/evaluator.rs @@ -1,9 +1,11 @@ use crate::{ - ciphertext::{Ciphertext, GadgetCiphertext}, - elem::Elem, - keys::SwitchingKey, + ciphertext::{Ciphertext, GadgetCiphertext, RGSWCiphertext}, + elem::{Elem, ElemBasics}, }; -use base2k::{Infos, Module, VecZnx, VecZnxBig, VecZnxDft, VecZnxOps, VmpPMatOps}; +use base2k::{ + Infos, Module, VecZnx, VecZnxApi, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMatOps, +}; +use std::cmp::min; pub fn gadget_product_tmp_bytes( module: &Module, @@ -20,15 +22,17 @@ pub fn gadget_product_tmp_bytes( + 2 * module.bytes_of_vec_znx_dft(gct_cols) } -pub fn gadget_product_inplace_thread_safe( +pub fn gadget_product_inplace_thread_safe( module: &Module, - res: &mut Elem, + res: &mut Elem, b: &GadgetCiphertext, tmp_bytes: &mut [u8], -) { +) where + T: VecZnxApi + Infos, +{ unsafe { - let a_ptr: *const VecZnx = res.at(1) as *const VecZnx; - gadget_product_thread_safe::(module, res, &*a_ptr, b, tmp_bytes); + let a_ptr: *const T = res.at(1) as *const T; + gadget_product_thread_safe::(module, res, &*a_ptr, b, tmp_bytes); } } @@ -46,54 +50,105 @@ pub fn gadget_product_inplace_thread_safe( /// /// res = sum[min(a_ncols, b_nrows)] decomp(a, i) * (-B[i]s + m * 2^{-k*i} + E[i], B[i]) /// = (cs + m * a + e, c) with min(res_limbs, b_cols) limbs. -pub fn gadget_product_thread_safe( +pub fn gadget_product_thread_safe( module: &Module, - res: &mut Elem, - a: &VecZnx, + res: &mut Elem, + a: &T, b: &GadgetCiphertext, tmp_bytes: &mut [u8], -) { +) where + T: VecZnxApi + Infos, +{ let log_base2k: usize = b.log_base2k(); + let rows: usize = min(b.rows(), a.limbs()); let cols: usize = b.cols(); - let (tmp_bytes_vmp_apply_dft, tmp_bytes) = - tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols)); + let bytes_vmp_apply_dft: usize = + module.vmp_apply_dft_to_dft_tmp_bytes(cols, a.limbs(), rows, cols); + let bytes_vec_znx_dft: usize = module.bytes_of_vec_znx_dft(cols); - let (tmp_bytes_c1_dft, tmp_bytes_res_dft) = tmp_bytes.split_at_mut(tmp_bytes.len() >> 1); + let (tmp_bytes_vmp_apply_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_vmp_apply_dft); + let (tmp_bytes_c1_dft, tmp_bytes_res_dft) = tmp_bytes.split_at_mut(bytes_vec_znx_dft); - let mut c1_dft: VecZnxDft = module.new_vec_znx_from_bytes(cols, tmp_bytes_c1_dft); - let mut res_dft: VecZnxDft = module.new_vec_znx_from_bytes(cols, tmp_bytes_res_dft); + let mut c1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_c1_dft); + let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_res_dft); let mut res_big: VecZnxBig = res_dft.as_vec_znx_big(); - // a_dft <- DFT(a) [cols] + // Alias c0 and c1 part of res_big + let (tmp_bytes_res_dft_c0, tmp_bytes_res_dft_c1) = + tmp_bytes_res_dft.split_at_mut(bytes_vec_znx_dft >> 1); + let res_big_c0: VecZnxBig = module.new_vec_znx_big_from_bytes(cols >> 1, tmp_bytes_res_dft_c0); + let mut res_big_c1: VecZnxBig = + module.new_vec_znx_big_from_bytes(cols >> 1, tmp_bytes_res_dft_c1); + + // a_dft <- DFT(a) module.vec_znx_dft(&mut c1_dft, a, a.limbs()); - // >>>>>>>> RES[0] + // (n x cols) <- (n x limbs=rows) x (rows x cols) + // res_dft[a * (G0|G1)] <- sum[rows] DFT(a) x (DFT(G0)|DFT(G1)) + module.vmp_apply_dft_to_dft(&mut res_dft, &c1_dft, &b.value, tmp_bytes_vmp_apply_dft); - // res_dft <- sum[rows] DFT(a)[cols] x GadgetCiphertext[0][cols] - module.vmp_apply_dft_to_dft(&mut res_dft, &c1_dft, &b.value[0], tmp_bytes_vmp_apply_dft); - - // res_big <- IDFT(DFT(a) x GadgetCiphertext[0]) + // res_big[a * (G0|G1)] <- IDFT(res_dft[a * (G0|G1)]) module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft, cols); - // res_big <- res[0] + a_dft x GadgetCiphertext[0] + // res_big <- res[0] + res_big[a*G0] module.vec_znx_big_add_small_inplace(&mut res_big, res.at(0)); - module.vec_znx_big_normalize(log_base2k, res.at_mut(0), &res_big, tmp_bytes_vmp_apply_dft); - - // >>>>>>>> RES[1] - - // res_dft <- DFT(c1) x GadgetCiphertext[1] - module.vmp_apply_dft_to_dft(&mut res_dft, &c1_dft, &b.value[1], tmp_bytes_vmp_apply_dft); - - // res_big <- IDFT(DFT(c1) x GadgetCiphertext[1]) - module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft, cols); + module.vec_znx_big_normalize(log_base2k, res.at_mut(0), &res_big_c0, tmp_bytes_c1_dft); if OVERWRITE { - // res[1] = normalize(a_dft x GadgetCiphertext[1]) - module.vec_znx_big_normalize(log_base2k, res.at_mut(1), &res_big, tmp_bytes_vmp_apply_dft); + // res[1] = normalize(res_big[a*G1]) + module.vec_znx_big_normalize(log_base2k, res.at_mut(1), &res_big_c1, tmp_bytes_c1_dft); } else { - // res[1] = normalize(a_dft x GadgetCiphertext[1] + res[1]) - module.vec_znx_big_add_small_inplace(&mut res_big, res.at(0)); - module.vec_znx_big_normalize(log_base2k, res.at_mut(0), &res_big, tmp_bytes_vmp_apply_dft); + // res[1] = normalize(res_big[a*G1] + res[1]) + module.vec_znx_big_add_small_inplace(&mut res_big_c1, res.at(1)); + module.vec_znx_big_normalize(log_base2k, res.at_mut(1), &res_big_c1, tmp_bytes_c1_dft); } } + +pub fn rgsw_product_thread_safe( + module: &Module, + res: &mut Elem, + a: &Ciphertext, + b: &RGSWCiphertext, + tmp_bytes: &mut [u8], +) where + T: VecZnxApi + Infos, +{ + let log_base2k: usize = b.log_base2k(); + let rows: usize = a.limbs(); + let cols: usize = b.cols(); + let in_limbs = a.limbs(); + let out_limbs: usize = a.limbs(); + + let bytes_of_vec_znx_dft = module.bytes_of_vec_znx_dft(cols); + let bytes_of_vmp_apply_dft_to_dft = + module.vmp_apply_dft_to_dft_tmp_bytes(out_limbs, in_limbs, rows, cols); + + let (tmp_bytes_c0_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft); + let (tmp_bytes_c1_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft); + let (tmp_bytes_tmp_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft); + let (tmp_bytes_r1_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft); + let (tmp_bytes_r2_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_vec_znx_dft); + let (bytes_of_vmp_apply_dft_to_dft, tmp_bytes) = + tmp_bytes.split_at_mut(bytes_of_vmp_apply_dft_to_dft); + + let mut c0_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_c0_dft); + let mut c1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_c1_dft); + let mut tmp_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_tmp_dft); + let mut r1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_r1_dft); + let mut r2_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_r2_dft); + + // c0_dft <- DFT(a[0]) + module.vec_znx_dft(&mut c0_dft, a.at(0), a.limbs()); + + // r_dft <- sum[rows] c0_dft[cols] x RGSW[0][cols] + module.vmp_apply_dft_to_dft( + &mut r1_dft, + &c1_dft, + &b.value, + bytes_of_vmp_apply_dft_to_dft, + ); + + // c1_dft <- DFT(a[1]) + module.vec_znx_dft(&mut c1_dft, a.at(1), a.limbs()); +} diff --git a/rlwe/src/key_generator.rs b/rlwe/src/key_generator.rs index 06c09ce..6546317 100644 --- a/rlwe/src/key_generator.rs +++ b/rlwe/src/key_generator.rs @@ -1,7 +1,7 @@ use crate::encryptor::{encrypt_grlwe_sk_thread_safe, encrypt_grlwe_sk_tmp_bytes}; use crate::keys::{PublicKey, SecretKey, SwitchingKey}; use crate::parameters::Parameters; -use base2k::{Module, SvpPPol, SvpPPolOps}; +use base2k::{Module, SvpPPol}; use sampling::source::Source; pub struct KeyGenerator {} diff --git a/rlwe/src/keys.rs b/rlwe/src/keys.rs index 93032f3..3561e9e 100644 --- a/rlwe/src/keys.rs +++ b/rlwe/src/keys.rs @@ -1,8 +1,7 @@ use crate::ciphertext::GadgetCiphertext; use crate::elem::Elem; use crate::encryptor::{encrypt_rlwe_sk_thread_safe, encrypt_rlwe_sk_tmp_bytes}; -use crate::parameters::Parameters; -use base2k::{Module, Scalar, SvpPPol, SvpPPolOps, VmpPMat, VmpPMatOps}; +use base2k::{Module, Scalar, SvpPPol, SvpPPolOps, VecZnx}; use sampling::source::Source; pub struct SecretKey(pub Scalar); @@ -25,7 +24,7 @@ impl SecretKey { } } -pub struct PublicKey(pub Elem); +pub struct PublicKey(pub Elem); impl PublicKey { pub fn new(module: &Module, log_base2k: usize, log_q: usize) -> PublicKey { diff --git a/rlwe/src/plaintext.rs b/rlwe/src/plaintext.rs index b31e297..651feab 100644 --- a/rlwe/src/plaintext.rs +++ b/rlwe/src/plaintext.rs @@ -1,35 +1,39 @@ use crate::ciphertext::Ciphertext; -use crate::elem::Elem; +use crate::elem::{Elem, ElemBasics}; use crate::parameters::Parameters; -use base2k::{Module, VecZnx}; +use base2k::{Infos, Module, VecZnx, VecZnxApi, VecZnxBorrow}; -pub struct Plaintext(pub Elem); +pub struct Plaintext(pub Elem); impl Parameters { - pub fn new_plaintext(&self, log_q: usize) -> Plaintext { + pub fn new_plaintext(&self, log_q: usize) -> Plaintext { Plaintext::new(self.module(), self.log_base2k(), log_q, self.log_scale()) } pub fn bytes_of_plaintext(&self, log_q: usize) -> usize { - Elem::bytes_of(self.module(), self.log_base2k(), log_q, 0) + Elem::::bytes_of(self.module(), self.log_base2k(), log_q, 0) } - pub fn plaintext_from_bytes(&self, log_q: usize, bytes: &mut [u8]) -> Plaintext { + pub fn plaintext_from_bytes(&self, log_q: usize, bytes: &mut [u8]) -> Plaintext { Plaintext(self.elem_from_bytes(log_q, 0, bytes)) } + + pub fn plaintext_borrow_from_bytes(&self, log_q: usize, bytes: &mut [u8]) -> Plaintext { + Plaintext(self.elem_borrow_from_bytes(log_q, 0, bytes)) + } } -impl Plaintext { +impl Plaintext { pub fn new(module: &Module, log_base2k: usize, log_q: usize, log_scale: usize) -> Self { - Self(Elem::new(module, log_base2k, log_q, 0, log_scale)) + Self(Elem::::new(module, log_base2k, log_q, 0, log_scale)) } pub fn bytes_of(module: &Module, log_base2k: usize, log_q: usize) -> usize { - Elem::bytes_of(module, log_base2k, log_q, 0) + Elem::::bytes_of(module, log_base2k, log_q, 0) } pub fn from_bytes(module: &Module, log_base2k: usize, log_q: usize, bytes: &mut [u8]) -> Self { - Self(Elem::from_bytes(module, log_base2k, log_q, 0, bytes)) + Self(Elem::::from_bytes(module, log_base2k, log_q, 0, bytes)) } pub fn n(&self) -> usize { @@ -71,4 +75,58 @@ impl Plaintext { pub fn as_ciphertext(&self) -> Ciphertext { unsafe { Ciphertext(std::ptr::read(&self.0)) } } + +} + +impl Plaintext { + + pub fn bytes_of(module: &Module, log_base2k: usize, log_q: usize) -> usize { + Elem::::bytes_of(module, log_base2k, log_q, 0) + } + + pub fn from_bytes(module: &Module, log_base2k: usize, log_q: usize, bytes: &mut [u8]) -> Self { + Self(Elem::::from_bytes(module, log_base2k, log_q, 0, bytes)) + } + + pub fn n(&self) -> usize { + self.0.n() + } + + pub fn degree(&self) -> usize { + self.0.degree() + } + + pub fn log_q(&self) -> usize { + self.0.log_q() + } + + pub fn limbs(&self) -> usize { + self.0.limbs() + } + + pub fn at(&self, i: usize) -> &VecZnxBorrow { + self.0.at(i) + } + + pub fn at_mut(&mut self, i: usize) -> &mut VecZnxBorrow { + self.0.at_mut(i) + } + + pub fn log_base2k(&self) -> usize { + self.0.log_base2k() + } + + pub fn log_scale(&self) -> usize { + self.0.log_scale() + } + + pub fn zero(&mut self) { + self.0.zero() + } + + /* + pub fn as_ciphertext(&self) -> Ciphertext { + unsafe { Ciphertext(std::ptr::read(&self.0)) } + } + */ }