From 71f33f59832995ee11b0cd7ca46a17b98d716460 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 18 Feb 2025 17:15:24 +0100 Subject: [PATCH] wip on generic traits --- rlwe/benches/gadget_product.rs | 40 +++-------- rlwe/src/ciphertext.rs | 52 +++++--------- rlwe/src/decryptor.rs | 12 ++-- rlwe/src/elem.rs | 126 ++++++++++++--------------------- rlwe/src/encryptor.rs | 20 +++--- rlwe/src/evaluator.rs | 22 +++--- rlwe/src/plaintext.rs | 14 ++-- 7 files changed, 104 insertions(+), 182 deletions(-) diff --git a/rlwe/benches/gadget_product.rs b/rlwe/benches/gadget_product.rs index 457dfda..fc75341 100644 --- a/rlwe/benches/gadget_product.rs +++ b/rlwe/benches/gadget_product.rs @@ -1,18 +1,15 @@ -use base2k::{ - FFT64, Module, Sampling, SvpPPolOps, VecZnx, VecZnxBig, VecZnxDft, VecZnxDftOps, VmpPMat, - VmpPMatOps, alloc_aligned_u8, -}; +use base2k::{FFT64, Module, SvpPPolOps, VecZnx, VmpPMat, alloc_aligned_u8}; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; use rlwe::{ ciphertext::{Ciphertext, new_gadget_ciphertext}, elem::Elem, encryptor::{encrypt_grlwe_sk_thread_safe, encrypt_grlwe_sk_tmp_bytes}, - evaluator::gadget_product_tmp_bytes, + evaluator::{gadget_product_inplace_thread_safe, gadget_product_tmp_bytes}, key_generator::gen_switching_key_thread_safe_tmp_bytes, keys::SecretKey, parameters::{Parameters, ParametersLiteral}, }; -use sampling::source::{Source, new_seed}; +use sampling::source::Source; fn gadget_product_inplace(c: &mut Criterion) { fn gadget_product<'a>( @@ -21,31 +18,8 @@ fn gadget_product_inplace(c: &mut Criterion) { gadget_ct: &'a Ciphertext, tmp_bytes: &'a mut [u8], ) -> Box { - let factor: usize = 2; - - let log_base2k: usize = 32; - let limbs: usize = 2; - let rows: usize = factor * limbs; - let cols: usize = factor * limbs + 1; - - let pmat: VmpPMat = module.new_vmp_pmat(rows, cols); - - let mut tmp_bytes: Vec = - alloc_aligned_u8(module.vmp_apply_dft_tmp_bytes(cols, rows, rows, cols), 64); - - let mut a_dft: VecZnxDft = module.new_vec_znx_dft(rows); - let mut res_dft: VecZnxDft = module.new_vec_znx_dft(cols); - let mut res_big: VecZnxBig = res_dft.as_vec_znx_big(); - let mut a: VecZnx = VecZnx::new(module.n(), rows); - let mut source = Source::new(new_seed()); - module.fill_uniform(log_base2k, &mut a, limbs, &mut source); - Box::new(move || { - module.vec_znx_dft(&mut a_dft, &a, rows); - module.vmp_apply_dft_to_dft(&mut res_dft, &mut a_dft, &pmat, &mut tmp_bytes); - module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft, cols); - - //gadget_product_inplace_thread_safe::(module, elem, gadget_ct, tmp_bytes) + gadget_product_inplace_thread_safe::(module, elem, gadget_ct, tmp_bytes) }) } @@ -55,9 +29,9 @@ fn gadget_product_inplace(c: &mut Criterion) { for log_n in 10..11 { let params_lit: ParametersLiteral = ParametersLiteral { log_n: log_n, - log_q: 54, + log_q: 32, log_p: 0, - log_base2k: 7, + log_base2k: 16, log_scale: 20, xe: 3.2, xs: 128, @@ -95,6 +69,8 @@ fn gadget_product_inplace(c: &mut Criterion) { let mut sk0: SecretKey = SecretKey::new(params.module()); let mut sk1: SecretKey = SecretKey::new(params.module()); + sk0.fill_ternary_hw(params.xs(), &mut source); + sk1.fill_ternary_hw(params.xs(), &mut source); let mut source_xe: Source = Source::new([4; 32]); let mut source_xa: Source = Source::new([5; 32]); diff --git a/rlwe/src/ciphertext.rs b/rlwe/src/ciphertext.rs index 0147fa9..bf9a231 100644 --- a/rlwe/src/ciphertext.rs +++ b/rlwe/src/ciphertext.rs @@ -1,7 +1,7 @@ use crate::elem::{Elem, ElemVecZnx, VecZnxCommon}; use crate::parameters::Parameters; use crate::plaintext::Plaintext; -use base2k::{Infos, Module, VecZnx, VecZnxApi, VmpPMat}; +use base2k::{Infos, Module, VecZnx, VmpPMat}; pub struct Ciphertext(pub Elem); @@ -13,8 +13,20 @@ impl Ciphertext { impl Ciphertext where - T: VecZnxCommon, - Elem: Infos + ElemVecZnx, + T: VecZnxCommon, +{ + pub fn zero(&mut self) { + self.0.zero() + } + + pub fn as_plaintext(&self) -> Plaintext { + unsafe { Plaintext::(std::ptr::read(&self.0)) } + } +} + +impl Ciphertext +where + T: Infos, { pub fn n(&self) -> usize { self.0.n() @@ -47,14 +59,6 @@ where pub fn log_scale(&self) -> usize { self.0.log_scale } - - pub fn zero(&mut self) { - self.0.zero() - } - - pub fn as_plaintext(&self) -> Plaintext { - unsafe { Plaintext::(std::ptr::read(&self.0)) } - } } impl Parameters { @@ -70,7 +74,7 @@ pub fn new_gadget_ciphertext( log_q: usize, ) -> Ciphertext { let cols: usize = (log_q + log_base2k - 1) / log_base2k; - let mut elem: Elem = Elem::::new(module, log_base2k, rows, 2 * cols); + let mut elem: Elem = Elem::::new(module, log_base2k, 1, rows, 2 * cols); elem.log_q = log_q; Ciphertext(elem) } @@ -82,29 +86,7 @@ pub fn new_rgsw_ciphertext( log_q: usize, ) -> Ciphertext { let cols: usize = (log_q + log_base2k - 1) / log_base2k; - let mut elem: Elem = Elem::::new(module, log_base2k, 2 * rows, 2 * cols); + let mut elem: Elem = Elem::::new(module, log_base2k, 2, rows, 2 * cols); elem.log_q = log_q; Ciphertext(elem) } - -impl Ciphertext { - pub fn n(&self) -> usize { - self.0.n() - } - - pub fn rows(&self) -> usize { - self.0.rows() - } - - pub fn cols(&self) -> usize { - self.0.cols() - } - - pub fn log_base2k(&self) -> usize { - self.0.log_base2k - } - - pub fn log_q(&self) -> usize { - self.0.log_q - } -} diff --git a/rlwe/src/decryptor.rs b/rlwe/src/decryptor.rs index 81725c6..4074315 100644 --- a/rlwe/src/decryptor.rs +++ b/rlwe/src/decryptor.rs @@ -5,9 +5,7 @@ use crate::{ parameters::Parameters, plaintext::Plaintext, }; -use base2k::{ - Infos, Module, SvpPPol, SvpPPolOps, VecZnx, VecZnxApi, VecZnxBigOps, VecZnxDft, VecZnxDftOps, -}; +use base2k::{Module, SvpPPol, SvpPPolOps, VecZnxBigOps, VecZnxDft, VecZnxDftOps}; use std::cmp::min; pub struct Decryptor { @@ -41,8 +39,8 @@ impl Parameters { sk: &SvpPPol, tmp_bytes: &mut [u8], ) where - T: VecZnxCommon, - Elem: Infos + ElemVecZnx, + T: VecZnxCommon, + Elem: ElemVecZnx, { decrypt_rlwe_thread_safe(self.module(), &mut res.0, &ct.0, sk, tmp_bytes) } @@ -55,8 +53,8 @@ pub fn decrypt_rlwe_thread_safe( sk: &SvpPPol, tmp_bytes: &mut [u8], ) where - T: VecZnxCommon, - Elem: Infos + ElemVecZnx, + T: VecZnxCommon, + Elem: ElemVecZnx, { let cols: usize = a.cols(); diff --git a/rlwe/src/elem.rs b/rlwe/src/elem.rs index d66eb93..31ce48e 100644 --- a/rlwe/src/elem.rs +++ b/rlwe/src/elem.rs @@ -3,12 +3,12 @@ use base2k::{Infos, Module, VecZnx, VecZnxApi, VecZnxBorrow, VecZnxOps, VmpPMat, use crate::parameters::Parameters; impl Parameters { - pub fn elem_from_bytes(&self, log_q: usize, rows: usize, bytes: &mut [u8]) -> Elem + pub fn elem_from_bytes(&self, log_q: usize, size: usize, bytes: &mut [u8]) -> Elem where - T: VecZnxCommon, - Elem: Infos + ElemVecZnx, + T: VecZnxCommon, + Elem: ElemVecZnx, { - Elem::::from_bytes(self.module(), self.log_base2k(), log_q, rows, bytes) + Elem::::from_bytes(self.module(), self.log_base2k(), log_q, size, bytes) } } @@ -23,47 +23,45 @@ pub trait VecZnxCommon: VecZnxApi + Infos {} impl VecZnxCommon for VecZnx {} impl VecZnxCommon for VecZnxBorrow {} -pub trait ElemVecZnx { +pub trait ElemVecZnx> { fn from_bytes( module: &Module, log_base2k: usize, log_q: usize, - rows: usize, + size: usize, bytes: &mut [u8], ) -> Elem; - fn bytes_of(module: &Module, log_base2k: usize, log_q: usize, rows: usize) -> usize; - fn at(&self, i: usize) -> &T; - fn at_mut(&mut self, i: usize) -> &mut T; + fn bytes_of(module: &Module, log_base2k: usize, log_q: usize, size: usize) -> usize; fn zero(&mut self); } impl ElemVecZnx for Elem where T: VecZnxCommon, - Elem: Infos, { - fn bytes_of(module: &Module, log_base2k: usize, log_q: usize, rows: usize) -> usize { + fn bytes_of(module: &Module, log_base2k: usize, log_q: usize, size: usize) -> usize { let cols = (log_q + log_base2k - 1) / log_base2k; - module.n() * cols * (rows + 1) * 8 + module.n() * cols * size * 8 } fn from_bytes( module: &Module, log_base2k: usize, log_q: usize, - rows: usize, + size: usize, bytes: &mut [u8], ) -> Elem { - assert!(rows > 0); + assert!(size > 0); let n: usize = module.n(); - assert!(bytes.len() >= Self::bytes_of(module, log_base2k, log_q, rows)); + assert!(bytes.len() >= Self::bytes_of(module, log_base2k, log_q, size)); let mut value: Vec = Vec::new(); let limbs: usize = (log_q + log_base2k - 1) / log_base2k; - let size = T::bytes_of(n, limbs); + let elem_size = T::bytes_of(n, limbs); let mut ptr: usize = 0; - (0..rows).for_each(|_| { + println!("{} {} {}", size, elem_size, bytes.len()); + (0..size).for_each(|_| { value.push(T::from_bytes(n, limbs, &mut bytes[ptr..])); - ptr += size + ptr += elem_size }); Self { value, @@ -73,22 +71,32 @@ where } } - fn at(&self, i: usize) -> &T { - assert!(i < self.rows()); - &self.value[i] - } - - fn at_mut(&mut self, i: usize) -> &mut T { - assert!(i < self.rows()); - &mut self.value[i] - } - fn zero(&mut self) { self.value.iter_mut().for_each(|i| i.zero()); } } -impl Elem { +impl Elem { + pub fn n(&self) -> usize { + self.value[0].n() + } + + pub fn log_n(&self) -> usize { + self.value[0].log_n() + } + + pub fn size(&self) -> usize { + self.value.len() + } + + pub fn rows(&self) -> usize { + self.value[0].rows() + } + + pub fn cols(&self) -> usize { + self.value[0].cols() + } + pub fn log_base2k(&self) -> usize { self.log_base2k } @@ -100,39 +108,15 @@ impl Elem { pub fn log_scale(&self) -> usize { self.log_scale } -} -impl Infos for Elem { - fn n(&self) -> usize { - self.value[0].n() + pub fn at(&self, i: usize) -> &T { + assert!(i < self.size()); + &self.value[i] } - fn log_n(&self) -> usize { - self.value[0].log_n() - } - - fn rows(&self) -> usize { - self.value.len() - } - fn cols(&self) -> usize { - self.value[0].cols() - } -} - -impl Infos for Elem { - fn n(&self) -> usize { - self.value[0].n() - } - - fn log_n(&self) -> usize { - self.value[0].log_n() - } - - fn rows(&self) -> usize { - self.value.len() - } - fn cols(&self) -> usize { - self.value[0].cols() + pub fn at_mut(&mut self, i: usize) -> &mut T { + assert!(i < self.size()); + &mut self.value[i] } } @@ -151,30 +135,14 @@ impl Elem { } } -impl Infos for Elem { - fn n(&self) -> usize { - self.value[0].n() - } - - fn log_n(&self) -> usize { - self.value[0].log_n() - } - - fn rows(&self) -> usize { - self.value[0].rows() - } - - fn cols(&self) -> usize { - self.value[0].cols() - } -} - impl Elem { - pub fn new(module: &Module, log_base2k: usize, rows: usize, cols: usize) -> Self { + pub fn new(module: &Module, log_base2k: usize, size: usize, rows: usize, cols: usize) -> Self { assert!(rows > 0); assert!(cols > 0); + let mut value: Vec = Vec::new(); + (0..size).for_each(|_| value.push(module.new_vmp_pmat(rows, cols))); Self { - value: Vec::from([module.new_vmp_pmat(rows, cols); 1]), + value: value, log_q: 0, log_base2k: log_base2k, log_scale: 0, diff --git a/rlwe/src/encryptor.rs b/rlwe/src/encryptor.rs index f57d9b7..8abfcb0 100644 --- a/rlwe/src/encryptor.rs +++ b/rlwe/src/encryptor.rs @@ -5,8 +5,8 @@ use crate::parameters::Parameters; use crate::plaintext::Plaintext; use base2k::sampling::Sampling; use base2k::{ - Infos, Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxApi, VecZnxBig, VecZnxBigOps, - VecZnxBorrow, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, cast_mut, + Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxApi, VecZnxBig, VecZnxBigOps, VecZnxBorrow, + VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, cast_mut, }; use sampling::source::{Source, new_seed}; @@ -55,8 +55,8 @@ impl EncryptorSk { ct: &mut Ciphertext, pt: Option<&Plaintext>, ) where - T: VecZnxCommon, - Elem: Infos + ElemVecZnx, + T: VecZnxCommon, + Elem: ElemVecZnx, { assert!( self.initialized == true, @@ -81,8 +81,8 @@ impl EncryptorSk { source_xe: &mut Source, tmp_bytes: &mut [u8], ) where - T: VecZnxCommon, - Elem: Infos + ElemVecZnx, + T: VecZnxCommon, + Elem: ElemVecZnx, { assert!( self.initialized == true, @@ -106,8 +106,8 @@ impl Parameters { source_xe: &mut Source, tmp_bytes: &mut [u8], ) where - T: VecZnxCommon, - Elem: Infos + ElemVecZnx, + T: VecZnxCommon, + Elem: ElemVecZnx, { encrypt_rlwe_sk_thread_safe( self.module(), @@ -137,8 +137,8 @@ pub fn encrypt_rlwe_sk_thread_safe( sigma: f64, tmp_bytes: &mut [u8], ) where - T: VecZnxCommon, - Elem: Infos + ElemVecZnx, + T: VecZnxCommon, + Elem: ElemVecZnx, { let cols: usize = ct.cols(); let log_base2k: usize = ct.log_base2k(); diff --git a/rlwe/src/evaluator.rs b/rlwe/src/evaluator.rs index cdd834f..e224af2 100644 --- a/rlwe/src/evaluator.rs +++ b/rlwe/src/evaluator.rs @@ -2,9 +2,7 @@ use crate::{ ciphertext::Ciphertext, elem::{Elem, ElemVecZnx, VecZnxCommon}, }; -use base2k::{ - Infos, Module, VecZnxApi, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps, -}; +use base2k::{Module, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps}; use std::cmp::min; pub fn gadget_product_tmp_bytes( @@ -22,14 +20,14 @@ pub fn gadget_product_tmp_bytes( + 2 * module.bytes_of_vec_znx_dft(gct_cols) } -pub fn gadget_product_inplace_thread_safe + Infos>( +pub fn gadget_product_inplace_thread_safe( module: &Module, res: &mut Elem, b: &Ciphertext, tmp_bytes: &mut [u8], ) where - T: VecZnxCommon, - Elem: Infos + ElemVecZnx, + T: VecZnxCommon, + Elem: ElemVecZnx, { unsafe { let a_ptr: *const T = res.at(1) as *const T; @@ -51,15 +49,15 @@ pub fn gadget_product_inplace_thread_safe + Infos>( +pub fn gadget_product_thread_safe( module: &Module, res: &mut Elem, a: &T, b: &Ciphertext, tmp_bytes: &mut [u8], ) where - T: VecZnxCommon, - Elem: Infos + ElemVecZnx, + T: VecZnxCommon, + Elem: ElemVecZnx, { let log_base2k: usize = b.log_base2k(); let rows: usize = min(b.rows(), a.cols()); @@ -112,15 +110,15 @@ pub fn gadget_product_thread_safe } } -pub fn rgsw_product_thread_safe + Infos>( +pub fn rgsw_product_thread_safe( module: &Module, res: &mut Elem, a: &Ciphertext, b: &Ciphertext, tmp_bytes: &mut [u8], ) where - T: VecZnxCommon, - Elem: Infos + ElemVecZnx, + T: VecZnxCommon, + Elem: ElemVecZnx, { let log_base2k: usize = b.log_base2k(); let rows: usize = min(b.rows(), a.cols()); diff --git a/rlwe/src/plaintext.rs b/rlwe/src/plaintext.rs index 38ae6da..36f9932 100644 --- a/rlwe/src/plaintext.rs +++ b/rlwe/src/plaintext.rs @@ -1,7 +1,7 @@ use crate::ciphertext::Ciphertext; use crate::elem::{Elem, ElemVecZnx, VecZnxCommon}; use crate::parameters::Parameters; -use base2k::{Infos, Module, VecZnx, VecZnxApi}; +use base2k::{Module, VecZnx}; pub struct Plaintext(pub Elem); @@ -12,16 +12,16 @@ impl Parameters { pub fn bytes_of_plaintext(&self, log_q: usize) -> usize where - T: VecZnxCommon, - Elem: Infos + ElemVecZnx, + T: VecZnxCommon, + Elem: ElemVecZnx, { Elem::::bytes_of(self.module(), self.log_base2k(), log_q, 1) } pub fn plaintext_from_bytes(&self, log_q: usize, bytes: &mut [u8]) -> Plaintext where - T: VecZnxCommon, - Elem: Infos + ElemVecZnx, + T: VecZnxCommon, + Elem: ElemVecZnx, { Plaintext::(self.elem_from_bytes::(log_q, 1, bytes)) } @@ -35,8 +35,8 @@ impl Plaintext { impl Plaintext where - T: VecZnxCommon, - Elem: Infos + ElemVecZnx, + T: VecZnxCommon, + Elem: ElemVecZnx, { pub fn bytes_of(module: &Module, log_base2k: usize, log_q: usize) -> usize { Elem::::bytes_of(module, log_base2k, log_q, 1)