From 46c577409eecd89c637d6523b63a31998efbc3d7 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Mon, 17 Mar 2025 12:07:40 +0100 Subject: [PATCH] Various improvement to memory management and API [module]: added enum for backend [VecZnx, VecZnxDft, VecZnxBig, VmpPMat]: added ptr to data [VecZnxBorrow]: removed [VecZnxAPI]: removed --- .vscode/settings.json | 3 +- Cargo.lock | 1 + base2k/examples/rlwe_encrypt.rs | 8 +- base2k/examples/vector_matrix_product.rs | 14 +- base2k/spqlios-arithmetic | 2 +- base2k/src/encoding.rs | 51 +-- base2k/src/free.rs | 43 -- base2k/src/infos.rs | 70 ---- base2k/src/lib.rs | 10 - base2k/src/module.rs | 44 +- base2k/src/sampling.rs | 18 +- base2k/src/svp.rs | 79 ++-- base2k/src/vec_znx.rs | 505 +++++++---------------- base2k/src/vec_znx_big.rs | 166 +++++--- base2k/src/vec_znx_dft.rs | 133 ++++-- base2k/src/vmp.rs | 316 +++++--------- rlwe/Cargo.toml | 1 + rlwe/benches/gadget_product.rs | 7 +- rlwe/examples/encryption.rs | 17 +- rlwe/examples/rlk_experiments.rs | 151 +++++++ rlwe/src/decryptor.rs | 39 +- rlwe/src/elem.rs | 66 +-- rlwe/src/encryptor.rs | 92 ++--- rlwe/src/gadget_product.rs | 26 +- rlwe/src/keys.rs | 4 +- rlwe/src/parameters.rs | 5 +- rlwe/src/plaintext.rs | 72 ++-- rlwe/src/rgsw_product.rs | 17 +- 28 files changed, 896 insertions(+), 1064 deletions(-) delete mode 100644 base2k/src/free.rs create mode 100644 rlwe/examples/rlk_experiments.rs diff --git a/.vscode/settings.json b/.vscode/settings.json index 6abc417..bf18bf9 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -56,6 +56,7 @@ "xlocnum": "cpp", "xloctime": "cpp", "xmemory": "cpp", - "xtr1common": "cpp" + "xtr1common": "cpp", + "vec_znx_arithmetic_private.h": "c" } } \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index c32b92a..99114f6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -655,6 +655,7 @@ version = "0.1.0" dependencies = [ "base2k", "criterion", + "itertools 0.14.0", "rand_distr", "rug", "sampling", diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index 9856362..592112f 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -1,6 +1,6 @@ use base2k::{ - Encoding, Infos, Module, Sampling, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxApi, VecZnxBig, - VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, FFT64, + alloc_aligned, Encoding, Infos, Module, Sampling, Scalar, SvpPPol, SvpPPolOps, VecZnx, + VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, MODULETYPE, }; use itertools::izip; use sampling::source::Source; @@ -11,9 +11,9 @@ fn main() { let cols: usize = 3; let msg_cols: usize = 2; let log_scale: usize = msg_cols * log_base2k - 5; - let module: Module = Module::new::(n); + let module: Module = Module::new(n, MODULETYPE::FFT64); - let mut carry: Vec = vec![0; module.vec_znx_big_normalize_tmp_bytes()]; + let mut carry: Vec = alloc_aligned(module.vec_znx_big_normalize_tmp_bytes()); let seed: [u8; 32] = [0; 32]; let mut source: Source = Source::new(seed); diff --git a/base2k/examples/vector_matrix_product.rs b/base2k/examples/vector_matrix_product.rs index 0eae265..cb2ba58 100644 --- a/base2k/examples/vector_matrix_product.rs +++ b/base2k/examples/vector_matrix_product.rs @@ -1,13 +1,13 @@ use base2k::{ - Encoding, Free, Infos, Module, VecZnx, VecZnxApi, VecZnxBig, VecZnxBigOps, VecZnxDft, - VecZnxDftOps, VecZnxOps, VecZnxVec, VmpPMat, VmpPMatOps, FFT64, + alloc_aligned, Encoding, Infos, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, + VecZnxDftOps, VecZnxOps, VecZnxVec, VmpPMat, VmpPMatOps, MODULETYPE, }; fn main() { let log_n: i32 = 5; let n: usize = 1 << log_n; - let module: Module = Module::new::(n); + let module: Module = Module::new(n, MODULETYPE::FFT64); let log_base2k: usize = 15; let cols: usize = 5; let log_k: usize = log_base2k * cols - 5; @@ -19,7 +19,7 @@ fn main() { let tmp_bytes: usize = module.vmp_prepare_tmp_bytes(rows, cols) | module.vmp_apply_dft_tmp_bytes(cols, cols, rows, cols); - let mut buf: Vec = vec![0; tmp_bytes]; + let mut buf: Vec = alloc_aligned(tmp_bytes); let mut a_values: Vec = vec![i64::default(); n]; a_values[1] = (1 << log_base2k) + 1; @@ -37,7 +37,7 @@ fn main() { }); (0..rows).for_each(|i| { - vecznx[i].data[i * n + 1] = 1 as i64; + vecznx[i].raw_mut()[i * n + 1] = 1 as i64; }); let slices: Vec<&[i64]> = vecznx.dblptr(); @@ -60,8 +60,6 @@ fn main() { res.print(res.cols(), n); module.free(); - c_dft.free(); - vmp_pmat.free(); - //println!("{:?}", values_res) + println!("{:?}", values_res) } diff --git a/base2k/spqlios-arithmetic b/base2k/spqlios-arithmetic index 5461131..52f002b 160000 --- a/base2k/spqlios-arithmetic +++ b/base2k/spqlios-arithmetic @@ -1 +1 @@ -Subproject commit 546113166e0e204cdfcd7a78ed96b6df7c457e40 +Subproject commit 52f002b13c4fbae044d732376f2bc6061289473d diff --git a/base2k/src/encoding.rs b/base2k/src/encoding.rs index fcea61f..3e8a188 100644 --- a/base2k/src/encoding.rs +++ b/base2k/src/encoding.rs @@ -1,5 +1,5 @@ use crate::ffi::znx::znx_zero_i64_ref; -use crate::{VecZnx, VecZnxBorrow, VecZnxCommon}; +use crate::{Infos, VecZnx}; use itertools::izip; use rug::{Assign, Float}; use std::cmp::min; @@ -89,42 +89,7 @@ impl Encoding for VecZnx { } } -impl Encoding for VecZnxBorrow { - fn encode_vec_i64(&mut self, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) { - encode_vec_i64(self, log_base2k, log_k, data, log_max) - } - - fn decode_vec_i64(&self, log_base2k: usize, log_k: usize, data: &mut [i64]) { - decode_vec_i64(self, log_base2k, log_k, data) - } - - fn decode_vec_float(&self, log_base2k: usize, data: &mut [Float]) { - decode_vec_float(self, log_base2k, data) - } - - fn encode_coeff_i64( - &mut self, - log_base2k: usize, - log_k: usize, - i: usize, - value: i64, - log_max: usize, - ) { - encode_coeff_i64(self, log_base2k, log_k, i, value, log_max) - } - - fn decode_coeff_i64(&self, log_base2k: usize, log_k: usize, i: usize) -> i64 { - decode_coeff_i64(self, log_base2k, log_k, i) - } -} - -fn encode_vec_i64( - a: &mut T, - log_base2k: usize, - log_k: usize, - data: &[i64], - log_max: usize, -) { +fn encode_vec_i64(a: &mut VecZnx, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) { let cols: usize = (log_k + log_base2k - 1) / log_base2k; debug_assert!( @@ -170,7 +135,7 @@ fn encode_vec_i64( } } -fn decode_vec_i64(a: &T, log_base2k: usize, log_k: usize, data: &mut [i64]) { +fn decode_vec_i64(a: &VecZnx, log_base2k: usize, log_k: usize, data: &mut [i64]) { let cols: usize = (log_k + log_base2k - 1) / log_base2k; debug_assert!( data.len() >= a.n(), @@ -194,7 +159,7 @@ fn decode_vec_i64(a: &T, log_base2k: usize, log_k: usize, data: }) } -fn decode_vec_float(a: &T, log_base2k: usize, data: &mut [Float]) { +fn decode_vec_float(a: &VecZnx, log_base2k: usize, data: &mut [Float]) { let cols: usize = a.cols(); debug_assert!( data.len() >= a.n(), @@ -224,8 +189,8 @@ fn decode_vec_float(a: &T, log_base2k: usize, data: &mut [Float }); } -fn encode_coeff_i64( - a: &mut T, +fn encode_coeff_i64( + a: &mut VecZnx, log_base2k: usize, log_k: usize, i: usize, @@ -247,7 +212,7 @@ fn encode_coeff_i64( // values on the last limb. // Else we decompose values base2k. if log_max + log_k_rem < 63 || log_k_rem == log_base2k { - a.at_mut(cols-1)[i] = value; + a.at_mut(cols - 1)[i] = value; } else { let mask: i64 = (1 << log_base2k) - 1; let steps: usize = min(cols, (log_max + log_base2k - 1) / log_base2k); @@ -268,7 +233,7 @@ fn encode_coeff_i64( } } -fn decode_coeff_i64(a: &T, log_base2k: usize, log_k: usize, i: usize) -> i64 { +fn decode_coeff_i64(a: &VecZnx, log_base2k: usize, log_k: usize, i: usize) -> i64 { let cols: usize = (log_k + log_base2k - 1) / log_base2k; debug_assert!(i < a.n()); let data: &[i64] = a.raw(); diff --git a/base2k/src/free.rs b/base2k/src/free.rs deleted file mode 100644 index 3ba787a..0000000 --- a/base2k/src/free.rs +++ /dev/null @@ -1,43 +0,0 @@ -use crate::ffi::svp; -use crate::ffi::vec_znx_big; -use crate::ffi::vec_znx_dft; -use crate::ffi::vmp; -use crate::{SvpPPol, VecZnxBig, VecZnxDft, VmpPMat}; - -/// This trait should be implemented by structs that point to -/// memory allocated through C. -pub trait Free { - // Frees the memory and self destructs. - fn free(self); -} - -impl Free for VmpPMat { - /// Frees the C allocated memory of the [VmpPMat] and self destructs the struct. - fn free(self) { - unsafe { vmp::delete_vmp_pmat(self.data) }; - drop(self); - } -} - -impl Free for VecZnxDft { - fn free(self) { - unsafe { vec_znx_dft::delete_vec_znx_dft(self.0) }; - drop(self); - } -} - -impl Free for VecZnxBig { - fn free(self) { - unsafe { - vec_znx_big::delete_vec_znx_big(self.0); - } - drop(self); - } -} - -impl Free for SvpPPol { - fn free(self) { - unsafe { svp::delete_svp_ppol(self.0) }; - let _ = drop(self); - } -} diff --git a/base2k/src/infos.rs b/base2k/src/infos.rs index 4714f01..6898c94 100644 --- a/base2k/src/infos.rs +++ b/base2k/src/infos.rs @@ -1,5 +1,3 @@ -use crate::{VecZnx, VecZnxBorrow, VmpPMat}; - pub trait Infos { /// Returns the ring degree of the receiver. fn n(&self) -> usize; @@ -14,71 +12,3 @@ pub trait Infos { /// Returns the number of rows of the receiver. fn rows(&self) -> usize; } - -impl Infos for VecZnx { - /// Returns the base 2 logarithm of the [VecZnx] degree. - fn log_n(&self) -> usize { - (usize::BITS - (self.n - 1).leading_zeros()) as _ - } - - /// Returns the [VecZnx] degree. - fn n(&self) -> usize { - self.n - } - - /// Returns the number of cols of the [VecZnx]. - fn cols(&self) -> usize { - self.data.len() / self.n - } - - /// Returns the number of rows of the [VecZnx]. - fn rows(&self) -> usize { - 1 - } -} - -impl Infos for VecZnxBorrow { - /// Returns the base 2 logarithm of the [VecZnx] degree. - fn log_n(&self) -> usize { - (usize::BITS - (self.n - 1).leading_zeros()) as _ - } - - /// Returns the [VecZnx] degree. - fn n(&self) -> usize { - self.n - } - - /// Returns the number of cols of the [VecZnx]. - fn cols(&self) -> usize { - self.cols - } - - /// Returns the number of rows of the [VecZnx]. - fn rows(&self) -> usize { - 1 - } -} - -impl Infos for VmpPMat { - /// Returns the ring dimension of the [VmpPMat]. - fn n(&self) -> usize { - self.n - } - - fn log_n(&self) -> usize { - (usize::BITS - (self.n() - 1).leading_zeros()) as _ - } - - /// Returns the number of rows (i.e. of [VecZnxDft]) of the [VmpPMat] - fn rows(&self) -> usize { - self.rows - } - - /// Returns the number of cols of the [VmpPMat]. - /// The number of cols refers to the number of cols - /// of each [VecZnxDft]. - /// This method is equivalent to [Self::cols]. - fn cols(&self) -> usize { - self.cols - } -} diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index 28888fb..8679bb2 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -8,7 +8,6 @@ pub mod encoding; )] // Other modules and exports pub mod ffi; -pub mod free; pub mod infos; pub mod module; pub mod sampling; @@ -20,7 +19,6 @@ pub mod vec_znx_dft; pub mod vmp; pub use encoding::*; -pub use free::*; pub use infos::*; pub use module::*; pub use sampling::*; @@ -124,11 +122,3 @@ pub fn alloc_aligned_custom(size: usize, align: usize) -> Vec { pub fn alloc_aligned(size: usize) -> Vec { alloc_aligned_custom::(size, DEFAULTALIGN) } - -fn alias_mut_slice_to_vec(slice: &[T]) -> Vec { - unsafe { - let ptr: *mut T = slice.as_ptr() as *mut T; - let len: usize = slice.len(); - Vec::from_raw_parts(ptr, len, len) - } -} diff --git a/base2k/src/module.rs b/base2k/src/module.rs index 6d1ce47..7a16147 100644 --- a/base2k/src/module.rs +++ b/base2k/src/module.rs @@ -1,26 +1,46 @@ use crate::ffi::module::{delete_module_info, module_info_t, new_module_info, MODULE}; -use crate::{Free, GALOISGENERATOR}; +use crate::GALOISGENERATOR; -pub type MODULETYPE = u8; -pub const FFT64: u8 = 0; -pub const NTT120: u8 = 1; +#[derive(Copy, Clone)] +#[repr(u8)] +pub enum MODULETYPE { + FFT64, + NTT120, +} -pub struct Module(pub *mut MODULE, pub usize); +pub struct Module { + pub ptr: *mut MODULE, + pub n: usize, + pub backend: MODULETYPE, +} impl Module { // Instantiates a new module. - pub fn new(n: usize) -> Self { + pub fn new(n: usize, module_type: MODULETYPE) -> Self { unsafe { - let m: *mut module_info_t = new_module_info(n as u64, MODULETYPE as u32); + let module_type_u32: u32; + match module_type { + MODULETYPE::FFT64 => module_type_u32 = 0, + MODULETYPE::NTT120 => module_type_u32 = 1, + } + let m: *mut module_info_t = new_module_info(n as u64, module_type_u32); if m.is_null() { panic!("Failed to create module."); } - Self(m, n) + Self { + ptr: m, + n: n, + backend: module_type, + } } } + pub fn backend(&self) -> MODULETYPE { + self.backend + } + pub fn n(&self) -> usize { - self.1 + self.n } pub fn log_n(&self) -> usize { @@ -53,11 +73,9 @@ impl Module { (gal_el as i64) * gen.signum() } -} -impl Free for Module { - fn free(self) { - unsafe { delete_module_info(self.0) } + pub fn free(self) { + unsafe { delete_module_info(self.ptr) } drop(self); } } diff --git a/base2k/src/sampling.rs b/base2k/src/sampling.rs index fb94930..416d3a6 100644 --- a/base2k/src/sampling.rs +++ b/base2k/src/sampling.rs @@ -1,16 +1,16 @@ -use crate::{Infos, Module, VecZnxApi}; +use crate::{Infos, Module, VecZnx}; use rand_distr::{Distribution, Normal}; use sampling::source::Source; -pub trait Sampling { +pub trait Sampling { /// Fills the first `cols` cols with uniform values in \[-2^{log_base2k-1}, 2^{log_base2k-1}\] - fn fill_uniform(&self, log_base2k: usize, a: &mut T, cols: usize, source: &mut Source); + fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, cols: usize, source: &mut Source); /// Adds vector sampled according to the provided distribution, scaled by 2^{-log_k} and bounded to \[-bound, bound\]. fn add_dist_f64>( &self, log_base2k: usize, - a: &mut T, + a: &mut VecZnx, log_k: usize, source: &mut Source, dist: D, @@ -21,7 +21,7 @@ pub trait Sampling { fn add_normal( &self, log_base2k: usize, - a: &mut T, + a: &mut VecZnx, log_k: usize, source: &mut Source, sigma: f64, @@ -29,8 +29,8 @@ pub trait Sampling { ); } -impl Sampling for Module { - fn fill_uniform(&self, log_base2k: usize, a: &mut T, cols: usize, source: &mut Source) { +impl Sampling for Module { + fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, cols: usize, source: &mut Source) { let base2k: u64 = 1 << log_base2k; let mask: u64 = base2k - 1; let base2k_half: i64 = (base2k >> 1) as i64; @@ -43,7 +43,7 @@ impl Sampling for Module { fn add_dist_f64>( &self, log_base2k: usize, - a: &mut T, + a: &mut VecZnx, log_k: usize, source: &mut Source, dist: D, @@ -79,7 +79,7 @@ impl Sampling for Module { fn add_normal( &self, log_base2k: usize, - a: &mut T, + a: &mut VecZnx, log_k: usize, source: &mut Source, sigma: f64, diff --git a/base2k/src/svp.rs b/base2k/src/svp.rs index bfb0c26..8f62195 100644 --- a/base2k/src/svp.rs +++ b/base2k/src/svp.rs @@ -1,13 +1,18 @@ use crate::ffi::svp; -use crate::{alias_mut_slice_to_vec, assert_alignement, Module, VecZnxApi, VecZnxDft}; +use crate::ffi::vec_znx_dft::vec_znx_dft_t; +use crate::{assert_alignement, Module, VecZnx, VecZnxDft}; -use crate::{alloc_aligned, cast, Infos}; +use crate::{alloc_aligned, cast_mut, Infos}; use rand::seq::SliceRandom; use rand_core::RngCore; use rand_distr::{Distribution, WeightedIndex}; use sampling::source::Source; -pub struct Scalar(pub Vec); +pub struct Scalar { + pub n: usize, + pub data: Vec, + pub ptr: *mut i64, +} impl Module { pub fn new_scalar(&self) -> Scalar { @@ -17,52 +22,70 @@ impl Module { impl Scalar { pub fn new(n: usize) -> Self { - Self(alloc_aligned::(n)) + let mut data: Vec = alloc_aligned::(n); + let ptr: *mut i64 = data.as_mut_ptr(); + Self { + n: n, + data: data, + ptr: ptr, + } } pub fn n(&self) -> usize { - self.0.len() + self.n } pub fn buffer_size(n: usize) -> usize { n } - pub fn from_buffer(&mut self, n: usize, buf: &mut [u8]) { + pub fn from_buffer(&mut self, n: usize, bytes: &mut [u8]) -> Self { let size: usize = Self::buffer_size(n); debug_assert!( - buf.len() >= size, - "invalid buffer: buf.len()={} < self.buffer_size(n={})={}", - buf.len(), + bytes.len() == size, + "invalid buffer: bytes.len()={} < self.buffer_size(n={})={}", + bytes.len(), n, size ); #[cfg(debug_assertions)] { - assert_alignement(buf.as_ptr()) + assert_alignement(bytes.as_ptr()) + } + unsafe { + let bytes_i64: &mut [i64] = cast_mut::(bytes); + let ptr: *mut i64 = bytes_i64.as_mut_ptr(); + Self { + n: n, + data: Vec::from_raw_parts(bytes_i64.as_mut_ptr(), bytes.len(), bytes.len()), + ptr: ptr, + } } - self.0 = alias_mut_slice_to_vec(cast::(&buf[..size])) } pub fn as_ptr(&self) -> *const i64 { - self.0.as_ptr() + self.ptr + } + + pub fn raw(&self) -> &[i64] { + unsafe { std::slice::from_raw_parts_mut(self.ptr, self.n) } } pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) { let choices: [i64; 3] = [-1, 0, 1]; let weights: [f64; 3] = [prob / 2.0, 1.0 - prob, prob / 2.0]; let dist: WeightedIndex = WeightedIndex::new(&weights).unwrap(); - self.0 + self.data .iter_mut() .for_each(|x: &mut i64| *x = choices[dist.sample(source)]); } pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) { assert!(hw <= self.n()); - self.0[..hw] + self.data[..hw] .iter_mut() .for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1); - self.0.shuffle(source); + self.data.shuffle(source); } } @@ -105,35 +128,23 @@ pub trait SvpPPolOps { /// Applies the [SvpPPol] x [VecZnxDft] product, where each limb of /// the [VecZnxDft] is multiplied with [SvpPPol]. - fn svp_apply_dft( - &self, - c: &mut VecZnxDft, - a: &SvpPPol, - b: &T, - b_cols: usize, - ); + fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx, b_cols: usize); } impl SvpPPolOps for Module { fn new_svp_ppol(&self) -> SvpPPol { - unsafe { SvpPPol(svp::new_svp_ppol(self.0), self.n()) } + unsafe { SvpPPol(svp::new_svp_ppol(self.ptr), self.n()) } } fn bytes_of_svp_ppol(&self) -> usize { - unsafe { svp::bytes_of_svp_ppol(self.0) as usize } + unsafe { svp::bytes_of_svp_ppol(self.ptr) as usize } } fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar) { - unsafe { svp::svp_prepare(self.0, svp_ppol.0, a.as_ptr()) } + unsafe { svp::svp_prepare(self.ptr, svp_ppol.0, a.as_ptr()) } } - fn svp_apply_dft( - &self, - c: &mut VecZnxDft, - a: &SvpPPol, - b: &T, - b_cols: usize, - ) { + fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx, b_cols: usize) { debug_assert!( c.cols() >= b_cols, "invalid c_vector: c_vector.cols()={} < b.cols()={}", @@ -142,8 +153,8 @@ impl SvpPPolOps for Module { ); unsafe { svp::svp_apply_dft( - self.0, - c.0, + self.ptr, + c.ptr as *mut vec_znx_dft_t, b_cols as u64, a.0, b.as_ptr(), diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index 6daa95e..aeb64f6 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -1,18 +1,35 @@ use crate::cast_mut; use crate::ffi::vec_znx; use crate::ffi::znx; -use crate::ffi::znx::znx_zero_i64_ref; -use crate::{alias_mut_slice_to_vec, alloc_aligned, assert_alignement}; +use crate::{alloc_aligned, assert_alignement}; use crate::{Infos, Module}; use itertools::izip; use std::cmp::min; +/// [VecZnx] represents a vector of small norm polynomials of Zn\[X\] with [i64] coefficients. +/// A [VecZnx] is composed of multiple Zn\[X\] polynomials stored in a single contiguous array +/// in the memory. +#[derive(Clone)] +pub struct VecZnx { + /// Polynomial degree. + n: usize, + + /// Number of columns. + cols: usize, + + /// Polynomial coefficients, as a contiguous array. Each col is equally spaced by n. + data: Vec, + + /// Pointer to data (data can be enpty if [VecZnx] borrows space instead of owning it). + ptr: *mut i64, +} + pub trait VecZnxVec { fn dblptr(&self) -> Vec<&[i64]>; fn dblptr_mut(&mut self) -> Vec<&mut [i64]>; } -impl VecZnxVec for Vec { +impl VecZnxVec for Vec { fn dblptr(&self) -> Vec<&[i64]> { self.iter().map(|v| v.raw()).collect() } @@ -22,328 +39,141 @@ impl VecZnxVec for Vec { } } -pub trait VecZnxApi: AsRef + AsMut { - type Owned: VecZnxCommon; - - fn from_bytes(n: usize, cols: usize, bytes: &mut [u8]) -> Self::Owned; - - /// Returns the minimum size of the [u8] array required to assign a - /// new backend array. - fn bytes_of(n: usize, cols: usize) -> usize; - - /// Copy the data of a onto self. - fn copy_from(&mut self, a: &A) - where - Self: AsMut; - - /// Returns the backing array. - fn raw(&self) -> &[i64]; - - /// Returns the mutable backing array. - fn raw_mut(&mut self) -> &mut [i64]; - - /// Returns a non-mutable pointer to the backing array. - fn as_ptr(&self) -> *const i64; - - /// Returns a mutable pointer to the backing array. - fn as_mut_ptr(&mut self) -> *mut i64; - - /// Returns a non-mutable reference to the i-th cols. - fn at(&self, i: usize) -> &[i64]; - - /// Returns a mutable reference to the i-th cols . - fn at_mut(&mut self, i: usize) -> &mut [i64]; - - /// Returns a non-mutable pointer to the i-th cols. - fn at_ptr(&self, i: usize) -> *const i64; - - /// Returns a mutable pointer to the i-th cols. - fn at_mut_ptr(&mut self, i: usize) -> *mut i64; - - /// Zeroes the backing array. - fn zero(&mut self); - - /// Normalization: propagates carry and ensures each coefficients - /// falls into the range [-2^{K-1}, 2^{K-1}]. - fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]); - - /// Right shifts the coefficients by k bits. - /// - /// # Arguments - /// - /// * `log_base2k`: the base two logarithm of the coefficients decomposition. - /// * `k`: the shift amount. - /// * `carry`: scratch space of size at least equal to self.n() * self.cols() << 3. - /// - /// # Panics - /// - /// The method will panic if carry.len() < self.n() * self.cols() << 3. - fn rsh(&mut self, log_base2k: usize, k: usize, carry: &mut [u8]); - - /// If self.n() > a.n(): Extracts X^{i*self.n()/a.n()} -> X^{i}. - /// If self.n() < a.n(): Extracts X^{i} -> X^{i*a.n()/self.n()}. - /// - /// # Arguments - /// - /// * `a`: the receiver polynomial in which the extracted coefficients are stored. - fn switch_degree(&self, a: &mut A) - where - Self: AsRef; - - fn print(&self, cols: usize, n: usize); -} - pub fn bytes_of_vec_znx(n: usize, cols: usize) -> usize { n * cols * 8 } -pub struct VecZnxBorrow { - pub n: usize, - pub cols: usize, - pub data: *mut i64, -} - -impl AsMut for VecZnxBorrow { - fn as_mut(&mut self) -> &mut VecZnxBorrow { - self - } -} - -impl AsRef for VecZnxBorrow { - fn as_ref(&self) -> &VecZnxBorrow { - self - } -} - -impl VecZnxCommon for VecZnxBorrow {} - -impl VecZnxApi for VecZnxBorrow { - type Owned = VecZnxBorrow; - - /// Returns a new struct implementing [VecZnxBorrow] with the provided data as backing array. - /// - /// The struct will *NOT* take ownership of buf[..[VecZnx::bytes_of]] - /// - /// User must ensure that data is properly alligned and that - /// the size of data is at least equal to [VecZnx::bytes_of]. - fn from_bytes(n: usize, cols: usize, bytes: &mut [u8]) -> Self::Owned { - let size = Self::bytes_of(n, cols); - debug_assert!( - bytes.len() >= size, - "invalid buffer: buf.len()={} < self.buffer_size(n={}, cols={})={}", - bytes.len(), - n, - cols, - size - ); - #[cfg(debug_assertions)] - { - assert_alignement(bytes.as_ptr()) - } - VecZnxBorrow { - n: n, - cols: cols, - data: cast_mut(&mut bytes[..size]).as_mut_ptr(), - } - } - - fn bytes_of(n: usize, cols: usize) -> usize { - bytes_of_vec_znx(n, cols) - } - - fn copy_from(&mut self, a: &A) - where - Self: AsMut, - { - copy_vec_znx_from::(self.as_mut(), a); - } - - fn as_ptr(&self) -> *const i64 { - self.data - } - - fn as_mut_ptr(&mut self) -> *mut i64 { - self.data - } - - fn raw(&self) -> &[i64] { - unsafe { std::slice::from_raw_parts(self.data, self.n * self.cols) } - } - - fn raw_mut(&mut self) -> &mut [i64] { - unsafe { std::slice::from_raw_parts_mut(self.data, self.n * self.cols) } - } - - fn at(&self, i: usize) -> &[i64] { - let n: usize = self.n(); - &self.raw()[n * i..n * (i + 1)] - } - - fn at_mut(&mut self, i: usize) -> &mut [i64] { - let n: usize = self.n(); - &mut self.raw_mut()[n * i..n * (i + 1)] - } - - fn at_ptr(&self, i: usize) -> *const i64 { - self.data.wrapping_add(self.n * i) - } - - fn at_mut_ptr(&mut self, i: usize) -> *mut i64 { - self.data.wrapping_add(self.n * i) - } - - fn zero(&mut self) { - unsafe { - znx_zero_i64_ref((self.n * self.cols) as u64, self.data); - } - } - - fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) { - normalize(log_base2k, self, carry) - } - - fn rsh(&mut self, log_base2k: usize, k: usize, carry: &mut [u8]) { - rsh(log_base2k, self, k, carry) - } - - fn switch_degree(&self, a: &mut A) - where - Self: AsRef, - { - switch_degree(a, self.as_ref()); - } - - fn print(&self, cols: usize, n: usize) { - (0..cols).for_each(|i| println!("{}: {:?}", i, &self.at(i)[..n])) - } -} - -impl VecZnxCommon for VecZnx {} - -impl VecZnxApi for VecZnx { - type Owned = VecZnx; - +impl VecZnx { /// Returns a new struct implementing [VecZnx] with the provided data as backing array. /// /// The struct will take ownership of buf[..[VecZnx::bytes_of]] /// /// User must ensure that data is properly alligned and that /// the size of data is at least equal to [VecZnx::bytes_of]. - fn from_bytes(n: usize, cols: usize, bytes: &mut [u8]) -> Self::Owned { - let size = Self::bytes_of(n, cols); - debug_assert!( - bytes.len() >= size, - "invalid bytes: bytes.len()={} < self.bytes_of(n={}, cols={})={}", - bytes.len(), - n, - cols, - size - ); + pub fn from_bytes(n: usize, cols: usize, bytes: &mut [u8]) -> Self { #[cfg(debug_assertions)] { - assert_alignement(bytes.as_ptr()) + assert_eq!(bytes.len(), Self::bytes_of(n, cols)); + assert_alignement(bytes.as_ptr()); + } + unsafe { + let bytes_i64: &mut [i64] = cast_mut::(bytes); + let ptr: *mut i64 = bytes_i64.as_mut_ptr(); + VecZnx { + n: n, + cols: cols, + data: Vec::from_raw_parts(bytes_i64.as_mut_ptr(), bytes.len(), bytes.len()), + ptr: ptr, + } + } + } + + pub fn from_bytes_borrow(n: usize, cols: usize, bytes: &mut [u8]) -> Self { + #[cfg(debug_assertions)] + { + assert!(bytes.len() >= Self::bytes_of(n, cols)); + assert_alignement(bytes.as_ptr()); } VecZnx { n: n, - data: alias_mut_slice_to_vec(cast_mut(&mut bytes[..size])), + cols: cols, + data: Vec::new(), + ptr: bytes.as_mut_ptr() as *mut i64, } } - fn bytes_of(n: usize, cols: usize) -> usize { + pub fn bytes_of(n: usize, cols: usize) -> usize { bytes_of_vec_znx(n, cols) } - fn copy_from(&mut self, a: &A) - where - Self: AsMut, - { - copy_vec_znx_from(self.as_mut(), a); + pub fn copy_from(&mut self, a: &VecZnx) { + copy_vec_znx_from(self, a); } - fn raw(&self) -> &[i64] { - &self.data + pub fn raw(&self) -> &[i64] { + unsafe { std::slice::from_raw_parts(self.ptr, self.n * self.cols) } } - fn raw_mut(&mut self) -> &mut [i64] { - &mut self.data + pub fn borrowing(&self) -> bool { + self.data.len() == 0 } - fn as_ptr(&self) -> *const i64 { - self.data.as_ptr() + pub fn raw_mut(&mut self) -> &mut [i64] { + unsafe { std::slice::from_raw_parts_mut(self.ptr, self.n * self.cols) } } - fn as_mut_ptr(&mut self) -> *mut i64 { - self.data.as_mut_ptr() + pub fn as_ptr(&self) -> *const i64 { + self.ptr } - fn at(&self, i: usize) -> &[i64] { + pub fn as_mut_ptr(&mut self) -> *mut i64 { + self.ptr + } + + pub fn at(&self, i: usize) -> &[i64] { let n: usize = self.n(); &self.raw()[n * i..n * (i + 1)] } - fn at_mut(&mut self, i: usize) -> &mut [i64] { + pub fn at_mut(&mut self, i: usize) -> &mut [i64] { let n: usize = self.n(); &mut self.raw_mut()[n * i..n * (i + 1)] } - fn at_ptr(&self, i: usize) -> *const i64 { - &self.data[i * self.n] as *const i64 + pub fn at_ptr(&self, i: usize) -> *const i64 { + self.ptr.wrapping_add(i * self.n) } - fn at_mut_ptr(&mut self, i: usize) -> *mut i64 { - &mut self.data[i * self.n] as *mut i64 + pub fn at_mut_ptr(&mut self, i: usize) -> *mut i64 { + self.ptr.wrapping_add(i * self.n) } - fn zero(&mut self) { - unsafe { znx::znx_zero_i64_ref(self.data.len() as u64, self.data.as_mut_ptr()) } + pub fn zero(&mut self) { + unsafe { znx::znx_zero_i64_ref((self.n * self.cols) as u64, self.ptr) } } - fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) { + pub fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) { normalize(log_base2k, self, carry) } - fn rsh(&mut self, log_base2k: usize, k: usize, carry: &mut [u8]) { + pub fn rsh(&mut self, log_base2k: usize, k: usize, carry: &mut [u8]) { rsh(log_base2k, self, k, carry) } - fn switch_degree(&self, a: &mut A) - where - Self: AsRef, - { - switch_degree(a, self.as_ref()) + pub fn switch_degree(&self, a: &mut VecZnx) { + switch_degree(a, self) } - fn print(&self, cols: usize, n: usize) { + pub fn print(&self, cols: usize, n: usize) { (0..cols).for_each(|i| println!("{}: {:?}", i, &self.at(i)[..n])) } } -/// [VecZnx] represents a vector of small norm polynomials of Zn\[X\] with [i64] coefficients. -/// A [VecZnx] is composed of multiple Zn\[X\] polynomials stored in a single contiguous array -/// in the memory. -#[derive(Clone)] -pub struct VecZnx { - /// Polynomial degree. - pub n: usize, - /// Polynomial coefficients, as a contiguous array. Each col is equally spaced by n. - pub data: Vec, -} - -impl AsMut for VecZnx { - fn as_mut(&mut self) -> &mut VecZnx { - self +impl Infos for VecZnx { + /// Returns the base 2 logarithm of the [VecZnx] degree. + fn log_n(&self) -> usize { + (usize::BITS - (self.n - 1).leading_zeros()) as _ } -} -impl AsRef for VecZnx { - fn as_ref(&self) -> &VecZnx { - self + /// Returns the [VecZnx] degree. + fn n(&self) -> usize { + self.n + } + + /// Returns the number of cols of the [VecZnx]. + fn cols(&self) -> usize { + self.cols + } + + /// Returns the number of rows of the [VecZnx]. + fn rows(&self) -> usize { + 1 } } /// Copies the coefficients of `a` on the receiver. /// Copy is done with the minimum size matching both backing arrays. -pub fn copy_vec_znx_from(b: &mut B, a: &A) { +pub fn copy_vec_znx_from(b: &mut VecZnx, a: &VecZnx) { let data_a: &[i64] = a.raw(); let data_b: &mut [i64] = b.raw_mut(); let size = min(data_b.len(), data_a.len()); @@ -353,9 +183,13 @@ pub fn copy_vec_znx_from(b: &mut B, a: &A) { impl VecZnx { /// Allocates a new [VecZnx] composed of #cols polynomials of Z\[X\]. pub fn new(n: usize, cols: usize) -> Self { + let mut data: Vec = alloc_aligned::(n * cols); + let ptr: *mut i64 = data.as_mut_ptr(); Self { n: n, - data: alloc_aligned::(n * cols), + cols: cols, + data: data, + ptr: ptr, } } @@ -370,8 +204,12 @@ impl VecZnx { return; } - self.data - .truncate((self.cols() - k / log_base2k) * self.n()); + if !self.borrowing() { + self.data + .truncate((self.cols() - k / log_base2k) * self.n()); + } + + self.cols -= k / log_base2k; let k_rem: usize = k % log_base2k; @@ -384,7 +222,7 @@ impl VecZnx { } } -pub fn switch_degree(b: &mut B, a: &A) { +pub fn switch_degree(b: &mut VecZnx, a: &VecZnx) { let (n_in, n_out) = (a.n(), b.n()); let (gap_in, gap_out): (usize, usize); @@ -406,7 +244,7 @@ pub fn switch_degree(b: &mut B, a: &A) { }); } -fn normalize(log_base2k: usize, a: &mut T, tmp_bytes: &mut [u8]) { +fn normalize(log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) { let n: usize = a.n(); debug_assert!( @@ -437,7 +275,7 @@ fn normalize(log_base2k: usize, a: &mut T, tmp_bytes: &mut [u8] } } -pub fn rsh(log_base2k: usize, a: &mut T, k: usize, tmp_bytes: &mut [u8]) { +pub fn rsh(log_base2k: usize, a: &mut VecZnx, k: usize, tmp_bytes: &mut [u8]) { let n: usize = a.n(); debug_assert!( @@ -469,26 +307,23 @@ pub fn rsh(log_base2k: usize, a: &mut T, k: usize, tmp_bytes: & znx::znx_zero_i64_ref(n as u64, carry_i64.as_mut_ptr()); } - let mask: i64 = (1 << k_rem) - 1; let log_base2k: usize = log_base2k; (cols_steps..cols).for_each(|i| { izip!(carry_i64.iter_mut(), a.at_mut(i).iter_mut()).for_each(|(ci, xi)| { *xi += *ci << log_base2k; *ci = get_base_k_carry(*xi, k_rem); - *xi = (*xi-*ci)>>k_rem; + *xi = (*xi - *ci) >> k_rem; }); }) } } #[inline(always)] -fn get_base_k_carry(x: i64, k: usize) -> i64{ - (x<<64-k) >> (64-k) +fn get_base_k_carry(x: i64, k: usize) -> i64 { + (x << 64 - k) >> (64 - k) } -pub trait VecZnxCommon: VecZnxApi + Infos {} - pub trait VecZnxOps { /// Allocates a new [VecZnx]. /// @@ -504,50 +339,34 @@ pub trait VecZnxOps { fn vec_znx_normalize_tmp_bytes(&self) -> usize; /// c <- a + b. - fn vec_znx_add( - &self, - c: &mut C, - a: &A, - b: &B, - ); + fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx); /// b <- b + a. - fn vec_znx_add_inplace(&self, b: &mut B, a: &A); + fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx); /// c <- a - b. - fn vec_znx_sub( - &self, - c: &mut C, - a: &A, - b: &B, - ); + fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx); /// b <- b - a. - fn vec_znx_sub_inplace(&self, b: &mut B, a: &A); + fn vec_znx_sub_inplace(&self, b: &mut VecZnx, a: &VecZnx); /// b <- -a. - fn vec_znx_negate(&self, b: &mut B, a: &A); + fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx); /// b <- -b. - fn vec_znx_negate_inplace(&self, a: &mut A); + fn vec_znx_negate_inplace(&self, a: &mut VecZnx); /// b <- a * X^k (mod X^{n} + 1) - fn vec_znx_rotate(&self, k: i64, b: &mut B, a: &A); + fn vec_znx_rotate(&self, k: i64, b: &mut VecZnx, a: &VecZnx); /// a <- a * X^k (mod X^{n} + 1) - fn vec_znx_rotate_inplace(&self, k: i64, a: &mut A); + fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx); /// b <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1)) - fn vec_znx_automorphism( - &self, - k: i64, - b: &mut B, - a: &A, - a_cols: usize, - ); + fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx, a_cols: usize); /// a <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1)) - fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut A, a_cols: usize); + fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, a_cols: usize); /// Splits b into subrings and copies them them into a. /// @@ -555,12 +374,7 @@ pub trait VecZnxOps { /// /// This method requires that all [VecZnx] of b have the same ring degree /// and that b.n() * b.len() <= a.n() - fn vec_znx_split( - &self, - b: &mut Vec, - a: &A, - buf: &mut C, - ); + fn vec_znx_split(&self, b: &mut Vec, a: &VecZnx, buf: &mut VecZnx); /// Merges the subrings a into b. /// @@ -568,7 +382,7 @@ pub trait VecZnxOps { /// /// This method requires that all [VecZnx] of a have the same ring degree /// and that a.n() * a.len() <= b.n() - fn vec_znx_merge(&self, b: &mut B, a: &Vec); + fn vec_znx_merge(&self, b: &mut VecZnx, a: &Vec); } impl VecZnxOps for Module { @@ -581,19 +395,14 @@ impl VecZnxOps for Module { } fn vec_znx_normalize_tmp_bytes(&self) -> usize { - unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.0) as usize } + unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize } } // c <- a + b - fn vec_znx_add( - &self, - c: &mut C, - a: &A, - b: &B, - ) { + fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx) { unsafe { vec_znx::vec_znx_add( - self.0, + self.ptr, c.as_mut_ptr(), c.cols() as u64, c.n() as u64, @@ -608,10 +417,10 @@ impl VecZnxOps for Module { } // b <- a + b - fn vec_znx_add_inplace(&self, b: &mut B, a: &A) { + fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx) { unsafe { vec_znx::vec_znx_add( - self.0, + self.ptr, b.as_mut_ptr(), b.cols() as u64, b.n() as u64, @@ -626,15 +435,10 @@ impl VecZnxOps for Module { } // c <- a + b - fn vec_znx_sub( - &self, - c: &mut C, - a: &A, - b: &B, - ) { + fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx) { unsafe { vec_znx::vec_znx_sub( - self.0, + self.ptr, c.as_mut_ptr(), c.cols() as u64, c.n() as u64, @@ -649,10 +453,10 @@ impl VecZnxOps for Module { } // b <- a + b - fn vec_znx_sub_inplace(&self, b: &mut B, a: &A) { + fn vec_znx_sub_inplace(&self, b: &mut VecZnx, a: &VecZnx) { unsafe { vec_znx::vec_znx_sub( - self.0, + self.ptr, b.as_mut_ptr(), b.cols() as u64, b.n() as u64, @@ -666,10 +470,10 @@ impl VecZnxOps for Module { } } - fn vec_znx_negate(&self, b: &mut B, a: &A) { + fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx) { unsafe { vec_znx::vec_znx_negate( - self.0, + self.ptr, b.as_mut_ptr(), b.cols() as u64, b.n() as u64, @@ -680,10 +484,10 @@ impl VecZnxOps for Module { } } - fn vec_znx_negate_inplace(&self, a: &mut A) { + fn vec_znx_negate_inplace(&self, a: &mut VecZnx) { unsafe { vec_znx::vec_znx_negate( - self.0, + self.ptr, a.as_mut_ptr(), a.cols() as u64, a.n() as u64, @@ -694,10 +498,10 @@ impl VecZnxOps for Module { } } - fn vec_znx_rotate(&self, k: i64, b: &mut B, a: &A) { + fn vec_znx_rotate(&self, k: i64, b: &mut VecZnx, a: &VecZnx) { unsafe { vec_znx::vec_znx_rotate( - self.0, + self.ptr, k, b.as_mut_ptr(), b.cols() as u64, @@ -709,10 +513,10 @@ impl VecZnxOps for Module { } } - fn vec_znx_rotate_inplace(&self, k: i64, a: &mut A) { + fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx) { unsafe { vec_znx::vec_znx_rotate( - self.0, + self.ptr, k, a.as_mut_ptr(), a.cols() as u64, @@ -739,11 +543,11 @@ impl VecZnxOps for Module { /// /// # Example /// ``` - /// use base2k::{Module, FFT64, VecZnx, Encoding, Infos, VecZnxApi, VecZnxOps}; + /// use base2k::{Module, MODULETYPE, VecZnx, Encoding, Infos, VecZnxOps}; /// use itertools::izip; /// /// let n: usize = 8; // polynomial degree - /// let module = Module::new::(n); + /// let module = Module::new(n, MODULETYPE::FFT64); /// let mut a: VecZnx = module.new_vec_znx(2); /// let mut b: VecZnx = module.new_vec_znx(2); /// let mut c: VecZnx = module.new_vec_znx(2); @@ -759,21 +563,15 @@ impl VecZnxOps for Module { /// (1..col.len()).for_each(|i|{ /// col[n-i] = -(i as i64) /// }); - /// izip!(b.data.iter(), c.data.iter()).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); + /// izip!(b.raw().iter(), c.raw().iter()).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); /// ``` - fn vec_znx_automorphism( - &self, - k: i64, - b: &mut B, - a: &A, - a_cols: usize, - ) { + fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx, a_cols: usize) { debug_assert_eq!(a.n(), self.n()); debug_assert_eq!(b.n(), self.n()); debug_assert!(a.cols() >= a_cols); unsafe { vec_znx::vec_znx_automorphism( - self.0, + self.ptr, k, b.as_mut_ptr(), b.cols() as u64, @@ -799,11 +597,11 @@ impl VecZnxOps for Module { /// /// # Example /// ``` - /// use base2k::{Module, FFT64, VecZnx, Encoding, Infos, VecZnxApi, VecZnxOps}; + /// use base2k::{Module, MODULETYPE, VecZnx, Encoding, Infos, VecZnxOps}; /// use itertools::izip; /// /// let n: usize = 8; // polynomial degree - /// let module = Module::new::(n); + /// let module = Module::new(n, MODULETYPE::FFT64); /// let mut a: VecZnx = VecZnx::new(n, 2); /// let mut b: VecZnx = VecZnx::new(n, 2); /// @@ -818,14 +616,14 @@ impl VecZnxOps for Module { /// (1..col.len()).for_each(|i|{ /// col[n-i] = -(i as i64) /// }); - /// izip!(a.data.iter(), b.data.iter()).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); + /// izip!(a.raw().iter(), b.raw().iter()).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); /// ``` - fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut A, a_cols: usize) { + fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, a_cols: usize) { debug_assert_eq!(a.n(), self.n()); debug_assert!(a.cols() >= a_cols); unsafe { vec_znx::vec_znx_automorphism( - self.0, + self.ptr, k, a.as_mut_ptr(), a.cols() as u64, @@ -837,12 +635,7 @@ impl VecZnxOps for Module { } } - fn vec_znx_split( - &self, - b: &mut Vec, - a: &A, - buf: &mut C, - ) { + fn vec_znx_split(&self, b: &mut Vec, a: &VecZnx, buf: &mut VecZnx) { let (n_in, n_out) = (a.n(), b[0].n()); debug_assert!( @@ -868,7 +661,7 @@ impl VecZnxOps for Module { }) } - fn vec_znx_merge(&self, b: &mut B, a: &Vec) { + fn vec_znx_merge(&self, b: &mut VecZnx, a: &Vec) { let (n_in, n_out) = (b.n(), a[0].n()); debug_assert!( diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index 2e36dd5..90942c7 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,29 +1,69 @@ -use crate::ffi::vec_znx_big; -use crate::ffi::vec_znx_dft; -use crate::{assert_alignement, Infos, Module, VecZnxApi, VecZnxDft}; +use crate::ffi::vec_znx_big::{self, vec_znx_bigcoeff_t}; +use crate::{alloc_aligned, assert_alignement, Infos, Module, VecZnx, VecZnxDft, MODULETYPE}; -pub struct VecZnxBig(pub *mut vec_znx_big::vec_znx_bigcoeff_t, pub usize); +pub struct VecZnxBig { + pub data: Vec, + pub ptr: *mut u8, + pub n: usize, + pub cols: usize, + pub backend: MODULETYPE, +} impl VecZnxBig { /// Returns a new [VecZnxBig] with the provided data as backing array. /// User must ensure that data is properly alligned and that /// the size of data is at least equal to [Module::bytes_of_vec_znx_big]. - pub fn from_bytes(cols: usize, bytes: &mut [u8]) -> VecZnxBig { + pub fn from_bytes(module: &Module, cols: usize, bytes: &mut [u8]) -> Self { #[cfg(debug_assertions)] { assert_alignement(bytes.as_ptr()) }; - VecZnxBig( - bytes.as_mut_ptr() as *mut vec_znx_big::vec_znx_bigcoeff_t, - cols, - ) + unsafe { + Self { + data: Vec::from_raw_parts(bytes.as_mut_ptr(), bytes.len(), bytes.len()), + ptr: bytes.as_mut_ptr(), + n: module.n(), + cols: cols, + backend: module.backend, + } + } + } + + pub fn from_bytes_borrow(module: &Module, cols: usize, bytes: &mut [u8]) -> Self { + #[cfg(debug_assertions)] + { + assert_eq!(bytes.len(), module.bytes_of_vec_znx_big(cols)); + assert_alignement(bytes.as_ptr()); + } + Self { + data: Vec::new(), + ptr: bytes.as_mut_ptr(), + n: module.n(), + cols: cols, + backend: module.backend, + } } pub fn as_vec_znx_dft(&mut self) -> VecZnxDft { - VecZnxDft(self.0 as *mut vec_znx_dft::vec_znx_dft_t, self.1) + VecZnxDft { + data: Vec::new(), + ptr: self.ptr, + n: self.n, + cols: self.cols, + backend: self.backend, + } } + + pub fn n(&self) -> usize { + self.n + } + pub fn cols(&self) -> usize { - self.1 + self.cols + } + + pub fn backend(&self) -> MODULETYPE { + self.backend } } @@ -47,39 +87,34 @@ pub trait VecZnxBigOps { fn bytes_of_vec_znx_big(&self, cols: usize) -> usize; /// b <- b - a - fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig, a: &T); + fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig, a: &VecZnx); /// c <- b - a - fn vec_znx_big_sub_small_a( - &self, - c: &mut VecZnxBig, - a: &T, - b: &VecZnxBig, - ); + fn vec_znx_big_sub_small_a(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig); /// c <- b + a - fn vec_znx_big_add_small(&self, c: &mut VecZnxBig, a: &T, b: &VecZnxBig); + fn vec_znx_big_add_small(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig); /// b <- b + a - fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &T); + fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx); fn vec_znx_big_normalize_tmp_bytes(&self) -> usize; /// b <- normalize(a) - fn vec_znx_big_normalize( + fn vec_znx_big_normalize( &self, log_base2k: usize, - b: &mut T, + b: &mut VecZnx, a: &VecZnxBig, tmp_bytes: &mut [u8], ); fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize; - fn vec_znx_big_range_normalize_base2k( + fn vec_znx_big_range_normalize_base2k( &self, log_base2k: usize, - res: &mut T, + res: &mut VecZnx, a: &VecZnxBig, a_range_begin: usize, a_range_xend: usize, @@ -94,7 +129,15 @@ pub trait VecZnxBigOps { impl VecZnxBigOps for Module { fn new_vec_znx_big(&self, cols: usize) -> VecZnxBig { - unsafe { VecZnxBig(vec_znx_big::new_vec_znx_big(self.0, cols as u64), cols) } + let mut data: Vec = alloc_aligned::(self.bytes_of_vec_znx_big(cols)); + let ptr: *mut u8 = data.as_mut_ptr(); + VecZnxBig { + data: data, + ptr: ptr, + n: self.n(), + cols: cols, + backend: self.backend(), + } } fn new_vec_znx_big_from_bytes(&self, cols: usize, bytes: &mut [u8]) -> VecZnxBig { @@ -108,55 +151,50 @@ impl VecZnxBigOps for Module { { assert_alignement(bytes.as_ptr()) } - VecZnxBig::from_bytes(cols, bytes) + VecZnxBig::from_bytes(self, cols, bytes) } fn bytes_of_vec_znx_big(&self, cols: usize) -> usize { - unsafe { vec_znx_big::bytes_of_vec_znx_big(self.0, cols as u64) as usize } + unsafe { vec_znx_big::bytes_of_vec_znx_big(self.ptr, cols as u64) as usize } } - fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig, a: &T) { + fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) { unsafe { vec_znx_big::vec_znx_big_sub_small_a( - self.0, - b.0, + self.ptr, + b.ptr as *mut vec_znx_bigcoeff_t, b.cols() as u64, a.as_ptr(), a.cols() as u64, a.n() as u64, - b.0, + b.ptr as *mut vec_znx_bigcoeff_t, b.cols() as u64, ) } } - fn vec_znx_big_sub_small_a( - &self, - c: &mut VecZnxBig, - a: &T, - b: &VecZnxBig, - ) { + fn vec_znx_big_sub_small_a(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig) { unsafe { vec_znx_big::vec_znx_big_sub_small_a( - self.0, - c.0, + self.ptr, + c.ptr as *mut vec_znx_bigcoeff_t, c.cols() as u64, a.as_ptr(), a.cols() as u64, a.n() as u64, - b.0, + b.ptr as *mut vec_znx_bigcoeff_t, b.cols() as u64, ) } } - fn vec_znx_big_add_small(&self, c: &mut VecZnxBig, a: &T, b: &VecZnxBig) { + fn vec_znx_big_add_small(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig) { unsafe { vec_znx_big::vec_znx_big_add_small( - self.0, - c.0, + self.ptr, + c.ptr as *mut vec_znx_bigcoeff_t, c.cols() as u64, - b.0, + b.ptr as *mut vec_znx_bigcoeff_t, b.cols() as u64, a.as_ptr(), a.cols() as u64, @@ -165,13 +203,13 @@ impl VecZnxBigOps for Module { } } - fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &T) { + fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) { unsafe { vec_znx_big::vec_znx_big_add_small( - self.0, - b.0, + self.ptr, + b.ptr as *mut vec_znx_bigcoeff_t, b.cols() as u64, - b.0, + b.ptr as *mut vec_znx_bigcoeff_t, b.cols() as u64, a.as_ptr(), a.cols() as u64, @@ -181,13 +219,13 @@ impl VecZnxBigOps for Module { } fn vec_znx_big_normalize_tmp_bytes(&self) -> usize { - unsafe { vec_znx_big::vec_znx_big_normalize_base2k_tmp_bytes(self.0) as usize } + unsafe { vec_znx_big::vec_znx_big_normalize_base2k_tmp_bytes(self.ptr) as usize } } - fn vec_znx_big_normalize( + fn vec_znx_big_normalize( &self, log_base2k: usize, - b: &mut T, + b: &mut VecZnx, a: &VecZnxBig, tmp_bytes: &mut [u8], ) { @@ -203,12 +241,12 @@ impl VecZnxBigOps for Module { } unsafe { vec_znx_big::vec_znx_big_normalize_base2k( - self.0, + self.ptr, log_base2k as u64, b.as_mut_ptr(), b.cols() as u64, b.n() as u64, - a.0, + a.ptr as *mut vec_znx_bigcoeff_t, a.cols() as u64, tmp_bytes.as_mut_ptr(), ) @@ -216,13 +254,13 @@ impl VecZnxBigOps for Module { } fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize { - unsafe { vec_znx_big::vec_znx_big_range_normalize_base2k_tmp_bytes(self.0) as usize } + unsafe { vec_znx_big::vec_znx_big_range_normalize_base2k_tmp_bytes(self.ptr) as usize } } - fn vec_znx_big_range_normalize_base2k( + fn vec_znx_big_range_normalize_base2k( &self, log_base2k: usize, - res: &mut T, + res: &mut VecZnx, a: &VecZnxBig, a_range_begin: usize, a_range_xend: usize, @@ -241,12 +279,12 @@ impl VecZnxBigOps for Module { } unsafe { vec_znx_big::vec_znx_big_range_normalize_base2k( - self.0, + self.ptr, log_base2k as u64, res.as_mut_ptr(), res.cols() as u64, res.n() as u64, - a.0, + a.ptr as *mut vec_znx_bigcoeff_t, a_range_begin as u64, a_range_xend as u64, a_range_step as u64, @@ -258,11 +296,11 @@ impl VecZnxBigOps for Module { fn vec_znx_big_automorphism(&self, gal_el: i64, b: &mut VecZnxBig, a: &VecZnxBig) { unsafe { vec_znx_big::vec_znx_big_automorphism( - self.0, + self.ptr, gal_el, - b.0, + b.ptr as *mut vec_znx_bigcoeff_t, b.cols() as u64, - a.0, + a.ptr as *mut vec_znx_bigcoeff_t, a.cols() as u64, ); } @@ -271,11 +309,11 @@ impl VecZnxBigOps for Module { fn vec_znx_big_automorphism_inplace(&self, gal_el: i64, a: &mut VecZnxBig) { unsafe { vec_znx_big::vec_znx_big_automorphism( - self.0, + self.ptr, gal_el, - a.0, + a.ptr as *mut vec_znx_bigcoeff_t, a.cols() as u64, - a.0, + a.ptr as *mut vec_znx_bigcoeff_t, a.cols() as u64, ); } diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index dfd4370..275b401 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -1,33 +1,104 @@ -use crate::ffi::vec_znx_big; +use crate::ffi::vec_znx_big::vec_znx_bigcoeff_t; use crate::ffi::vec_znx_dft; -use crate::ffi::vec_znx_dft::bytes_of_vec_znx_dft; -use crate::{assert_alignement, Infos, Module, VecZnxApi, VecZnxBig}; +use crate::ffi::vec_znx_dft::{bytes_of_vec_znx_dft, vec_znx_dft_t}; +use crate::{alloc_aligned, VecZnx}; +use crate::{assert_alignement, Infos, Module, VecZnxBig, MODULETYPE}; -pub struct VecZnxDft(pub *mut vec_znx_dft::vec_znx_dft_t, pub usize); +pub struct VecZnxDft { + pub data: Vec, + pub ptr: *mut u8, + pub n: usize, + pub cols: usize, + pub backend: MODULETYPE, +} impl VecZnxDft { /// Returns a new [VecZnxDft] with the provided data as backing array. /// User must ensure that data is properly alligned and that /// the size of data is at least equal to [Module::bytes_of_vec_znx_dft]. - pub fn from_bytes(cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft { + pub fn from_bytes(module: &Module, cols: usize, bytes: &mut [u8]) -> VecZnxDft { #[cfg(debug_assertions)] { - assert_alignement(tmp_bytes.as_ptr()) + assert_eq!(bytes.len(), module.bytes_of_vec_znx_dft(cols)); + assert_alignement(bytes.as_ptr()) + } + unsafe { + VecZnxDft { + data: Vec::from_raw_parts(bytes.as_mut_ptr(), bytes.len(), bytes.len()), + ptr: bytes.as_mut_ptr(), + n: module.n(), + cols: cols, + backend: module.backend, + } + } + } + + pub fn from_bytes_borrow(module: &Module, cols: usize, bytes: &mut [u8]) -> VecZnxDft { + #[cfg(debug_assertions)] + { + assert_eq!(bytes.len(), module.bytes_of_vec_znx_dft(cols)); + assert_alignement(bytes.as_ptr()); + } + VecZnxDft { + data: Vec::new(), + ptr: bytes.as_mut_ptr(), + n: module.n(), + cols: cols, + backend: module.backend, } - VecZnxDft( - tmp_bytes.as_mut_ptr() as *mut vec_znx_dft::vec_znx_dft_t, - cols, - ) } /// Cast a [VecZnxDft] into a [VecZnxBig]. /// The returned [VecZnxBig] shares the backing array /// with the original [VecZnxDft]. pub fn as_vec_znx_big(&mut self) -> VecZnxBig { - VecZnxBig(self.0 as *mut vec_znx_big::vec_znx_bigcoeff_t, self.1) + VecZnxBig { + data: Vec::new(), + ptr: self.ptr, + n: self.n, + cols: self.cols, + backend: self.backend, + } } + + pub fn n(&self) -> usize { + self.n + } + pub fn cols(&self) -> usize { - self.1 + self.cols + } + + pub fn backend(&self) -> MODULETYPE { + self.backend + } + + /// Returns a non-mutable reference of `T` of the entire contiguous array of the [VecZnxDft]. + /// When using [`crate::FFT64`] as backend, `T` should be [f64]. + /// When using [`crate::NTT120`] as backend, `T` should be [i64]. + /// The length of the returned array is cols * n. + pub fn raw(&self, module: &Module) -> &[T] { + let ptr: *const T = self.ptr as *const T; + let len: usize = (self.cols() * module.n() * 8) / std::mem::size_of::(); + unsafe { &std::slice::from_raw_parts(ptr, len) } + } + + pub fn at(&self, module: &Module, col_i: usize) -> &[T] { + &self.raw::(module)[col_i * module.n()..(col_i + 1) * module.n()] + } + + /// Returns a mutable reference of `T` of the entire contiguous array of the [VecZnxDft]. + /// When using [`crate::FFT64`] as backend, `T` should be [f64]. + /// When using [`crate::NTT120`] as backend, `T` should be [i64]. + /// The length of the returned array is cols * n. + pub fn raw_mut(&self, module: &Module) -> &mut [T] { + let ptr: *mut T = self.ptr as *mut T; + let len: usize = (self.cols() * module.n() * 8) / std::mem::size_of::(); + unsafe { std::slice::from_raw_parts_mut(ptr, len) } + } + + pub fn at_mut(&self, module: &Module, col_i: usize) -> &mut [T] { + &mut self.raw_mut::(module)[col_i * module.n()..(col_i + 1) * module.n()] } } @@ -72,12 +143,20 @@ pub trait VecZnxDftOps { tmp_bytes: &mut [u8], ); - fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &T, a_limbs: usize); + fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx, a_limbs: usize); } impl VecZnxDftOps for Module { fn new_vec_znx_dft(&self, cols: usize) -> VecZnxDft { - unsafe { VecZnxDft(vec_znx_dft::new_vec_znx_dft(self.0, cols as u64), cols) } + let mut data: Vec = alloc_aligned::(self.bytes_of_vec_znx_dft(cols)); + let ptr: *mut u8 = data.as_mut_ptr(); + VecZnxDft { + data: data, + ptr: ptr, + n: self.n(), + cols: cols, + backend: self.backend(), + } } fn new_vec_znx_dft_from_bytes(&self, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft { @@ -91,11 +170,11 @@ impl VecZnxDftOps for Module { { assert_alignement(tmp_bytes.as_ptr()) } - VecZnxDft::from_bytes(cols, tmp_bytes) + VecZnxDft::from_bytes(self, cols, tmp_bytes) } fn bytes_of_vec_znx_dft(&self, cols: usize) -> usize { - unsafe { bytes_of_vec_znx_dft(self.0, cols as u64) as usize } + unsafe { bytes_of_vec_znx_dft(self.ptr, cols as u64) as usize } } fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_limbs: usize) { @@ -106,19 +185,25 @@ impl VecZnxDftOps for Module { a_limbs ); unsafe { - vec_znx_dft::vec_znx_idft_tmp_a(self.0, b.0, b.cols() as u64, a.0, a_limbs as u64) + vec_znx_dft::vec_znx_idft_tmp_a( + self.ptr, + b.ptr as *mut vec_znx_bigcoeff_t, + b.cols() as u64, + a.ptr as *mut vec_znx_dft_t, + a_limbs as u64, + ) } } fn vec_znx_idft_tmp_bytes(&self) -> usize { - unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(self.0) as usize } + unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(self.ptr) as usize } } /// b <- DFT(a) /// /// # Panics /// If b.cols < a_cols - fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &T, a_cols: usize) { + fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx, a_cols: usize) { debug_assert!( b.cols() >= a_cols, "invalid a_cols: b.cols()={} < a_cols={}", @@ -127,8 +212,8 @@ impl VecZnxDftOps for Module { ); unsafe { vec_znx_dft::vec_znx_dft( - self.0, - b.0, + self.ptr, + b.ptr as *mut vec_znx_dft_t, b.cols() as u64, a.as_ptr(), a_cols as u64, @@ -169,10 +254,10 @@ impl VecZnxDftOps for Module { } unsafe { vec_znx_dft::vec_znx_idft( - self.0, - b.0, + self.ptr, + b.ptr as *mut vec_znx_bigcoeff_t, a.cols() as u64, - a.0, + a.ptr as *mut vec_znx_dft_t, a_cols as u64, tmp_bytes.as_mut_ptr(), ) diff --git a/base2k/src/vmp.rs b/base2k/src/vmp.rs index edbc5e7..7195cfb 100644 --- a/base2k/src/vmp.rs +++ b/base2k/src/vmp.rs @@ -1,5 +1,6 @@ -use crate::ffi::vmp; -use crate::{assert_alignement, Infos, Module, VecZnxApi, VecZnxDft}; +use crate::ffi::vec_znx_dft::vec_znx_dft_t; +use crate::ffi::vmp::{self, vmp_pmat_t}; +use crate::{alloc_aligned, assert_alignement, Infos, Module, VecZnx, VecZnxDft, MODULETYPE}; /// Vector Matrix Product Prepared Matrix: a vector of [VecZnx], /// stored as a 3D matrix in the DFT domain in a single contiguous array. @@ -11,20 +12,75 @@ use crate::{assert_alignement, Infos, Module, VecZnxApi, VecZnxDft}; /// [VmpPMat] is used to permform a vector matrix product between a [VecZnx] and a [VmpPMat]. /// See the trait [VmpPMatOps] for additional information. pub struct VmpPMat { - /// The pointer to the C memory. - pub data: *mut vmp::vmp_pmat_t, + /// Raw data, is empty if borrowing scratch space. + data: Vec, + /// Pointer to data. Can point to scratch space. + ptr: *mut u8, /// The number of [VecZnxDft]. - pub rows: usize, + rows: usize, /// The number of cols in each [VecZnxDft]. - pub cols: usize, + cols: usize, /// The ring degree of each [VecZnxDft]. - pub n: usize, + n: usize, + + backend: MODULETYPE, +} + +impl Infos for VmpPMat { + /// Returns the ring dimension of the [VmpPMat]. + fn n(&self) -> usize { + self.n + } + + fn log_n(&self) -> usize { + (usize::BITS - (self.n() - 1).leading_zeros()) as _ + } + + /// Returns the number of rows (i.e. of [VecZnxDft]) of the [VmpPMat] + fn rows(&self) -> usize { + self.rows + } + + /// Returns the number of cols of the [VmpPMat]. + /// The number of cols refers to the number of cols + /// of each [VecZnxDft]. + /// This method is equivalent to [Self::cols]. + fn cols(&self) -> usize { + self.cols + } } impl VmpPMat { - /// Returns the pointer to the [vmp_pmat_t]. - pub fn data(&self) -> *mut vmp::vmp_pmat_t { - self.data + pub fn as_ptr(&self) -> *const u8 { + self.ptr + } + + pub fn as_mut_ptr(&self) -> *mut u8 { + self.ptr + } + + pub fn borrowed(&self) -> bool{ + self.data.len() == 0 + } + + /// Returns a non-mutable reference of `T` of the entire contiguous array of the [VmpPMat]. + /// When using [`crate::FFT64`] as backend, `T` should be [f64]. + /// When using [`crate::NTT120`] as backend, `T` should be [i64]. + /// The length of the returned array is rows * cols * n. + pub fn raw(&self) -> &[T] { + let ptr: *const T = self.ptr as *const T; + let len: usize = (self.rows() * self.cols() * self.n() * 8) / std::mem::size_of::(); + unsafe { &std::slice::from_raw_parts(ptr, len) } + } + + /// Returns a non-mutable reference of `T` of the entire contiguous array of the [VmpPMat]. + /// When using [`crate::FFT64`] as backend, `T` should be [f64]. + /// When using [`crate::NTT120`] as backend, `T` should be [i64]. + /// The length of the returned array is rows * cols * n. + pub fn raw_mut(&self) -> &mut [T] { + let ptr: *mut T = self.ptr as *mut T; + let len: usize = (self.rows() * self.cols() * self.n() * 8) / std::mem::size_of::(); + unsafe { std::slice::from_raw_parts_mut(ptr, len) } } /// Returns a copy of the backend array at index (i, j) of the [VmpPMat]. @@ -36,16 +92,16 @@ impl VmpPMat { /// * `row`: row index (i). /// * `col`: col index (j). pub fn at(&self, row: usize, col: usize) -> Vec { - let mut res: Vec = vec![T::default(); self.n]; + let mut res: Vec = alloc_aligned(self.n); if self.n < 8 { res.copy_from_slice( - &self.get_backend_array::()[(row + col * self.rows()) * self.n() + &self.raw::()[(row + col * self.rows()) * self.n() ..(row + col * self.rows()) * (self.n() + 1)], ); } else { (0..self.n >> 3).for_each(|blk| { - res[blk * 8..(blk + 1) * 8].copy_from_slice(&self.get_array(row, col, blk)[..8]); + res[blk * 8..(blk + 1) * 8].copy_from_slice(&self.at_block(row, col, blk)[..8]); }); } @@ -54,33 +110,25 @@ impl VmpPMat { /// When using [`crate::FFT64`] as backend, `T` should be [f64]. /// When using [`crate::NTT120`] as backend, `T` should be [i64]. - fn get_array(&self, row: usize, col: usize, blk: usize) -> &[T] { + fn at_block(&self, row: usize, col: usize, blk: usize) -> &[T] { let nrows: usize = self.rows(); let ncols: usize = self.cols(); if col == (ncols - 1) && (ncols & 1 == 1) { - &self.get_backend_array::()[blk * nrows * ncols * 8 + col * nrows * 8 + row * 8..] + &self.raw::()[blk * nrows * ncols * 8 + col * nrows * 8 + row * 8..] } else { - &self.get_backend_array::()[blk * nrows * ncols * 8 + &self.raw::()[blk * nrows * ncols * 8 + (col / 2) * (2 * nrows) * 8 + row * 2 * 8 + (col % 2) * 8..] } } - - /// Returns a non-mutable reference of `T` of the entire contiguous array of the [VmpPMat]. - /// When using [`crate::FFT64`] as backend, `T` should be [f64]. - /// When using [`crate::NTT120`] as backend, `T` should be [i64]. - /// The length of the returned array is rows * cols * n. - pub fn get_backend_array(&self) -> &[T] { - let ptr: *const T = self.data as *const T; - let len: usize = (self.rows() * self.cols() * self.n() * 8) / std::mem::size_of::(); - unsafe { &std::slice::from_raw_parts(ptr, len) } - } } /// This trait implements methods for vector matrix product, /// that is, multiplying a [VecZnx] with a [VmpPMat]. pub trait VmpPMatOps { + fn bytes_of_vmp_pmat(&self, rows: usize, cols: usize) -> usize; + /// Allocates a new [VmpPMat] with the given number of rows and columns. /// /// # Arguments @@ -106,26 +154,6 @@ pub trait VmpPMatOps { /// * `b`: [VmpPMat] on which the values are encoded. /// * `a`: the contiguous array of [i64] of the 3D matrix to encode on the [VmpPMat]. /// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. - /// - /// # Example - /// ``` - /// use base2k::{Module, VmpPMat, VmpPMatOps, FFT64, Free, alloc_aligned}; - /// use std::cmp::min; - /// - /// let n: usize = 1024; - /// let module = Module::new::(n); - /// let rows = 5; - /// let cols = 6; - /// - /// let mut b_mat: Vec = vec![0i64;n * cols * rows]; - /// - /// let mut buf: Vec = alloc_aligned(module.vmp_prepare_tmp_bytes(rows, cols)); - /// - /// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols); - /// module.vmp_prepare_contiguous(&mut vmp_pmat, &b_mat, &mut buf); - /// - /// vmp_pmat.free() // don't forget to free the memory once vmp_pmat is not needed anymore. - /// ``` fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], buf: &mut [u8]); /// Prepares a [VmpPMat] from a vector of [VecZnx]. @@ -137,32 +165,6 @@ pub trait VmpPMatOps { /// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. /// /// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. - /// - /// # Example - /// ``` - /// use base2k::{Module, FFT64, VmpPMat, VmpPMatOps, VecZnx, VecZnxApi, VecZnxOps, Free, alloc_aligned}; - /// use std::cmp::min; - /// - /// let n: usize = 1024; - /// let module: Module = Module::new::(n); - /// let rows: usize = 5; - /// let cols: usize = 6; - /// - /// let mut vecznx: Vec= Vec::new(); - /// (0..rows).for_each(|_|{ - /// vecznx.push(module.new_vec_znx(cols)); - /// }); - /// - /// let slices: Vec<&[i64]> = vecznx.iter().map(|v| v.data.as_slice()).collect(); - /// - /// let mut buf: Vec = alloc_aligned(module.vmp_prepare_tmp_bytes(rows, cols)); - /// - /// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols); - /// module.vmp_prepare_dblptr(&mut vmp_pmat, &slices, &mut buf); - /// - /// vmp_pmat.free(); - /// module.free(); - /// ``` fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], buf: &mut [u8]); /// Prepares the ith-row of [VmpPMat] from a vector of [VecZnx]. @@ -175,26 +177,6 @@ pub trait VmpPMatOps { /// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. /// /// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. - /// /// # Example - /// ``` - /// use base2k::{Module, FFT64, VmpPMat, VmpPMatOps, VecZnx, VecZnxApi, VecZnxOps, Free, alloc_aligned}; - /// use std::cmp::min; - /// - /// let n: usize = 1024; - /// let module: Module = Module::new::(n); - /// let rows: usize = 5; - /// let cols: usize = 6; - /// - /// let vecznx = module.new_vec_znx(cols); - /// - /// let mut buf: Vec = alloc_aligned(module.vmp_prepare_tmp_bytes(rows, cols)); - /// - /// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols); - /// module.vmp_prepare_row(&mut vmp_pmat, vecznx.raw(), 0, &mut buf); - /// - /// vmp_pmat.free(); - /// module.free(); - /// ``` fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]); /// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft]. @@ -237,38 +219,7 @@ pub trait VmpPMatOps { /// * `a`: the left operand [VecZnx] of the vector matrix product. /// * `b`: the right operand [VmpPMat] of the vector matrix product. /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_tmp_bytes]. - /// - /// # Example - /// ``` - /// use base2k::{Module, VecZnx, VecZnxOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps, FFT64, Free, VecZnxApi, alloc_aligned}; - /// - /// let n = 1024; - /// - /// let module: Module = Module::new::(n); - /// let cols: usize = 5; - /// - /// let rows: usize = cols; - /// let cols: usize = cols + 1; - /// let c_cols: usize = cols; - /// let a_cols: usize = cols; - /// let mut buf: Vec = alloc_aligned(module.vmp_apply_dft_tmp_bytes(c_cols, a_cols, rows, cols)); - /// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols); - /// - /// let a: VecZnx = module.new_vec_znx(cols); - /// let mut c_dft: VecZnxDft = module.new_vec_znx_dft(cols); - /// module.vmp_apply_dft(&mut c_dft, &a, &vmp_pmat, &mut buf); - /// - /// c_dft.free(); - /// vmp_pmat.free(); - /// module.free(); - /// ``` - fn vmp_apply_dft( - &self, - c: &mut VecZnxDft, - a: &T, - b: &VmpPMat, - buf: &mut [u8], - ); + fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]); /// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft_to_dft]. /// @@ -311,32 +262,6 @@ pub trait VmpPMatOps { /// * `a`: the left operand [VecZnxDft] of the vector matrix product. /// * `b`: the right operand [VmpPMat] of the vector matrix product. /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. - /// - /// # Example - /// ``` - /// use base2k::{Module, VecZnx, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps, FFT64, Free, alloc_aligned}; - /// - /// let n = 1024; - /// - /// let module: Module = Module::new::(n); - /// let cols: usize = 5; - /// - /// let rows: usize = cols; - /// let cols: usize = cols + 1; - /// let c_cols: usize = cols; - /// let a_cols: usize = cols; - /// let mut tmp_bytes: Vec = alloc_aligned(module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, rows, cols)); - /// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols); - /// - /// let a_dft: VecZnxDft = module.new_vec_znx_dft(cols); - /// let mut c_dft: VecZnxDft = module.new_vec_znx_dft(cols); - /// module.vmp_apply_dft_to_dft(&mut c_dft, &a_dft, &vmp_pmat, &mut tmp_bytes); - /// - /// a_dft.free(); - /// c_dft.free(); - /// vmp_pmat.free(); - /// module.free(); - /// ``` fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]); /// Applies the vector matrix product [VecZnxDft] x [VmpPMat] in place. @@ -363,46 +288,29 @@ pub trait VmpPMatOps { /// * `b`: the input and output of the vector matrix product, as a [VecZnxDft]. /// * `a`: the right operand [VmpPMat] of the vector matrix product. /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. - /// - /// # Example - /// ```rust - /// use base2k::{Module, VecZnx, VecZnxOps, VecZnxDft, VmpPMat, VmpPMatOps, FFT64, Free, VecZnxApi, VecZnxDftOps,alloc_aligned}; - /// - /// let n = 1024; - /// - /// let module: Module = Module::new::(n); - /// let cols: usize = 5; - /// - /// let rows: usize = cols; - /// let cols: usize = cols + 1; - /// let mut tmp_bytes: Vec = alloc_aligned(module.vmp_apply_dft_to_dft_tmp_bytes(cols, cols, rows, cols)); - /// let a: VecZnx = module.new_vec_znx(cols); - /// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols); - /// - /// let mut c_dft: VecZnxDft = module.new_vec_znx_dft(cols); - /// module.vmp_apply_dft_to_dft_inplace(&mut c_dft, &vmp_pmat, &mut tmp_bytes); - /// - /// c_dft.free(); - /// vmp_pmat.free(); - /// module.free(); - /// ``` fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, buf: &mut [u8]); } impl VmpPMatOps for Module { + fn bytes_of_vmp_pmat(&self, rows: usize, cols: usize) -> usize { + unsafe { vmp::bytes_of_vmp_pmat(self.ptr, rows as u64, cols as u64) as usize } + } + fn new_vmp_pmat(&self, rows: usize, cols: usize) -> VmpPMat { - unsafe { - VmpPMat { - data: vmp::new_vmp_pmat(self.0, rows as u64, cols as u64), - rows, - cols, - n: self.n(), - } + let mut data: Vec = alloc_aligned::(self.bytes_of_vmp_pmat(rows, cols)); + let ptr: *mut u8 = data.as_mut_ptr(); + VmpPMat { + data: data, + ptr: ptr, + n: self.n(), + cols: cols, + rows: rows, + backend: self.backend(), } } fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize) -> usize { - unsafe { vmp::vmp_prepare_tmp_bytes(self.0, rows as u64, cols as u64) as usize } + unsafe { vmp::vmp_prepare_tmp_bytes(self.ptr, rows as u64, cols as u64) as usize } } fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], tmp_bytes: &mut [u8]) { @@ -414,8 +322,8 @@ impl VmpPMatOps for Module { } unsafe { vmp::vmp_prepare_contiguous( - self.0, - b.data(), + self.ptr, + b.as_mut_ptr() as *mut vmp_pmat_t, a.as_ptr(), b.rows() as u64, b.cols() as u64, @@ -437,8 +345,8 @@ impl VmpPMatOps for Module { } unsafe { vmp::vmp_prepare_dblptr( - self.0, - b.data(), + self.ptr, + b.as_mut_ptr() as *mut vmp_pmat_t, ptrs.as_ptr(), b.rows() as u64, b.cols() as u64, @@ -456,8 +364,8 @@ impl VmpPMatOps for Module { } unsafe { vmp::vmp_prepare_row( - self.0, - b.data(), + self.ptr, + b.as_mut_ptr() as *mut vmp_pmat_t, a.as_ptr(), row_i as u64, b.rows() as u64, @@ -476,7 +384,7 @@ impl VmpPMatOps for Module { ) -> usize { unsafe { vmp::vmp_apply_dft_tmp_bytes( - self.0, + self.ptr, res_cols as u64, a_cols as u64, gct_rows as u64, @@ -485,13 +393,7 @@ impl VmpPMatOps for Module { } } - fn vmp_apply_dft( - &self, - c: &mut VecZnxDft, - a: &T, - b: &VmpPMat, - tmp_bytes: &mut [u8], - ) { + fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, tmp_bytes: &mut [u8]) { debug_assert!( tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols()) ); @@ -501,13 +403,13 @@ impl VmpPMatOps for Module { } unsafe { vmp::vmp_apply_dft( - self.0, - c.0, + self.ptr, + c.ptr as *mut vec_znx_dft_t, c.cols() as u64, a.as_ptr(), a.cols() as u64, a.n() as u64, - b.data(), + b.as_ptr() as *const vmp_pmat_t, b.rows() as u64, b.cols() as u64, tmp_bytes.as_mut_ptr(), @@ -524,7 +426,7 @@ impl VmpPMatOps for Module { ) -> usize { unsafe { vmp::vmp_apply_dft_to_dft_tmp_bytes( - self.0, + self.ptr, res_cols as u64, a_cols as u64, gct_rows as u64, @@ -550,12 +452,12 @@ impl VmpPMatOps for Module { } unsafe { vmp::vmp_apply_dft_to_dft( - self.0, - c.0, + self.ptr, + c.ptr as *mut vec_znx_dft_t, c.cols() as u64, - a.0, + a.ptr as *const vec_znx_dft_t, a.cols() as u64, - b.data(), + b.as_ptr() as *const vmp_pmat_t, b.rows() as u64, b.cols() as u64, tmp_bytes.as_mut_ptr(), @@ -574,12 +476,12 @@ impl VmpPMatOps for Module { } unsafe { vmp::vmp_apply_dft_to_dft( - self.0, - b.0, + self.ptr, + b.ptr as *mut vec_znx_dft_t, b.cols() as u64, - b.0, + b.ptr as *mut vec_znx_dft_t, b.cols() as u64, - a.data(), + a.as_ptr() as *const vmp_pmat_t, a.rows() as u64, a.cols() as u64, tmp_bytes.as_mut_ptr(), diff --git a/rlwe/Cargo.toml b/rlwe/Cargo.toml index 21c5ddd..a8b8207 100644 --- a/rlwe/Cargo.toml +++ b/rlwe/Cargo.toml @@ -11,6 +11,7 @@ criterion = {workspace = true} base2k = {path="../base2k"} sampling = {path="../sampling"} rand_distr = {workspace = true} +itertools = {workspace = true} [[bench]] name = "gadget_product" diff --git a/rlwe/benches/gadget_product.rs b/rlwe/benches/gadget_product.rs index e5e5b12..e7dd8c2 100644 --- a/rlwe/benches/gadget_product.rs +++ b/rlwe/benches/gadget_product.rs @@ -1,5 +1,5 @@ use base2k::{ - FFT64, Infos, Module, Sampling, SvpPPolOps, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, + Infos, MODULETYPE, Module, Sampling, SvpPPolOps, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, alloc_aligned_u8, }; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; @@ -36,6 +36,7 @@ fn bench_gadget_product_inplace(c: &mut Criterion) { for log_n in 10..11 { let params_lit: ParametersLiteral = ParametersLiteral { + backend: MODULETYPE::FFT64, log_n: log_n, log_q: 32, log_p: 0, @@ -45,7 +46,7 @@ fn bench_gadget_product_inplace(c: &mut Criterion) { xs: 128, }; - let params: Parameters = Parameters::new::(¶ms_lit); + let params: Parameters = Parameters::new(¶ms_lit); let mut tmp_bytes: Vec = alloc_aligned_u8( params.encrypt_rlwe_sk_tmp_bytes(params.log_q()) @@ -101,7 +102,7 @@ fn bench_gadget_product_inplace(c: &mut Criterion) { let mut ct: Ciphertext = params.new_ciphertext(params.log_q()); - params.encrypt_rlwe_sk_thread_safe( + params.encrypt_rlwe_sk( &mut ct, None, &sk0_svp_ppol, diff --git a/rlwe/examples/encryption.rs b/rlwe/examples/encryption.rs index 4a05523..4002a95 100644 --- a/rlwe/examples/encryption.rs +++ b/rlwe/examples/encryption.rs @@ -1,4 +1,4 @@ -use base2k::{Encoding, FFT64, SvpPPolOps, VecZnx, VecZnxApi}; +use base2k::{Encoding, SvpPPolOps, VecZnx, alloc_aligned}; use rlwe::{ ciphertext::Ciphertext, elem::ElemCommon, @@ -10,6 +10,7 @@ use sampling::source::Source; fn main() { let params_lit: ParametersLiteral = ParametersLiteral { + backend: base2k::MODULETYPE::FFT64, log_n: 10, log_q: 54, log_p: 0, @@ -19,13 +20,12 @@ fn main() { xs: 128, }; - let params: Parameters = Parameters::new::(¶ms_lit); + let params: Parameters = Parameters::new(¶ms_lit); - let mut tmp_bytes: Vec = vec![ - 0u8; + let mut tmp_bytes: Vec = alloc_aligned( params.decrypt_rlwe_tmp_byte(params.log_q()) - | params.encrypt_rlwe_sk_tmp_bytes(params.log_q()) - ]; + | params.encrypt_rlwe_sk_tmp_bytes(params.log_q()), + ); let mut source: Source = Source::new([0; 32]); let mut sk: SecretKey = SecretKey::new(params.module()); @@ -35,7 +35,7 @@ fn main() { want.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); - let mut pt: Plaintext = params.new_plaintext(params.log_q()); + let mut pt: Plaintext = params.new_plaintext(params.log_q()); let log_base2k = pt.log_base2k(); @@ -56,7 +56,7 @@ fn main() { let mut sk_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol(); params.module().svp_prepare(&mut sk_svp_ppol, &sk.0); - params.encrypt_rlwe_sk_thread_safe( + params.encrypt_rlwe_sk( &mut ct, Some(&pt), &sk_svp_ppol, @@ -66,7 +66,6 @@ fn main() { ); params.decrypt_rlwe(&mut pt, &ct, &sk_svp_ppol, &mut tmp_bytes); - pt.0.value[0].print(pt.cols(), 16); let mut have = vec![i64::default(); params.n()]; diff --git a/rlwe/examples/rlk_experiments.rs b/rlwe/examples/rlk_experiments.rs new file mode 100644 index 0000000..a35904c --- /dev/null +++ b/rlwe/examples/rlk_experiments.rs @@ -0,0 +1,151 @@ +use base2k::{ + Encoding, Infos, Module, Sampling, SvpPPol, SvpPPolOps, VecZnx, VecZnxDftOps, VecZnxOps, + VmpPMat, VmpPMatOps, is_aligned, +}; +use itertools::izip; +use rlwe::ciphertext::{Ciphertext, new_gadget_ciphertext}; +use rlwe::elem::ElemCommon; +use rlwe::encryptor::encrypt_rlwe_sk; +use rlwe::keys::SecretKey; +use rlwe::plaintext::Plaintext; +use sampling::source::{Source, new_seed}; + +fn main() { + let n: usize = 32; + let module: Module = Module::new(n, base2k::MODULETYPE::FFT64); + let log_base2k: usize = 16; + let log_k: usize = 32; + let cols: usize = 4; + + let mut a: VecZnx = module.new_vec_znx(cols); + let mut data: Vec = vec![0i64; n]; + data[0] = 0; + data[1] = 0; + a.encode_vec_i64(log_base2k, log_k, &data, 16); + + let mut a_dft: base2k::VecZnxDft = module.new_vec_znx_dft(cols); + + module.vec_znx_dft(&mut a_dft, &a, cols); + + (0..cols).for_each(|i| { + println!("{:?}", a_dft.at::(&module, i)); + }) +} + +pub struct GadgetCiphertextProtocol {} + +impl GadgetCiphertextProtocol { + pub fn new() -> GadgetCiphertextProtocol { + Self {} + } + + pub fn allocate( + module: &Module, + log_base2k: usize, + rows: usize, + log_q: usize, + ) -> GadgetCiphertextShare { + GadgetCiphertextShare::new(module, log_base2k, rows, log_q) + } + + pub fn gen_share( + module: &Module, + sk: &SecretKey, + pt: &Plaintext, + seed: &[u8; 32], + share: &mut GadgetCiphertextShare, + tmp_bytes: &mut [u8], + ) { + share.seed.copy_from_slice(seed); + let mut source_xe: Source = Source::new(new_seed()); + let mut source_xa: Source = Source::new(*seed); + let mut sk_ppol: SvpPPol = module.new_svp_ppol(); + sk.prepare(module, &mut sk_ppol); + share.value.iter_mut().for_each(|ai| { + //let elem = Elem{}; + //encrypt_rlwe_sk_thread_safe(module, ai, Some(pt.elem()), &sk_ppol, &mut source_xa, &mut source_xe, 3.2, tmp_bytes); + }) + } +} + +pub struct GadgetCiphertextShare { + pub seed: [u8; 32], + pub log_q: usize, + pub log_base2k: usize, + pub value: Vec, +} + +impl GadgetCiphertextShare { + pub fn new(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> Self { + let value: Vec = Vec::new(); + let cols: usize = (log_q + log_base2k - 1) / log_base2k; + (0..rows).for_each(|_| { + let vec_znx: VecZnx = module.new_vec_znx(cols); + }); + Self { + seed: [u8::default(); 32], + log_q: log_q, + log_base2k: log_base2k, + value: value, + } + } + + pub fn rows(&self) -> usize { + self.value.len() + } + + pub fn cols(&self) -> usize { + self.value[0].cols() + } + + pub fn aggregate_inplace(&mut self, module: &Module, a: &GadgetCiphertextShare) { + izip!(self.value.iter_mut(), a.value.iter()).for_each(|(bi, ai)| { + module.vec_znx_add_inplace(bi, ai); + }) + } + + pub fn get(&self, module: &Module, b: &mut Ciphertext, tmp_bytes: &mut [u8]) { + assert!(is_aligned(tmp_bytes.as_ptr())); + + let rows: usize = b.rows(); + let cols: usize = b.cols(); + + assert!(tmp_bytes.len() >= gadget_ciphertext_share_get_tmp_bytes(module, rows, cols)); + + assert_eq!(self.value.len(), rows); + assert_eq!(self.value[0].cols(), cols); + + let (tmp_bytes_vmp_prepare_row, tmp_bytes_vec_znx) = + tmp_bytes.split_at_mut(module.vmp_prepare_tmp_bytes(rows, cols)); + + let mut c: VecZnx = VecZnx::from_bytes_borrow(module.n(), cols, tmp_bytes_vec_znx); + + let mut source: Source = Source::new(self.seed); + + (0..self.value.len()).for_each(|row_i| { + module.vmp_prepare_row( + b.at_mut(0), + self.value[row_i].raw(), + row_i, + tmp_bytes_vmp_prepare_row, + ); + module.fill_uniform(self.log_base2k, &mut c, cols, &mut source); + module.vmp_prepare_row(b.at_mut(1), c.raw(), row_i, tmp_bytes_vmp_prepare_row) + }) + } + + pub fn get_new(&self, module: &Module, tmp_bytes: &mut [u8]) -> Ciphertext { + let mut b: Ciphertext = + new_gadget_ciphertext(module, self.log_base2k, self.rows(), self.log_q); + self.get(module, &mut b, tmp_bytes); + b + } +} + +pub fn gadget_ciphertext_share_get_tmp_bytes(module: &Module, rows: usize, cols: usize) -> usize { + module.vmp_prepare_tmp_bytes(rows, cols) + module.bytes_of_vec_znx(cols) +} + +pub struct CircularCiphertextProtocol {} + +pub struct CircularGadgetCiphertextProtocol {} diff --git a/rlwe/src/decryptor.rs b/rlwe/src/decryptor.rs index 63a892b..e4d9545 100644 --- a/rlwe/src/decryptor.rs +++ b/rlwe/src/decryptor.rs @@ -1,11 +1,11 @@ use crate::{ ciphertext::Ciphertext, - elem::{Elem, ElemCommon, VecZnxCommon}, + elem::{Elem, ElemCommon}, keys::SecretKey, parameters::Parameters, plaintext::Plaintext, }; -use base2k::{Module, SvpPPol, SvpPPolOps, VecZnxBigOps, VecZnxDft, VecZnxDftOps}; +use base2k::{Module, SvpPPol, SvpPPolOps, VecZnx, VecZnxBigOps, VecZnxDft, VecZnxDftOps}; use std::cmp::min; pub struct Decryptor { @@ -32,30 +32,24 @@ impl Parameters { ) } - pub fn decrypt_rlwe( + pub fn decrypt_rlwe( &self, - res: &mut Plaintext, - ct: &Ciphertext, + res: &mut Plaintext, + ct: &Ciphertext, sk: &SvpPPol, tmp_bytes: &mut [u8], - ) where - T: VecZnxCommon, - Elem: ElemCommon, - { + ) { decrypt_rlwe(self.module(), &mut res.0, &ct.0, sk, tmp_bytes) } } -pub fn decrypt_rlwe( +pub fn decrypt_rlwe( module: &Module, - res: &mut Elem, - a: &Elem, + res: &mut Elem, + a: &Elem, sk: &SvpPPol, tmp_bytes: &mut [u8], -) where - T: VecZnxCommon, - Elem: ElemCommon, -{ +) { let cols: usize = a.cols(); assert!( @@ -65,9 +59,11 @@ pub fn decrypt_rlwe( decrypt_rlwe_tmp_byte(module, cols) ); - let res_dft_bytes: usize = module.bytes_of_vec_znx_dft(cols); + let (tmp_bytes_vec_znx_dft, tmp_bytes_normalize) = + tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols)); - let mut res_dft: VecZnxDft = VecZnxDft::from_bytes(a.cols(), tmp_bytes); + let mut res_dft: VecZnxDft = + VecZnxDft::from_bytes_borrow(module, a.cols(), tmp_bytes_vec_znx_dft); let mut res_big: base2k::VecZnxBig = res_dft.as_vec_znx_big(); // res_dft <- DFT(ct[1]) * DFT(sk) @@ -77,12 +73,7 @@ pub fn decrypt_rlwe( // res_big <- ct[1] x sk + ct[0] module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0)); // res <- normalize(ct[1] x sk + ct[0]) - module.vec_znx_big_normalize( - a.log_base2k(), - res.at_mut(0), - &res_big, - &mut tmp_bytes[res_dft_bytes..], - ); + module.vec_znx_big_normalize(a.log_base2k(), res.at_mut(0), &res_big, tmp_bytes_normalize); res.log_base2k = a.log_base2k(); res.log_q = min(res.log_q(), a.log_q()); diff --git a/rlwe/src/elem.rs b/rlwe/src/elem.rs index 4466c07..6be3038 100644 --- a/rlwe/src/elem.rs +++ b/rlwe/src/elem.rs @@ -1,17 +1,7 @@ -use base2k::{Infos, Module, VecZnx, VecZnxBorrow, VecZnxOps, VmpPMat, VmpPMatOps}; +use base2k::{Infos, Module, VecZnx, VecZnxOps, VmpPMat, VmpPMatOps}; use crate::parameters::Parameters; -impl Parameters { - pub fn elem_from_bytes(&self, log_q: usize, size: usize, bytes: &mut [u8]) -> Elem - where - T: VecZnxCommon, - Elem: ElemVecZnx, - { - Elem::::from_bytes(self.module(), self.log_base2k(), log_q, size, bytes) - } -} - pub struct Elem { pub value: Vec, pub log_base2k: usize, @@ -19,26 +9,26 @@ pub struct Elem { pub log_scale: usize, } -pub trait VecZnxCommon: base2k::VecZnxCommon {} -impl VecZnxCommon for VecZnx {} -impl VecZnxCommon for VecZnxBorrow {} - -pub trait ElemVecZnx> { +pub trait ElemVecZnx { fn from_bytes( module: &Module, log_base2k: usize, log_q: usize, size: usize, bytes: &mut [u8], - ) -> Elem; + ) -> Elem; + fn from_bytes_borrow( + module: &Module, + log_base2k: usize, + log_q: usize, + size: usize, + bytes: &mut [u8], + ) -> Elem; fn bytes_of(module: &Module, log_base2k: usize, log_q: usize, size: usize) -> usize; fn zero(&mut self); } -impl ElemVecZnx for Elem -where - T: VecZnxCommon, -{ +impl ElemVecZnx for Elem { fn bytes_of(module: &Module, log_base2k: usize, log_q: usize, size: usize) -> usize { let cols = (log_q + log_base2k - 1) / log_base2k; module.n() * cols * size * 8 @@ -50,16 +40,42 @@ where log_q: usize, size: usize, bytes: &mut [u8], - ) -> Elem { + ) -> Elem { assert!(size > 0); let n: usize = module.n(); assert!(bytes.len() >= Self::bytes_of(module, log_base2k, log_q, size)); - let mut value: Vec = Vec::new(); + let mut value: Vec = Vec::new(); let limbs: usize = (log_q + log_base2k - 1) / log_base2k; - let elem_size = T::bytes_of(n, limbs); + let elem_size = VecZnx::bytes_of(n, limbs); let mut ptr: usize = 0; (0..size).for_each(|_| { - value.push(T::from_bytes(n, limbs, &mut bytes[ptr..])); + value.push(VecZnx::from_bytes(n, limbs, &mut bytes[ptr..])); + ptr += elem_size + }); + Self { + value, + log_q, + log_base2k, + log_scale: 0, + } + } + + fn from_bytes_borrow( + module: &Module, + log_base2k: usize, + log_q: usize, + size: usize, + bytes: &mut [u8], + ) -> Elem { + assert!(size > 0); + let n: usize = module.n(); + assert!(bytes.len() >= Self::bytes_of(module, log_base2k, log_q, size)); + let mut value: Vec = Vec::new(); + let limbs: usize = (log_q + log_base2k - 1) / log_base2k; + let elem_size = VecZnx::bytes_of(n, limbs); + let mut ptr: usize = 0; + (0..size).for_each(|_| { + value.push(VecZnx::from_bytes_borrow(n, limbs, &mut bytes[ptr..])); ptr += elem_size }); Self { diff --git a/rlwe/src/encryptor.rs b/rlwe/src/encryptor.rs index cb69944..dd25280 100644 --- a/rlwe/src/encryptor.rs +++ b/rlwe/src/encryptor.rs @@ -1,12 +1,12 @@ use crate::ciphertext::Ciphertext; -use crate::elem::{Elem, ElemCommon, ElemVecZnx, VecZnxCommon}; +use crate::elem::{Elem, ElemCommon, ElemVecZnx}; use crate::keys::SecretKey; use crate::parameters::Parameters; use crate::plaintext::Plaintext; use base2k::sampling::Sampling; use base2k::{ - Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxApi, VecZnxBig, VecZnxBigOps, VecZnxBorrow, - VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, + Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, + VecZnxOps, VmpPMat, VmpPMatOps, }; use sampling::source::{Source, new_seed}; @@ -49,20 +49,17 @@ impl EncryptorSk { self.source_xe = Source::new(seed) } - pub fn encrypt_rlwe_sk( + pub fn encrypt_rlwe_sk( &mut self, params: &Parameters, - ct: &mut Ciphertext, - pt: Option<&Plaintext>, - ) where - T: VecZnxCommon, - Elem: ElemCommon, - { + ct: &mut Ciphertext, + pt: Option<&Plaintext>, + ) { assert!( self.initialized == true, "invalid call to [EncryptorSk.encrypt_rlwe_sk]: [EncryptorSk] has not been initialized with a [SecretKey]" ); - params.encrypt_rlwe_sk_thread_safe( + params.encrypt_rlwe_sk( ct, pt, &self.sk, @@ -72,23 +69,20 @@ impl EncryptorSk { ); } - pub fn encrypt_rlwe_sk_thread_safe( + pub fn encrypt_rlwe_sk_core( &self, params: &Parameters, - ct: &mut Ciphertext, - pt: Option<&Plaintext>, + ct: &mut Ciphertext, + pt: Option<&Plaintext>, source_xa: &mut Source, source_xe: &mut Source, tmp_bytes: &mut [u8], - ) where - T: VecZnxCommon, - Elem: ElemCommon, - { + ) { assert!( self.initialized == true, - "invalid call to [EncryptorSk.encrypt_rlwe_sk_thread_safe]: [EncryptorSk] has not been initialized with a [SecretKey]" + "invalid call to [EncryptorSk.encrypt_rlwe_sk]: [EncryptorSk] has not been initialized with a [SecretKey]" ); - params.encrypt_rlwe_sk_thread_safe(ct, pt, &self.sk, source_xa, source_xe, tmp_bytes); + params.encrypt_rlwe_sk(ct, pt, &self.sk, source_xa, source_xe, tmp_bytes); } } @@ -97,19 +91,16 @@ impl Parameters { encrypt_rlwe_sk_tmp_bytes(self.module(), self.log_base2k(), log_q) } - pub fn encrypt_rlwe_sk_thread_safe( + pub fn encrypt_rlwe_sk( &self, - ct: &mut Ciphertext, - pt: Option<&Plaintext>, + ct: &mut Ciphertext, + pt: Option<&Plaintext>, sk: &SvpPPol, source_xa: &mut Source, source_xe: &mut Source, tmp_bytes: &mut [u8], - ) where - T: VecZnxCommon, - Elem: ElemCommon, - { - encrypt_rlwe_sk_thread_safe( + ) { + encrypt_rlwe_sk( self.module(), &mut ct.0, pt.map(|pt| &pt.0), @@ -127,19 +118,16 @@ pub fn encrypt_rlwe_sk_tmp_bytes(module: &Module, log_base2k: usize, log_q: usiz + module.vec_znx_big_normalize_tmp_bytes() } -pub fn encrypt_rlwe_sk_thread_safe( +pub fn encrypt_rlwe_sk( module: &Module, - ct: &mut Elem, - pt: Option<&Elem>, + ct: &mut Elem, + pt: Option<&Elem>, sk: &SvpPPol, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, tmp_bytes: &mut [u8], -) where - T: VecZnxCommon, - Elem: ElemCommon, -{ +) { let cols: usize = ct.cols(); let log_base2k: usize = ct.log_base2k(); let log_q: usize = ct.log_q(); @@ -153,16 +141,16 @@ pub fn encrypt_rlwe_sk_thread_safe( let log_q: usize = ct.log_q(); let log_base2k: usize = ct.log_base2k(); - let c1: &mut T = ct.at_mut(1); + let c1: &mut VecZnx = ct.at_mut(1); // c1 <- Z_{2^prec}[X]/(X^{N}+1) module.fill_uniform(log_base2k, c1, cols, source_xa); - let bytes_of_vec_znx_dft: usize = module.bytes_of_vec_znx_dft(cols); + let (tmp_bytes_vec_znx_dft, tmp_bytes_normalize) = + tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols)); // Scratch space for DFT values - let mut buf_dft: VecZnxDft = - VecZnxDft::from_bytes(cols, &mut tmp_bytes[..bytes_of_vec_znx_dft]); + let mut buf_dft: VecZnxDft = VecZnxDft::from_bytes_borrow(module, cols, tmp_bytes_vec_znx_dft); // Applies buf_dft <- DFT(s) * DFT(c1) module.svp_apply_dft(&mut buf_dft, sk, c1, cols); @@ -173,16 +161,14 @@ pub fn encrypt_rlwe_sk_thread_safe( // buf_big = s x c1 module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft, cols); - let carry: &mut [u8] = &mut tmp_bytes[bytes_of_vec_znx_dft..]; - // c0 <- -s x c1 + m - let c0: &mut T = ct.at_mut(0); + let c0: &mut VecZnx = ct.at_mut(0); if let Some(pt) = pt { module.vec_znx_big_sub_small_a_inplace(&mut buf_big, pt.at(0)); - module.vec_znx_big_normalize(log_base2k, c0, &buf_big, carry); + module.vec_znx_big_normalize(log_base2k, c0, &buf_big, tmp_bytes_normalize); } else { - module.vec_znx_big_normalize(log_base2k, c0, &buf_big, carry); + module.vec_znx_big_normalize(log_base2k, c0, &buf_big, tmp_bytes_normalize); module.vec_znx_negate_inplace(c0); } @@ -211,7 +197,7 @@ pub fn encrypt_grlwe_sk_tmp_bytes( ) -> usize { let cols = (log_q + log_base2k - 1) / log_base2k; Elem::::bytes_of(module, log_base2k, log_q, 2) - + Plaintext::::bytes_of(module, log_base2k, log_q) + + Plaintext::bytes_of(module, log_base2k, log_q) + encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q) + module.vmp_prepare_tmp_bytes(rows, cols) } @@ -240,25 +226,25 @@ pub fn encrypt_grlwe_sk( min_tmp_bytes_len ); - let bytes_of_elem: usize = Elem::::bytes_of(module, log_base2k, log_q, 2); - let bytes_of_pt: usize = Plaintext::::bytes_of(module, log_base2k, log_q); + let bytes_of_elem: usize = Elem::::bytes_of(module, log_base2k, log_q, 2); + let bytes_of_pt: usize = Plaintext::bytes_of(module, log_base2k, log_q); let bytes_of_enc_sk: usize = encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q); let (tmp_bytes_pt, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_pt); let (tmp_bytes_enc_sk, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_enc_sk); let (tmp_bytes_elem, tmp_bytes_vmp_prepare_row) = tmp_bytes.split_at_mut(bytes_of_elem); - let mut tmp_elem: Elem = - Elem::::from_bytes(module, log_base2k, ct.log_q(), 2, tmp_bytes_elem); - let mut tmp_pt: Plaintext = - Plaintext::::from_bytes(module, log_base2k, log_q, tmp_bytes_pt); + let mut tmp_elem: Elem = + Elem::::from_bytes_borrow(module, log_base2k, ct.log_q(), 2, tmp_bytes_elem); + let mut tmp_pt: Plaintext = + Plaintext::from_bytes_borrow(module, log_base2k, log_q, tmp_bytes_pt); (0..rows).for_each(|row_i| { // Sets the i-th row of the RLWE sample to m (i.e. m * 2^{-log_base2k*i}) - tmp_pt.at_mut(0).at_mut(row_i).copy_from_slice(&m.0); + tmp_pt.at_mut(0).at_mut(row_i).copy_from_slice(&m.raw()); // Encrypts RLWE(m * 2^{-log_base2k*i}) - encrypt_rlwe_sk_thread_safe( + encrypt_rlwe_sk( module, &mut tmp_elem, Some(&tmp_pt.0), diff --git a/rlwe/src/gadget_product.rs b/rlwe/src/gadget_product.rs index 73651dc..0139f88 100644 --- a/rlwe/src/gadget_product.rs +++ b/rlwe/src/gadget_product.rs @@ -1,9 +1,5 @@ -use crate::{ - ciphertext::Ciphertext, - elem::{Elem, ElemCommon, ElemVecZnx, VecZnxCommon}, - parameters::Parameters, -}; -use base2k::{Module, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps}; +use crate::{ciphertext::Ciphertext, elem::ElemCommon, parameters::Parameters}; +use base2k::{Module, VecZnx, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps}; use std::cmp::min; pub fn gadget_product_tmp_bytes( @@ -53,19 +49,16 @@ impl Parameters { /// /// res = sum[min(a_ncols, b_nrows)] decomp(a, i) * (-B[i]s + m * 2^{-k*i} + E[i], B[i]) /// = (cs + m * a + e, c) with min(res_cols, b_cols) cols. -pub fn gadget_product_core( +pub fn gadget_product_core( module: &Module, res_dft_0: &mut VecZnxDft, res_dft_1: &mut VecZnxDft, - a: &T, + a: &VecZnx, a_cols: usize, b: &Ciphertext, b_cols: usize, tmp_bytes: &mut [u8], -) where - T: VecZnxCommon, - Elem: ElemVecZnx, -{ +) { assert!(b_cols <= b.cols()); module.vec_znx_dft(res_dft_1, a, min(a_cols, b_cols)); module.vmp_apply_dft_to_dft(res_dft_0, res_dft_1, b.at(0), tmp_bytes); @@ -104,7 +97,7 @@ mod test { plaintext::Plaintext, }; use base2k::{ - FFT64, Infos, Sampling, SvpPPolOps, VecZnx, VecZnxApi, VecZnxBig, VecZnxBigOps, VecZnxDft, + Infos, MODULETYPE, Sampling, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, alloc_aligned_u8, }; use sampling::source::{Source, new_seed}; @@ -117,6 +110,7 @@ mod test { // Basic parameters with enough limbs to test edge cases let params_lit: ParametersLiteral = ParametersLiteral { + backend: MODULETYPE::FFT64, log_n: 12, log_q: q_cols * log_base2k, log_p: p_cols * log_base2k, @@ -126,7 +120,7 @@ mod test { xs: 1 << 11, }; - let params: Parameters = Parameters::new::(¶ms_lit); + let params: Parameters = Parameters::new(¶ms_lit); // scratch space let mut tmp_bytes: Vec = alloc_aligned_u8( @@ -213,8 +207,8 @@ mod test { ); // Plaintext for decrypted output of gadget product - let mut pt: Plaintext = - Plaintext::::new(params.module(), params.log_base2k(), params.log_qp()); + let mut pt: Plaintext = + Plaintext::new(params.module(), params.log_base2k(), params.log_qp()); // Iterates over all possible cols values for input/output polynomials and gadget ciphertext. diff --git a/rlwe/src/keys.rs b/rlwe/src/keys.rs index df78e11..da7c412 100644 --- a/rlwe/src/keys.rs +++ b/rlwe/src/keys.rs @@ -1,6 +1,6 @@ use crate::ciphertext::{Ciphertext, new_gadget_ciphertext}; use crate::elem::{Elem, ElemCommon}; -use crate::encryptor::{encrypt_rlwe_sk_thread_safe, encrypt_rlwe_sk_tmp_bytes}; +use crate::encryptor::{encrypt_rlwe_sk, encrypt_rlwe_sk_tmp_bytes}; use base2k::{Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VmpPMat}; use sampling::source::Source; @@ -40,7 +40,7 @@ impl PublicKey { xe_source: &mut Source, tmp_bytes: &mut [u8], ) { - encrypt_rlwe_sk_thread_safe( + encrypt_rlwe_sk( module, &mut self.0, None, diff --git a/rlwe/src/parameters.rs b/rlwe/src/parameters.rs index 8cc948f..a0860d0 100644 --- a/rlwe/src/parameters.rs +++ b/rlwe/src/parameters.rs @@ -1,6 +1,7 @@ use base2k::module::{MODULETYPE, Module}; pub struct ParametersLiteral { + pub backend: MODULETYPE, pub log_n: usize, pub log_q: usize, pub log_p: usize, @@ -22,7 +23,7 @@ pub struct Parameters { } impl Parameters { - pub fn new(p: &ParametersLiteral) -> Self { + pub fn new(p: &ParametersLiteral) -> Self { assert!( p.log_n + 2 * p.log_base2k <= 53, "invalid parameters: p.log_n + 2*p.log_base2k > 53" @@ -35,7 +36,7 @@ impl Parameters { log_base2k: p.log_base2k, xe: p.xe, xs: p.xs, - module: Module::new::(1 << p.log_n), + module: Module::new(1 << p.log_n, p.backend), } } diff --git a/rlwe/src/plaintext.rs b/rlwe/src/plaintext.rs index 78a62cb..40dd5d3 100644 --- a/rlwe/src/plaintext.rs +++ b/rlwe/src/plaintext.rs @@ -1,61 +1,65 @@ use crate::ciphertext::Ciphertext; -use crate::elem::{Elem, ElemCommon, ElemVecZnx, VecZnxCommon}; +use crate::elem::{Elem, ElemCommon, ElemVecZnx}; use crate::parameters::Parameters; use base2k::{Module, VecZnx}; -pub struct Plaintext(pub Elem); +pub struct Plaintext(pub Elem); impl Parameters { - pub fn new_plaintext(&self, log_q: usize) -> Plaintext { + pub fn new_plaintext(&self, log_q: usize) -> Plaintext { Plaintext::new(self.module(), self.log_base2k(), log_q) } - pub fn bytes_of_plaintext(&self, log_q: usize) -> usize - where - T: VecZnxCommon, - Elem: ElemVecZnx, - { - Elem::::bytes_of(self.module(), self.log_base2k(), log_q, 1) + pub fn bytes_of_plaintext(&self, log_q: usize) -> usize +where { + Elem::::bytes_of(self.module(), self.log_base2k(), log_q, 1) } - pub fn plaintext_from_bytes(&self, log_q: usize, bytes: &mut [u8]) -> Plaintext - where - T: VecZnxCommon, - Elem: ElemVecZnx, - { - Plaintext::(self.elem_from_bytes::(log_q, 1, bytes)) + pub fn plaintext_from_bytes(&self, log_q: usize, bytes: &mut [u8]) -> Plaintext { + Plaintext(Elem::::from_bytes( + self.module(), + self.log_base2k(), + log_q, + 1, + bytes, + )) } } -impl Plaintext { +impl Plaintext { pub fn new(module: &Module, log_base2k: usize, log_q: usize) -> Self { Self(Elem::::new(module, log_base2k, log_q, 1)) } } -impl Plaintext -where - T: VecZnxCommon, - Elem: ElemVecZnx, -{ +impl Plaintext { pub fn bytes_of(module: &Module, log_base2k: usize, log_q: usize) -> usize { - Elem::::bytes_of(module, log_base2k, log_q, 1) + Elem::::bytes_of(module, log_base2k, log_q, 1) } pub fn from_bytes(module: &Module, log_base2k: usize, log_q: usize, bytes: &mut [u8]) -> Self { - Self(Elem::::from_bytes(module, log_base2k, log_q, 1, bytes)) + Self(Elem::::from_bytes( + module, log_base2k, log_q, 1, bytes, + )) } - pub fn as_ciphertext(&self) -> Ciphertext { - unsafe { Ciphertext::(std::ptr::read(&self.0)) } + pub fn from_bytes_borrow( + module: &Module, + log_base2k: usize, + log_q: usize, + bytes: &mut [u8], + ) -> Self { + Self(Elem::::from_bytes_borrow( + module, log_base2k, log_q, 1, bytes, + )) + } + + pub fn as_ciphertext(&self) -> Ciphertext { + unsafe { Ciphertext::(std::ptr::read(&self.0)) } } } -impl ElemCommon for Plaintext -where - T: VecZnxCommon, - Elem: ElemVecZnx, -{ +impl ElemCommon for Plaintext { fn n(&self) -> usize { self.0.n() } @@ -68,11 +72,11 @@ where self.0.log_q } - fn elem(&self) -> &Elem { + fn elem(&self) -> &Elem { &self.0 } - fn elem_mut(&mut self) -> &mut Elem { + fn elem_mut(&mut self) -> &mut Elem { &mut self.0 } @@ -88,11 +92,11 @@ where self.0.cols() } - fn at(&self, i: usize) -> &T { + fn at(&self, i: usize) -> &VecZnx { self.0.at(i) } - fn at_mut(&mut self, i: usize) -> &mut T { + fn at_mut(&mut self, i: usize) -> &mut VecZnx { self.0.at_mut(i) } diff --git a/rlwe/src/rgsw_product.rs b/rlwe/src/rgsw_product.rs index f2bda0f..51ee793 100644 --- a/rlwe/src/rgsw_product.rs +++ b/rlwe/src/rgsw_product.rs @@ -1,20 +1,19 @@ use crate::{ ciphertext::Ciphertext, - elem::{Elem, ElemCommon, ElemVecZnx, VecZnxCommon}, + elem::{Elem, ElemCommon, ElemVecZnx}, +}; +use base2k::{ + Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps, }; -use base2k::{Module, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps}; use std::cmp::min; -pub fn rgsw_product( +pub fn rgsw_product( module: &Module, - _res: &mut Elem, - a: &Ciphertext, + _res: &mut Elem, + a: &Ciphertext, b: &Ciphertext, tmp_bytes: &mut [u8], -) where - T: VecZnxCommon, - Elem: ElemVecZnx, -{ +) { let _log_base2k: usize = b.log_base2k(); let rows: usize = min(b.rows(), a.cols()); let cols: usize = b.cols();