diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index e9f902e..b084d80 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -38,7 +38,7 @@ fn main() { let mut buf_dft: VecZnxDft = module.new_vec_znx_dft(a.limbs()); // Applies buf_dft <- s * a - module.svp_apply_dft(&mut buf_dft, &s_ppol, &a); + module.svp_apply_dft(&mut buf_dft, &s_ppol, &a, a.limbs()); // Alias scratch space let mut buf_big: VecZnxBig = buf_dft.as_vec_znx_big(); @@ -67,11 +67,11 @@ fn main() { //Decrypt // buf_big <- a * s - module.svp_apply_dft(&mut buf_dft, &s_ppol, &a); + module.svp_apply_dft(&mut buf_dft, &s_ppol, &a, a.limbs()); module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft, b.limbs()); // buf_big <- a * s + b - module.vec_znx_big_add_small_inplace(&mut buf_big, &b); + module.vec_znx_big_add_small_inplace(&mut buf_big, &b, b.limbs()); // res <- normalize(buf_big) module.vec_znx_big_normalize(log_base2k, &mut res, &buf_big, &mut carry); diff --git a/base2k/spqlios-arithmetic b/base2k/spqlios-arithmetic index 83555cc..5461131 160000 --- a/base2k/spqlios-arithmetic +++ b/base2k/spqlios-arithmetic @@ -1 +1 @@ -Subproject commit 83555cc664b4ebedd9b82d35120c80605b895b87 +Subproject commit 546113166e0e204cdfcd7a78ed96b6df7c457e40 diff --git a/base2k/src/encoding.rs b/base2k/src/encoding.rs index 4631cf5..1808fd4 100644 --- a/base2k/src/encoding.rs +++ b/base2k/src/encoding.rs @@ -56,6 +56,8 @@ impl Encoding for VecZnx { fn encode_vec_i64(&mut self, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) { let limbs: usize = (log_k + log_base2k - 1) / log_base2k; + println!("limbs: {}", limbs); + assert!(limbs <= self.limbs(), "invalid argument log_k: (log_k + self.log_base2k - 1)/self.log_base2k={} > self.limbs()={}", limbs, self.limbs()); let size: usize = min(data.len(), self.n()); @@ -65,10 +67,10 @@ impl Encoding for VecZnx { // values on the last limb. // Else we decompose values base2k. if log_max + log_k_rem < 63 || log_k_rem == log_base2k { - (0..limbs - 1).for_each(|i| unsafe { + (0..self.limbs()).for_each(|i| unsafe { znx_zero_i64_ref(size as u64, self.at_mut(i).as_mut_ptr()); }); - self.at_mut(self.limbs() - 1)[..size].copy_from_slice(&data[..size]); + self.at_mut(limbs - 1)[..size].copy_from_slice(&data[..size]); } else { let mask: i64 = (1 << log_base2k) - 1; let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k); diff --git a/base2k/src/ffi/vmp.rs b/base2k/src/ffi/vmp.rs index 154555c..a0e6a92 100644 --- a/base2k/src/ffi/vmp.rs +++ b/base2k/src/ffi/vmp.rs @@ -91,6 +91,18 @@ unsafe extern "C" { ); } +unsafe extern "C" { + pub unsafe fn vmp_prepare_row( + module: *const MODULE, + pmat: *mut VMP_PMAT, + row: *const i64, + row_i: u64, + nrows: u64, + ncols: u64, + tmp_space: *mut u8, + ); +} + unsafe extern "C" { pub unsafe fn vmp_prepare_tmp_bytes(module: *const MODULE, nrows: u64, ncols: u64) -> u64; } diff --git a/base2k/src/svp.rs b/base2k/src/svp.rs index cc07247..85b71b8 100644 --- a/base2k/src/svp.rs +++ b/base2k/src/svp.rs @@ -91,7 +91,7 @@ pub trait SvpPPolOps { /// Applies the [SvpPPol] x [VecZnxDft] product, where each limb of /// the [VecZnxDft] is multiplied with [SvpPPol]. - fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx); + fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx, b_limbs: usize); } impl SvpPPolOps for Module { @@ -107,14 +107,13 @@ impl SvpPPolOps for Module { unsafe { svp::svp_prepare(self.0, svp_ppol.0, a.as_ptr()) } } - fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx) { - let limbs: u64 = b.limbs() as u64; + fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx, b_limbs: usize) { assert!( - c.limbs() as u64 >= limbs, + c.limbs() >= b_limbs, "invalid c_vector: c_vector.limbs()={} < b.limbs()={}", c.limbs(), - limbs + b_limbs ); - unsafe { svp::svp_apply_dft(self.0, c.0, limbs, a.0, b.as_ptr(), limbs, b.n() as u64) } + unsafe { svp::svp_apply_dft(self.0, c.0, b_limbs as u64, a.0, b.as_ptr(), b_limbs as u64, b.n() as u64) } } } diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index f8a6ac5..461254d 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -117,23 +117,22 @@ impl Module { } // b <- b + a - pub fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) { - let limbs: usize = a.limbs(); + pub fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx, a_limbs: usize) { assert!( - b.limbs() >= limbs, + b.limbs() >= a_limbs, "invalid c_vector: b.limbs()={} < a.limbs()={}", b.limbs(), - limbs + a_limbs ); unsafe { vec_znx_big::vec_znx_big_add_small( self.0, b.0, - limbs as u64, + a_limbs as u64, b.0, - limbs as u64, + a_limbs as u64, a.as_ptr(), - limbs as u64, + a_limbs as u64, a.n() as u64, ) } diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index fd4f4f7..64b61a5 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -1,7 +1,7 @@ use crate::ffi::vec_znx_big; use crate::ffi::vec_znx_dft; use crate::ffi::vec_znx_dft::bytes_of_vec_znx_dft; -use crate::{Module, VecZnxBig}; +use crate::{Module, VecZnx, VecZnxBig}; pub struct VecZnxDft(pub *mut vec_znx_dft::vec_znx_dft_t, pub usize); @@ -30,6 +30,25 @@ impl Module { unsafe { VecZnxDft(vec_znx_dft::new_vec_znx_dft(self.0, limbs as u64), limbs) } } + /// Returns a new [VecZnxDft] with the provided bytes array as backing array. + /// + /// # Arguments + /// + /// * `limbs`: the number of limbs of the [VecZnxDft]. + /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft]. + /// + /// # Panics + /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. + pub fn new_vec_znx_from_bytes(&self, limbs: usize, bytes: &mut [u8]) -> VecZnxDft { + assert!( + bytes.len() >= self.bytes_of_vec_znx_dft(limbs), + "invalid bytes: bytes.len()={} < bytes_of_vec_znx_dft={}", + bytes.len(), + self.bytes_of_vec_znx_dft(limbs) + ); + VecZnxDft::from_bytes(limbs, bytes) + } + /// Returns the minimum number of bytes necessary to allocate /// a new [VecZnxDft] through [VecZnxDft::from_bytes]. pub fn bytes_of_vec_znx_dft(&self, limbs: usize) -> usize { @@ -52,6 +71,29 @@ impl Module { unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(self.0) as usize } } + /// b <- DFT(a) + /// + /// # Panics + /// If b.limbs < a_limbs + pub fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx, a_limbs: usize) { + assert!( + b.limbs() >= a_limbs, + "invalid a_limbs: b.limbs()={} < a_limbs={}", + b.limbs(), + a_limbs + ); + unsafe { + vec_znx_dft::vec_znx_dft( + self.0, + b.0, + a_limbs as u64, + a.as_ptr(), + a_limbs as u64, + a.n as u64, + ) + } + } + // b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes]. pub fn vec_znx_idft( &self, diff --git a/base2k/src/vmp.rs b/base2k/src/vmp.rs index c850e6a..3a6ce09 100644 --- a/base2k/src/vmp.rs +++ b/base2k/src/vmp.rs @@ -169,6 +169,38 @@ pub trait VmpPMatOps { /// ``` fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &Vec, buf: &mut [u8]); + /// Prepares the ith-row of [VmpPMat] from a vector of [VecZnx]. + /// + /// # Arguments + /// + /// * `b`: [VmpPMat] on which the values are encoded. + /// * `a`: the vector of [VecZnx] to encode on the [VmpPMat]. + /// * `row_i`: the index of the row to prepare. + /// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. + /// + /// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. + /// /// # Example + /// ``` + /// use base2k::{Module, FFT64, Matrix3D, VmpPMat, VmpPMatOps, VecZnx, VecZnxOps, Free}; + /// use std::cmp::min; + /// + /// let n: usize = 1024; + /// let module: Module = Module::new::(n); + /// let rows: usize = 5; + /// let cols: usize = 6; + /// + /// let vecznx: module.new_vec_znx(cols); + /// + /// let mut buf: Vec = vec![u8::default(); module.vmp_prepare_tmp_bytes(rows, cols)]; + /// + /// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols); + /// module.vmp_prepare_row(&mut vmp_pmat, &vecznx, 0, &mut buf); + /// + /// vmp_pmat.free(); + /// module.free(); + /// ``` + fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &VecZnx, row_i: usize, tmp_bytes: &mut [u8]); + /// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft]. /// /// # Arguments @@ -404,6 +436,20 @@ impl VmpPMatOps for Module { } } + fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &VecZnx, row_i: usize, buf: &mut [u8]) { + unsafe { + vmp::vmp_prepare_row( + self.0, + b.data(), + a.data.as_ptr(), + row_i as u64, + b.rows() as u64, + b.cols() as u64, + buf.as_mut_ptr(), + ); + } + } + fn vmp_apply_dft_tmp_bytes( &self, c_limbs: usize, diff --git a/rlwe/examples/encryption.rs b/rlwe/examples/encryption.rs new file mode 100644 index 0000000..eb4e323 --- /dev/null +++ b/rlwe/examples/encryption.rs @@ -0,0 +1,77 @@ +use base2k::{Encoding, FFT64, SvpPPolOps}; +use rlwe::{ + ciphertext::Ciphertext, + decryptor::{Decryptor, decrypt_rlwe_thread_safe_tmp_byte}, + encryptor::{EncryptorSk, encrypt_rlwe_sk_tmp_bytes}, + keys::SecretKey, + parameters::{Parameters, ParametersLiteral}, + plaintext::Plaintext, +}; +use sampling::source::{Source, new_seed}; + +fn main() { + let params_lit: ParametersLiteral = ParametersLiteral { + log_n: 10, + log_q: 54, + log_p: 0, + log_base2k: 17, + log_scale: 20, + xe: 3.2, + xs: 128, + }; + + let params: Parameters = Parameters::new::(¶ms_lit); + + let mut tmp_bytes: Vec = vec![ + 0u8; + params.decrypt_rlwe_thread_safe_tmp_byte(params.log_q()) + | params.encrypt_rlwe_sk_tmp_bytes(params.log_q()) + ]; + + let sk: SecretKey = SecretKey::new(params.module()); + + let mut want = vec![i64::default(); params.n()]; + + want.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); + + let mut pt: Plaintext = params.new_plaintext(params.log_q() - 20); + + let log_base2k = pt.log_base2k(); + + let log_k: usize = 17; + + pt.0.value[0].encode_vec_i64(log_base2k, log_k, &want, 32); + pt.0.value[0].normalize(log_base2k, &mut tmp_bytes); + + println!("log_k: {}", log_k); + pt.0.value[0].print_limbs(pt.limbs(), 16); + + let mut ct: Ciphertext = params.new_ciphertext(params.log_q()); + + let mut source_xe: Source = Source::new(new_seed()); + let mut source_xa: Source = Source::new(new_seed()); + + let mut sk_svp_ppol: base2k::SvpPPol = params.module().svp_new_ppol(); + params.module().svp_prepare(&mut sk_svp_ppol, &sk.0); + + params.encrypt_rlwe_sk_thread_safe( + &mut ct, + Some(&pt), + &sk_svp_ppol, + &mut source_xa, + &mut source_xe, + &mut tmp_bytes, + ); + + params.decrypt_rlwe_thread_safe(&mut pt, &ct, &sk_svp_ppol, &mut tmp_bytes); + + pt.0.value[0].print_limbs(pt.limbs(), 16); + + let mut have = vec![i64::default(); params.n()]; + + println!("pt: {}", log_k); + pt.0.value[0].decode_vec_i64(pt.log_base2k(), log_k, &mut have); + + println!("want: {:?}", &want[..16]); + println!("have: {:?}", &have[..16]); +} diff --git a/rlwe/src/ciphertext.rs b/rlwe/src/ciphertext.rs index 1870f4b..91281ba 100644 --- a/rlwe/src/ciphertext.rs +++ b/rlwe/src/ciphertext.rs @@ -1,20 +1,19 @@ use crate::elem::Elem; +use crate::parameters::Parameters; use crate::plaintext::Plaintext; -use base2k::VecZnx; +use base2k::{Module, VecZnx, VmpPMat, VmpPMatOps}; pub struct Ciphertext(pub Elem); -/* -impl Parameters { - pub fn new_ciphertext(&self, degree: usize, log_base2k: usize, log_q: usize) -> Ciphertext { - Ciphertext(self.new_elem(degree, log_base2k, log_q)) - } -} - */ - impl Ciphertext { - pub fn new(n: usize, log_base2k: usize, log_q: usize, degree: usize) -> Self { - Self(Elem::new(n, log_base2k, log_q, degree)) + pub fn new( + module: &Module, + log_base2k: usize, + log_q: usize, + degree: usize, + log_scale: usize, + ) -> Self { + Self(Elem::new(module, log_base2k, log_q, degree, log_scale)) } pub fn n(&self) -> usize { @@ -45,7 +44,75 @@ impl Ciphertext { self.0.log_base2k() } + pub fn log_scale(&self) -> usize { + self.0.log_scale + } + pub fn as_plaintext(&self) -> Plaintext { unsafe { Plaintext(std::ptr::read(&self.0)) } } } + +impl Parameters { + pub fn new_ciphertext(&self, log_q: usize) -> Ciphertext { + Ciphertext::new(self.module(), self.log_base2k(), log_q, self.log_scale(), 1) + } +} + +pub struct GadgetCiphertext { + pub value: Vec, + pub log_base2k: usize, + pub log_q: usize, + pub log_scale: usize, +} + +impl GadgetCiphertext { + pub fn new( + module: &Module, + log_base2k: usize, + rows: usize, + log_q: usize, + log_scale: usize, + ) -> Self { + let cols: usize = (log_q + log_base2k - 1) / log_base2k; + let mut value: Vec = Vec::new(); + (0..rows).for_each(|_| value.push(module.new_vmp_pmat(rows, cols))); + Self { + value, + log_base2k, + log_q, + log_scale, + } + } + + pub fn n(&self) -> usize { + self.value[0].n + } + + pub fn rows(&self) -> usize { + self.value[0].rows + } + + pub fn cols(&self) -> usize { + self.value[0].cols + } + + pub fn degree(&self) -> usize { + self.value.len() - 1 + } + + pub fn log_q(&self) -> usize { + self.log_q + } + + pub fn log_base2k(&self) -> usize { + self.log_base2k + } +} + +pub struct RGSWCiphertext { + pub value: [GadgetCiphertext; 2], + pub log_base2k: usize, + pub log_q: usize, + pub log_p: usize, +} diff --git a/rlwe/src/decryptor.rs b/rlwe/src/decryptor.rs new file mode 100644 index 0000000..750a346 --- /dev/null +++ b/rlwe/src/decryptor.rs @@ -0,0 +1,80 @@ +use crate::{ + ciphertext::Ciphertext, keys::SecretKey, parameters::Parameters, plaintext::Plaintext, +}; +use base2k::{Module, SvpPPol, SvpPPolOps, VecZnxDft}; +use std::cmp::min; + +pub struct Decryptor { + sk: SvpPPol, +} + +impl Decryptor { + pub fn new(params: &Parameters, sk: &SecretKey) -> Self { + let mut sk_svp_ppol: SvpPPol = params.module().svp_new_ppol(); + sk.prepare(params.module(), &mut sk_svp_ppol); + Self { sk: sk_svp_ppol } + } +} + +pub fn decrypt_rlwe_thread_safe_tmp_byte(module: &Module, limbs: usize) -> usize { + module.bytes_of_vec_znx_dft(limbs) + module.vec_znx_big_normalize_tmp_bytes() +} + +impl Parameters { + pub fn decrypt_rlwe_thread_safe_tmp_byte(&self, log_q: usize) -> usize { + decrypt_rlwe_thread_safe_tmp_byte( + self.module(), + (log_q + self.log_base2k() - 1) / self.log_base2k(), + ) + } + + pub fn decrypt_rlwe_thread_safe( + &self, + res: &mut Plaintext, + ct: &Ciphertext, + sk: &SvpPPol, + tmp_bytes: &mut [u8], + ) { + decrypt_rlwe_thread_safe(self.module(), res, ct, sk, tmp_bytes) + } +} + +pub fn decrypt_rlwe_thread_safe( + module: &Module, + res: &mut Plaintext, + ct: &Ciphertext, + sk: &SvpPPol, + tmp_bytes: &mut [u8], +) { + let limbs: usize = min(res.limbs(), ct.limbs()); + + assert!( + tmp_bytes.len() >= decrypt_rlwe_thread_safe_tmp_byte(module, limbs), + "invalid tmp_bytes: tmp_bytes.len()={} < decrypt_rlwe_thread_safe_tmp_byte={}", + tmp_bytes.len(), + decrypt_rlwe_thread_safe_tmp_byte(module, limbs) + ); + + let res_dft_bytes: usize = module.bytes_of_vec_znx_dft(limbs); + + let mut res_dft: VecZnxDft = VecZnxDft::from_bytes(limbs, tmp_bytes); + let mut res_big: base2k::VecZnxBig = res_dft.as_vec_znx_big(); + + // res_dft <- DFT(ct[1]) * DFT(sk) + module.svp_apply_dft(&mut res_dft, sk, &ct.0.value[1], limbs); + // res_big <- ct[1] x sk + module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft, limbs); + // res_big <- ct[1] x sk + ct[0] + module.vec_znx_big_add_small_inplace(&mut res_big, &ct.0.value[0], limbs); + // res <- normalize(ct[1] x sk + ct[0]) + module.vec_znx_big_normalize( + ct.log_base2k(), + res.at_mut(0), + &res_big, + &mut tmp_bytes[res_dft_bytes..], + ); + + res.0.log_base2k = ct.log_base2k(); + res.0.log_q = min(res.log_q(), ct.log_q()); + res.0.log_scale = ct.log_scale(); +} diff --git a/rlwe/src/elem.rs b/rlwe/src/elem.rs index bb8eb74..570833c 100644 --- a/rlwe/src/elem.rs +++ b/rlwe/src/elem.rs @@ -1,21 +1,59 @@ use crate::parameters::Parameters; -use base2k::{Infos, VecZnx}; +use base2k::{Infos, Module, VecZnx, VecZnxOps}; pub struct Elem { pub value: Vec, pub log_base2k: usize, pub log_q: usize, + pub log_scale: usize, } impl Elem { - pub fn new(n: usize, log_base2k: usize, log_q: usize, degree: usize) -> Self { + pub fn new( + module: &Module, + log_base2k: usize, + log_q: usize, + degree: usize, + log_scale: usize, + ) -> Self { let limbs: usize = (log_q + log_base2k - 1) / log_base2k; let mut value: Vec = Vec::new(); - (0..degree + 1).for_each(|_| value.push(VecZnx::new(n, limbs))); + (0..degree + 1).for_each(|_| value.push(module.new_vec_znx(limbs))); Self { value, - log_base2k, log_q, + log_base2k, + log_scale: log_scale, + } + } + + pub fn bytes_of(module: &Module, log_base2k: usize, log_q: usize, degree: usize) -> usize { + let cols = (log_q + log_base2k - 1) / log_base2k; + module.n() * cols * (degree + 1) * 8 + } + + pub fn from_bytes( + module: &Module, + log_base2k: usize, + log_q: usize, + degree: usize, + bytes: &mut [u8], + ) -> Self { + let n: usize = module.n(); + assert!(bytes.len() >= Self::bytes_of(module, log_base2k, log_q, degree)); + let mut value: Vec = Vec::new(); + let limbs: usize = (log_q + log_base2k - 1) / log_base2k; + let size = VecZnx::bytes(n, limbs); + let mut ptr: usize = 0; + (0..degree + 1).for_each(|_| { + value.push(VecZnx::from_bytes(n, limbs, &mut bytes[ptr..])); + ptr += size + }); + Self { + value, + log_q, + log_base2k, + log_scale: 0, } } @@ -35,6 +73,10 @@ impl Elem { self.log_base2k } + pub fn log_scale(&self) -> usize { + self.log_scale + } + pub fn log_q(&self) -> usize { self.log_q } @@ -49,3 +91,13 @@ impl Elem { &mut self.value[i] } } + +impl Parameters { + pub fn bytes_of_elem(&self, log_q: usize, degree: usize) -> usize { + Elem::bytes_of(self.module(), self.log_base2k(), log_q, degree) + } + + pub fn elem_from_bytes(&self, log_q: usize, degree: usize, bytes: &mut [u8]) -> Elem { + Elem::from_bytes(self.module(), self.log_base2k(), log_q, degree, bytes) + } +} diff --git a/rlwe/src/encryptor.rs b/rlwe/src/encryptor.rs index efb4c59..06ecf0f 100644 --- a/rlwe/src/encryptor.rs +++ b/rlwe/src/encryptor.rs @@ -1,84 +1,141 @@ -use crate::ciphertext::Ciphertext; +use crate::ciphertext::{Ciphertext, GadgetCiphertext}; use crate::elem::Elem; use crate::keys::SecretKey; use crate::parameters::Parameters; use crate::plaintext::Plaintext; +use base2k::ffi::znx::znx_zero_i64_ref; use base2k::sampling::Sampling; -use base2k::{Module, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxDft}; -use sampling::source::Source; +use base2k::{ + Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxDft, VecZnxOps, VmpPMatOps, +}; +use sampling::source::{Source, new_seed}; pub struct EncryptorSk { - pub sk: SvpPPol, + sk: SvpPPol, + source_xa: Source, + source_xe: Source, + initialized: bool, + tmp_bytes: Vec, } impl EncryptorSk { - pub fn new(params: &Parameters, sk: &SecretKey) -> Self { + pub fn new(params: &Parameters, sk: Option<&SecretKey>) -> Self { let mut sk_svp_ppol: SvpPPol = params.module().svp_new_ppol(); - params.module().svp_prepare(&mut sk_svp_ppol, &sk.0); - Self { sk: sk_svp_ppol } + let mut initialized: bool = false; + if let Some(sk) = sk { + sk.prepare(params.module(), &mut sk_svp_ppol); + initialized = true; + } + Self { + sk: sk_svp_ppol, + initialized, + source_xa: Source::new(new_seed()), + source_xe: Source::new(new_seed()), + tmp_bytes: vec![0u8; params.encrypt_rlwe_sk_tmp_bytes(params.limbs_qp())], + } + } + + pub fn set_sk(&mut self, module: &Module, sk: &SecretKey) { + sk.prepare(module, &mut self.sk); + self.initialized = true; + } + + pub fn seed_source_xa(&mut self, seed: [u8; 32]) { + self.source_xa = Source::new(seed) + } + + pub fn seed_source_xe(&mut self, seed: [u8; 32]) { + self.source_xe = Source::new(seed) } pub fn encrypt_rlwe_sk( + &mut self, + params: &Parameters, + ct: &mut Ciphertext, + pt: Option<&Plaintext>, + ) { + assert!( + self.initialized == true, + "invalid call to [EncryptorSk.encrypt_rlwe_sk]: [EncryptorSk] has not been initialized with a [SecretKey]" + ); + params.encrypt_rlwe_sk_thread_safe( + ct, + pt, + &self.sk, + &mut self.source_xa, + &mut self.source_xe, + &mut self.tmp_bytes, + ); + } + + pub fn encrypt_rlwe_sk_thread_safe( &self, params: &Parameters, ct: &mut Ciphertext, pt: Option<&Plaintext>, - xa_source: &mut Source, - xe_source: &mut Source, + source_xa: &mut Source, + source_xe: &mut Source, tmp_bytes: &mut [u8], ) { - params.encrypt_rlwe_sk(ct, pt, &self.sk, xa_source, xe_source, tmp_bytes); + assert!( + self.initialized == true, + "invalid call to [EncryptorSk.encrypt_rlwe_sk_thread_safe]: [EncryptorSk] has not been initialized with a [SecretKey]" + ); + params.encrypt_rlwe_sk_thread_safe(ct, pt, &self.sk, source_xa, source_xe, tmp_bytes); } } impl Parameters { - pub fn encrypt_rlwe_sk_tmp_bytes(&self, limbs: usize) -> usize { - encrypt_rlwe_sk_tmp_bytes(self.module(), limbs) + pub fn encrypt_rlwe_sk_tmp_bytes(&self, log_q: usize) -> usize { + encrypt_rlwe_sk_tmp_bytes(self.module(), self.log_base2k(), log_q) } - pub fn encrypt_rlwe_sk( + pub fn encrypt_rlwe_sk_thread_safe( &self, ct: &mut Ciphertext, pt: Option<&Plaintext>, sk: &SvpPPol, - xa_source: &mut Source, - xe_source: &mut Source, + source_xa: &mut Source, + source_xe: &mut Source, tmp_bytes: &mut [u8], ) { - encrypt_rlwe_sk( + encrypt_rlwe_sk_thread_safe( self.module(), &mut ct.0, pt.map(|pt| &pt.0), sk, - xa_source, - xe_source, + source_xa, + source_xe, self.xe(), tmp_bytes, ) } } -pub fn encrypt_rlwe_sk_tmp_bytes(module: &Module, limbs: usize) -> usize { - module.bytes_of_vec_znx_dft(limbs) + module.vec_znx_big_normalize_tmp_bytes() +pub fn encrypt_rlwe_sk_tmp_bytes(module: &Module, log_base2k: usize, log_q: usize) -> usize { + module.bytes_of_vec_znx_dft((log_q + log_base2k - 1) / log_base2k) + + module.vec_znx_big_normalize_tmp_bytes() } -pub fn encrypt_rlwe_sk( +pub fn encrypt_rlwe_sk_thread_safe( module: &Module, ct: &mut Elem, pt: Option<&Elem>, sk: &SvpPPol, - xa_source: &mut Source, - xe_source: &mut Source, + source_xa: &mut Source, + source_xe: &mut Source, sigma: f64, tmp_bytes: &mut [u8], ) { let limbs: usize = ct.limbs(); + let log_base2k: usize = ct.log_base2k(); + let log_q: usize = ct.log_q(); assert!( - tmp_bytes.len() >= encrypt_rlwe_sk_tmp_bytes(module, limbs), + tmp_bytes.len() >= encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q), "invalid tmp_bytes: tmp_bytes={} < encrypt_rlwe_sk_tmp_bytes={}", tmp_bytes.len(), - encrypt_rlwe_sk_tmp_bytes(module, limbs) + encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q) ); let log_q: usize = ct.log_q(); @@ -86,7 +143,7 @@ pub fn encrypt_rlwe_sk( let c1: &mut VecZnx = ct.at_mut(1); // c1 <- Z_{2^prec}[X]/(X^{N}+1) - c1.fill_uniform(limbs, log_base2k, xa_source); + c1.fill_uniform(log_base2k, limbs, source_xa); let bytes_of_vec_znx_dft = module.bytes_of_vec_znx_dft(limbs); @@ -95,7 +152,7 @@ pub fn encrypt_rlwe_sk( VecZnxDft::from_bytes(limbs, &mut tmp_bytes[..bytes_of_vec_znx_dft]); // Applies buf_dft <- s * c1 - module.svp_apply_dft(&mut buf_dft, sk, c1); + module.svp_apply_dft(&mut buf_dft, sk, c1, limbs); // Alias scratch space let mut buf_big: VecZnxBig = buf_dft.as_vec_znx_big(); @@ -110,5 +167,90 @@ pub fn encrypt_rlwe_sk( // c0 <- normalize(buf_big) + e let c0: &mut VecZnx = ct.at_mut(0); module.vec_znx_big_normalize(log_base2k, c0, &buf_big, carry); - c0.add_normal(log_base2k, log_q, xe_source, sigma, (sigma * 6.0).ceil()); + c0.add_normal(log_base2k, log_q, source_xe, sigma, (sigma * 6.0).ceil()); +} + +pub fn encrypt_grlwe_sk_tmp_bytes( + module: &Module, + log_base2k: usize, + rows: usize, + log_q: usize, +) -> usize { + let cols = (log_q + log_base2k - 1) / log_base2k; + Elem::bytes_of(module, log_base2k, log_q, 1) + + Plaintext::bytes_of(module, log_base2k, log_q) + + encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q) + + module.vmp_prepare_tmp_bytes(rows, cols) +} + +pub fn encrypt_grlwe_sk_thread_safe( + module: &Module, + ct: &mut GadgetCiphertext, + m: &Scalar, + sk: &SvpPPol, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + tmp_bytes: &mut [u8], +) { + let rows: usize = ct.rows(); + let log_q: usize = ct.log_q(); + let log_base2k: usize = ct.log_base2k(); + + let min_tmp_bytes_len = encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q); + + assert!( + tmp_bytes.len() >= min_tmp_bytes_len, + "invalid tmp_bytes: tmp_bytes.len()={} < encrypt_grlwe_sk_tmp_bytes={}", + tmp_bytes.len(), + min_tmp_bytes_len + ); + + let mut ptr: usize = 0; + let mut tmp_elem: Elem = Elem::from_bytes(module, log_base2k, ct.log_q(), 1, tmp_bytes); + let bytes_of_elem: usize = Elem::bytes_of(module, log_base2k, log_q, 1); + ptr += bytes_of_elem; + + let mut tmp_pt: Plaintext = + Plaintext::from_bytes(module, log_base2k, log_q, &mut tmp_bytes[ptr..]); + ptr += Plaintext::bytes_of(module, log_base2k, log_q); + + let (tmp_bytes_encrypt_sk, tmp_bytes_vmp_prepare_row) = + tmp_bytes[ptr..].split_at_mut(encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q)); + + (0..rows).for_each(|row_i| { + // Sets the i-th row of the RLWE sample to m (i.e. m * 2^{-log_base2k*i}) + tmp_pt.0.value[0].at_mut(row_i).copy_from_slice(&m.0); + + // Encrypts RLWE(m * 2^{-log_base2k*i}) + encrypt_rlwe_sk_thread_safe( + module, + &mut tmp_elem, + Some(&tmp_pt.0), + sk, + source_xa, + source_xe, + sigma, + tmp_bytes_encrypt_sk, + ); + + // Zeroes the ith-row of tmp_pt + tmp_pt.0.value[0].at_mut(row_i).fill(0); + + // GRLWE[row_i][0] = -as + m * 2^{-i*log_base2k} + e*2^{-log_q} + module.vmp_prepare_row( + &mut ct.value[0], + tmp_elem.at(0), + row_i, + tmp_bytes_vmp_prepare_row, + ); + + // GRLWE[row_i][1] = a + module.vmp_prepare_row( + &mut ct.value[1], + tmp_elem.at(1), + row_i, + tmp_bytes_vmp_prepare_row, + ); + }) } diff --git a/rlwe/src/evaluator.rs b/rlwe/src/evaluator.rs new file mode 100644 index 0000000..af9c041 --- /dev/null +++ b/rlwe/src/evaluator.rs @@ -0,0 +1,82 @@ +use crate::ciphertext::{Ciphertext, GadgetCiphertext}; +use base2k::{Module, VecZnxBig, VecZnxDft, VmpPMatOps}; + +pub fn gadget_product_tmp_bytes( + module: &Module, + log_base2k: usize, + out_log_q: usize, + in_log_q: usize, + gct_rows: usize, + gct_log_q: usize, +) -> usize { + let gct_cols: usize = (gct_log_q + log_base2k - 1) / log_base2k; + let in_limbs: usize = (in_log_q + log_base2k - 1) / log_base2k; + let out_limbs: usize = (out_log_q + log_base2k - 1) / log_base2k; + module.vmp_apply_dft_to_dft_tmp_bytes(out_limbs, in_limbs, gct_rows, gct_cols) + + 2 * module.bytes_of_vec_znx_dft(gct_cols) +} + +pub fn gadget_product_inplace( + module: &Module, + a: &mut Ciphertext, + b: &GadgetCiphertext, + tmp_bytes: &mut [u8], +) { + // This is safe to do because the relevant values of a are copied to a buffer before being + // overwritten. + unsafe { + let a_ptr: *mut Ciphertext = a; + gadget_product(module, a, &*a_ptr, b, tmp_bytes) + } +} + +pub fn gadget_product( + module: &Module, + res: &mut Ciphertext, + a: &Ciphertext, + b: &GadgetCiphertext, + tmp_bytes: &mut [u8], +) { + assert!( + a.log_base2k() == b.log_base2k(), + "invalid inputs: a.log_base2k={} != b.log_base2k={}", + a.log_base2k(), + b.log_base2k() + ); + + let log_base2k: usize = b.log_base2k(); + let cols: usize = b.cols(); + + let (tmp_bytes_vmp_apply_dft, tmp_bytes) = + tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols)); + + let (tmp_bytes_c1_dft, tmp_bytes_res_dft) = tmp_bytes.split_at_mut(tmp_bytes.len() >> 1); + + let mut c1_dft: VecZnxDft = module.new_vec_znx_from_bytes(cols, tmp_bytes_c1_dft); + let mut res_dft: VecZnxDft = module.new_vec_znx_from_bytes(cols, tmp_bytes_res_dft); + let mut res_big: VecZnxBig = res_dft.as_vec_znx_big(); + + // c1_dft <- DFT(b[1]) + module.vec_znx_dft(&mut c1_dft, a.at(1), a.limbs()); + + // res_dft <- DFT(c1) x GadgetCiphertext[0] + module.vmp_apply_dft_to_dft(&mut res_dft, &c1_dft, &b.value[0], tmp_bytes_vmp_apply_dft); + + // res_big <- IDFT(DFT(c1) x GadgetCiphertext[0]) + module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft, cols); + + // res_big <- c0 + c1_dft x GadgetCiphertext[0] + module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0), cols); + + // res[0] = normalize(c0 + c1_dft x GadgetCiphertext[0]) + module.vec_znx_big_normalize(log_base2k, res.at_mut(0), &res_big, tmp_bytes_vmp_apply_dft); + + // res_dft <- DFT(c1) x GadgetCiphertext[1] + module.vmp_apply_dft_to_dft(&mut res_dft, &c1_dft, &b.value[1], tmp_bytes_vmp_apply_dft); + + // res_big <- IDFT(DFT(c1) x GadgetCiphertext[1]) + module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft, cols); + + // res[1] = normalize(c1_dft x GadgetCiphertext[1]) + module.vec_znx_big_normalize(log_base2k, res.at_mut(1), &res_big, tmp_bytes_vmp_apply_dft); +} diff --git a/rlwe/src/key_generator.rs b/rlwe/src/key_generator.rs index 1cfaf1f..958db6a 100644 --- a/rlwe/src/key_generator.rs +++ b/rlwe/src/key_generator.rs @@ -1 +1,53 @@ +use crate::keys::{PublicKey, SecretKey, SwitchingKey}; +use crate::parameters::Parameters; +use base2k::SvpPPol; +use sampling::source::Source; + pub struct KeyGenerator {} + +impl KeyGenerator { + pub fn gen_secret_key_thread_safe( + &self, + params: &Parameters, + source: &mut Source, + ) -> SecretKey { + let mut sk: SecretKey = SecretKey::new(params.module()); + sk.fill_ternary_hw(params.xs(), source); + sk + } + + pub fn gen_public_key_thread_safe( + &self, + params: &Parameters, + sk_ppol: &SvpPPol, + source: &mut Source, + tmp_bytes: &mut [u8], + ) -> PublicKey { + let mut xa_source: Source = source.branch(); + let mut xe_source: Source = source.branch(); + let mut pk: PublicKey = + PublicKey::new(params.module(), params.log_base2k(), params.log_qp()); + pk.gen_thread_safe( + params.module(), + sk_ppol, + params.xe(), + &mut xa_source, + &mut xe_source, + tmp_bytes, + ); + pk + } + + pub fn gen_switching_key( + &self, + params: &Parameters, + sk_in: &SecretKey, + sk_out: &SecretKey, + rows: usize, + log_q: usize, + ) -> SwitchingKey { + let swk = SwitchingKey::new(params.module(), params.log_base2k(), rows, log_q, 0); + + swk + } +} diff --git a/rlwe/src/keys.rs b/rlwe/src/keys.rs index a7a6bbe..cf4edc0 100644 --- a/rlwe/src/keys.rs +++ b/rlwe/src/keys.rs @@ -1,44 +1,86 @@ +use crate::ciphertext::GadgetCiphertext; use crate::elem::Elem; -use crate::encryptor::{encrypt_rlwe_sk, encrypt_rlwe_sk_tmp_bytes}; +use crate::encryptor::{encrypt_rlwe_sk_thread_safe, encrypt_rlwe_sk_tmp_bytes}; use crate::parameters::Parameters; -use base2k::{Module, Sampling, Scalar, SvpPPol, SvpPPolOps, VecZnx}; +use base2k::{Module, Scalar, SvpPPol, SvpPPolOps, VmpPMat, VmpPMatOps}; use sampling::source::Source; pub struct SecretKey(pub Scalar); impl SecretKey { - pub fn new_ternary_prob(module: &Module, limbs: usize, prob: f64, source: &mut Source) -> Self { - let mut sk: Scalar = Scalar::new(module.n()); - sk.fill_ternary_prob(prob, source); - SecretKey(sk) + pub fn new(params: &Module) -> Self { + SecretKey(Scalar::new(params.n())) + } + + pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) { + self.0.fill_ternary_prob(prob, source); + } + + pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) { + self.0.fill_ternary_hw(hw, source); + } + + pub fn prepare(&self, module: &Module, sk_ppol: &mut SvpPPol) { + module.svp_prepare(sk_ppol, &self.0) } } pub struct PublicKey(pub Elem); impl PublicKey { - pub fn new( - params: &Parameters, + pub fn new(module: &Module, log_base2k: usize, log_q: usize) -> PublicKey { + PublicKey(Elem::new(module, log_base2k, log_q, 1, 0)) + } + + pub fn gen_thread_safe( + &mut self, + module: &Module, sk: &SvpPPol, + xe: f64, xa_source: &mut Source, xe_source: &mut Source, tmp_bytes: &mut [u8], - ) -> Self { - let mut pk: Elem = Elem::new(params.n(), params.log_base2k(), params.log_qp(), 1); - encrypt_rlwe_sk( - params.module(), - &mut pk, + ) { + encrypt_rlwe_sk_thread_safe( + module, + &mut self.0, None, sk, xa_source, xe_source, - params.xe(), + xe, tmp_bytes, ); - PublicKey(pk) } - pub fn new_tmp_bytes(params: &Parameters) -> usize { - encrypt_rlwe_sk_tmp_bytes(params.module(), params.limbs_qp()) + pub fn gen_thread_safe_tmp_bytes(module: &Module, log_base2k: usize, log_q: usize) -> usize { + encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q) + } +} + +pub struct SwitchingKey(GadgetCiphertext); + +impl SwitchingKey { + pub fn new( + module: &Module, + log_base2k: usize, + rows: usize, + log_q: usize, + log_scale: usize, + ) -> SwitchingKey { + SwitchingKey(GadgetCiphertext::new( + module, log_base2k, rows, log_q, log_scale, + )) + } + + pub fn gen_thread_safe( + &mut self, + params: &mut Parameters, + sk_in: &SvpPPol, + sk_out: &SvpPPol, + xa_source: &mut Source, + xe_source: &mut Source, + tmp_bytes: &mut [u8], + ) { } } diff --git a/rlwe/src/lib.rs b/rlwe/src/lib.rs index 5bb7cc0..a559681 100644 --- a/rlwe/src/lib.rs +++ b/rlwe/src/lib.rs @@ -1,6 +1,8 @@ pub mod ciphertext; +pub mod decryptor; pub mod elem; pub mod encryptor; +pub mod evaluator; pub mod key_generator; pub mod keys; pub mod parameters; diff --git a/rlwe/src/plaintext.rs b/rlwe/src/plaintext.rs index aeec2bd..174a07d 100644 --- a/rlwe/src/plaintext.rs +++ b/rlwe/src/plaintext.rs @@ -1,20 +1,35 @@ use crate::ciphertext::Ciphertext; use crate::elem::Elem; -use base2k::VecZnx; +use crate::parameters::Parameters; +use base2k::{Module, VecZnx}; pub struct Plaintext(pub Elem); -/* impl Parameters { pub fn new_plaintext(&self, log_q: usize) -> Plaintext { - Plaintext(self.new_elem(0, log_q)) + Plaintext::new(self.module(), self.log_base2k(), log_q, self.log_scale()) + } + + pub fn bytes_of_plaintext(&self, log_q: usize) -> usize { + Elem::bytes_of(self.module(), self.log_base2k(), log_q, 0) + } + + pub fn plaintext_from_bytes(&self, log_q: usize, bytes: &mut [u8]) -> Plaintext { + Plaintext(self.elem_from_bytes(log_q, 0, bytes)) } } -*/ impl Plaintext { - pub fn new(n: usize, log_base2k: usize, log_q: usize) -> Self { - Self(Elem::new(n, log_base2k, log_q, 0)) + pub fn new(module: &Module, log_base2k: usize, log_q: usize, log_scale: usize) -> Self { + Self(Elem::new(module, log_base2k, log_q, 0, log_scale)) + } + + pub fn bytes_of(module: &Module, log_base2k: usize, log_q: usize) -> usize { + Elem::bytes_of(module, log_base2k, log_q, 0) + } + + pub fn from_bytes(module: &Module, log_base2k: usize, log_q: usize, bytes: &mut [u8]) -> Self { + Self(Elem::from_bytes(module, log_base2k, log_q, 0, bytes)) } pub fn n(&self) -> usize { @@ -45,6 +60,10 @@ impl Plaintext { self.0.log_base2k() } + pub fn log_scale(&self) -> usize { + self.0.log_scale() + } + pub fn as_ciphertext(&self) -> Ciphertext { unsafe { Ciphertext(std::ptr::read(&self.0)) } }