diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index cb9dfa8..3d53141 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -1,6 +1,6 @@ use base2k::{ - Encoding, FFT64, Infos, Module, Sampling, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, - VecZnxDftOps, VecZnxLayout, VecZnxOps, alloc_aligned, + Encoding, FFT64, Module, Sampling, Scalar, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, + VecZnxDftOps, VecZnxOps, ZnxInfos, ZnxLayout, alloc_aligned, }; use itertools::izip; use sampling::source::Source; @@ -25,7 +25,7 @@ fn main() { s.fill_ternary_prob(0.5, &mut source); // Buffer to store s in the DFT domain - let mut s_ppol: SvpPPol = module.new_svp_ppol(); + let mut s_ppol: ScalarZnxDft = module.new_svp_ppol(); // s_ppol <- DFT(s) module.svp_prepare(&mut s_ppol, &s); diff --git a/base2k/examples/vector_matrix_product.rs b/base2k/examples/vector_matrix_product.rs index 8d4a33d..0120f61 100644 --- a/base2k/examples/vector_matrix_product.rs +++ b/base2k/examples/vector_matrix_product.rs @@ -1,6 +1,6 @@ use base2k::{ - Encoding, FFT64, Infos, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxLayout, VecZnxOps, VmpPMat, - VmpPMatOps, alloc_aligned, + Encoding, FFT64, MatZnxDft, MatZnxDftOps, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, + ZnxInfos, ZnxLayout, alloc_aligned, }; fn main() { @@ -31,16 +31,16 @@ fn main() { a.print(n); println!(); - let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows_mat, 1, limbs_mat); + let mut mat_znx_dft: MatZnxDft = module.new_mat_znx_dft(rows_mat, 1, limbs_mat); (0..a.limbs()).for_each(|row_i| { let mut tmp: VecZnx = module.new_vec_znx(1, limbs_mat); tmp.at_limb_mut(row_i)[1] = 1 as i64; - module.vmp_prepare_row(&mut vmp_pmat, tmp.raw(), row_i, &mut buf); + module.vmp_prepare_row(&mut mat_znx_dft, tmp.raw(), row_i, &mut buf); }); let mut c_dft: VecZnxDft = module.new_vec_znx_dft(1, limbs_mat); - module.vmp_apply_dft(&mut c_dft, &a, &vmp_pmat, &mut buf); + module.vmp_apply_dft(&mut c_dft, &a, &mat_znx_dft, &mut buf); let mut c_big: VecZnxBig = c_dft.as_vec_znx_big(); module.vec_znx_idft_tmp_a(&mut c_big, &mut c_dft); diff --git a/base2k/src/commons.rs b/base2k/src/commons.rs index ef7a410..290599d 100644 --- a/base2k/src/commons.rs +++ b/base2k/src/commons.rs @@ -1,4 +1,6 @@ -pub trait Infos { +use crate::{Backend, Module}; + +pub trait ZnxInfos { /// Returns the ring degree of the polynomials. fn n(&self) -> usize; @@ -18,20 +20,34 @@ pub trait Infos { fn poly_count(&self) -> usize; } -pub trait VecZnxLayout: Infos { +pub trait ZnxBase { + type Scalar; + fn new(module: &Module, cols: usize, limbs: usize) -> Self; + fn from_bytes(module: &Module, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self; + fn from_bytes_borrow(module: &Module, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self; + fn bytes_of(module: &Module, cols: usize, limbs: usize) -> usize; +} + +pub trait ZnxLayout: ZnxInfos { type Scalar; + /// Returns a non-mutable pointer to the underlying coefficients array. fn as_ptr(&self) -> *const Self::Scalar; + + /// Returns a mutable pointer to the underlying coefficients array. fn as_mut_ptr(&mut self) -> *mut Self::Scalar; + /// Returns a non-mutable reference to the entire underlying coefficient array. fn raw(&self) -> &[Self::Scalar] { unsafe { std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) } } + /// Returns a mutable reference to the entire underlying coefficient array. fn raw_mut(&mut self) -> &mut [Self::Scalar] { unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.n() * self.poly_count()) } } + /// Returns a non-mutable pointer starting at the (i, j)-th small polynomial. fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar { #[cfg(debug_assertions)] { @@ -42,6 +58,7 @@ pub trait VecZnxLayout: Infos { unsafe { self.as_ptr().add(offset) } } + /// Returns a mutable pointer starting at the (i, j)-th small polynomial. fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar { #[cfg(debug_assertions)] { @@ -52,18 +69,22 @@ pub trait VecZnxLayout: Infos { unsafe { self.as_mut_ptr().add(offset) } } + /// Returns non-mutable reference to the (i, j)-th small polynomial. fn at_poly(&self, i: usize, j: usize) -> &[Self::Scalar] { unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n()) } } + /// Returns mutable reference to the (i, j)-th small polynomial. fn at_poly_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] { unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n()) } } + /// Returns non-mutable reference to the i-th limb. fn at_limb(&self, j: usize) -> &[Self::Scalar] { unsafe { std::slice::from_raw_parts(self.at_ptr(0, j), self.n() * self.cols()) } } + /// Returns mutable reference to the i-th limb. fn at_limb_mut(&mut self, j: usize) -> &mut [Self::Scalar] { unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(0, j), self.n() * self.cols()) } } diff --git a/base2k/src/encoding.rs b/base2k/src/encoding.rs index 5944f3c..6034b95 100644 --- a/base2k/src/encoding.rs +++ b/base2k/src/encoding.rs @@ -1,5 +1,5 @@ use crate::ffi::znx::znx_zero_i64_ref; -use crate::{Infos, VecZnx, VecZnxLayout}; +use crate::{VecZnx, ZnxInfos, ZnxLayout}; use itertools::izip; use rug::{Assign, Float}; use std::cmp::min; @@ -262,17 +262,18 @@ fn decode_coeff_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, i #[cfg(test)] mod tests { - use crate::{Encoding, Infos, VecZnx, VecZnxLayout}; + use crate::{Encoding, FFT64, Module, VecZnx, ZnxBase, ZnxInfos, ZnxLayout}; use itertools::izip; use sampling::source::Source; #[test] fn test_set_get_i64_lo_norm() { let n: usize = 8; + let module: Module = Module::::new(n); let log_base2k: usize = 17; let cols: usize = 5; let log_k: usize = cols * log_base2k - 5; - let mut a: VecZnx = VecZnx::new(n, 2, cols); + let mut a: VecZnx = VecZnx::new(&module, 2, cols); let mut source: Source = Source::new([0u8; 32]); let raw: &mut [i64] = a.raw_mut(); raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); @@ -290,10 +291,11 @@ mod tests { #[test] fn test_set_get_i64_hi_norm() { let n: usize = 8; + let module: Module = Module::::new(n); let log_base2k: usize = 17; let cols: usize = 5; let log_k: usize = cols * log_base2k - 5; - let mut a: VecZnx = VecZnx::new(n, 2, cols); + let mut a: VecZnx = VecZnx::new(&module, 2, cols); let mut source = Source::new([0u8; 32]); let raw: &mut [i64] = a.raw_mut(); raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index 4d54ca0..40df3bb 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -3,26 +3,26 @@ pub mod encoding; #[allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, improper_ctypes)] // Other modules and exports pub mod ffi; +pub mod mat_znx_dft; pub mod module; pub mod sampling; +pub mod scalar_znx_dft; pub mod stats; -pub mod svp; pub mod vec_znx; pub mod vec_znx_big; pub mod vec_znx_dft; -pub mod vmp; pub use commons::*; pub use encoding::*; +pub use mat_znx_dft::*; pub use module::*; pub use sampling::*; +pub use scalar_znx_dft::*; #[allow(unused_imports)] pub use stats::*; -pub use svp::*; pub use vec_znx::*; pub use vec_znx_big::*; pub use vec_znx_dft::*; -pub use vmp::*; pub const GALOISGENERATOR: u64 = 5; pub const DEFAULTALIGN: usize = 64; diff --git a/base2k/src/vmp.rs b/base2k/src/mat_znx_dft.rs similarity index 88% rename from base2k/src/vmp.rs rename to base2k/src/mat_znx_dft.rs index f2af561..9466696 100644 --- a/base2k/src/vmp.rs +++ b/base2k/src/mat_znx_dft.rs @@ -1,7 +1,7 @@ use crate::ffi::vec_znx_big::vec_znx_big_t; use crate::ffi::vec_znx_dft::vec_znx_dft_t; use crate::ffi::vmp::{self, vmp_pmat_t}; -use crate::{Backend, FFT64, Infos, Module, VecZnx, VecZnxBig, VecZnxDft, VecZnxLayout, alloc_aligned, assert_alignement}; +use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, ZnxInfos, ZnxLayout, alloc_aligned, assert_alignement}; use std::marker::PhantomData; /// Vector Matrix Product Prepared Matrix: a vector of [VecZnx], @@ -10,7 +10,7 @@ use std::marker::PhantomData; /// /// [VmpPMat] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [VmpPMat]. /// See the trait [VmpPMatOps] for additional information. -pub struct VmpPMat { +pub struct MatZnxDft { /// Raw data, is empty if borrowing scratch space. data: Vec, /// Pointer to data. Can point to scratch space. @@ -26,7 +26,7 @@ pub struct VmpPMat { _marker: PhantomData, } -impl Infos for VmpPMat { +impl ZnxInfos for MatZnxDft { fn n(&self) -> usize { self.n } @@ -52,11 +52,11 @@ impl Infos for VmpPMat { } } -impl VmpPMat { - fn new(module: &Module, rows: usize, cols: usize, limbs: usize) -> VmpPMat { - let mut data: Vec = alloc_aligned::(module.bytes_of_vmp_pmat(rows, cols, limbs)); +impl MatZnxDft { + fn new(module: &Module, rows: usize, cols: usize, limbs: usize) -> MatZnxDft { + let mut data: Vec = alloc_aligned::(module.bytes_of_mat_znx_dft(rows, cols, limbs)); let ptr: *mut u8 = data.as_mut_ptr(); - VmpPMat:: { + MatZnxDft:: { data: data, ptr: ptr, n: module.n(), @@ -126,8 +126,8 @@ impl VmpPMat { /// This trait implements methods for vector matrix product, /// that is, multiplying a [VecZnx] with a [VmpPMat]. -pub trait VmpPMatOps { - fn bytes_of_vmp_pmat(&self, rows: usize, cols: usize, limbs: usize) -> usize; +pub trait MatZnxDftOps { + fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, limbs: usize) -> usize; /// Allocates a new [VmpPMat] with the given number of rows and columns. /// @@ -135,7 +135,7 @@ pub trait VmpPMatOps { /// /// * `rows`: number of rows (number of [VecZnxDft]). /// * `size`: number of size (number of size of each [VecZnxDft]). - fn new_vmp_pmat(&self, rows: usize, cols: usize, limbs: usize) -> VmpPMat; + fn new_mat_znx_dft(&self, rows: usize, cols: usize, limbs: usize) -> MatZnxDft; /// Returns the number of bytes needed as scratch space for [VmpPMatOps::vmp_prepare_contiguous]. /// @@ -154,7 +154,7 @@ pub trait VmpPMatOps { /// * `b`: [VmpPMat] on which the values are encoded. /// * `a`: the contiguous array of [i64] of the 3D matrix to encode on the [VmpPMat]. /// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. - fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], buf: &mut [u8]); + fn vmp_prepare_contiguous(&self, b: &mut MatZnxDft, a: &[i64], buf: &mut [u8]); /// Prepares the ith-row of [VmpPMat] from a [VecZnx]. /// @@ -166,7 +166,7 @@ pub trait VmpPMatOps { /// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. /// /// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. - fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]); + fn vmp_prepare_row(&self, b: &mut MatZnxDft, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]); /// Extracts the ith-row of [VmpPMat] into a [VecZnxBig]. /// @@ -175,7 +175,7 @@ pub trait VmpPMatOps { /// * `b`: the [VecZnxBig] to on which to extract the row of the [VmpPMat]. /// * `a`: [VmpPMat] on which the values are encoded. /// * `row_i`: the index of the row to extract. - fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &VmpPMat, row_i: usize); + fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &MatZnxDft, row_i: usize); /// Prepares the ith-row of [VmpPMat] from a [VecZnxDft]. /// @@ -186,7 +186,7 @@ pub trait VmpPMatOps { /// * `row_i`: the index of the row to prepare. /// /// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. - fn vmp_prepare_row_dft(&self, b: &mut VmpPMat, a: &VecZnxDft, row_i: usize); + fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft, a: &VecZnxDft, row_i: usize); /// Extracts the ith-row of [VmpPMat] into a [VecZnxDft]. /// @@ -195,7 +195,7 @@ pub trait VmpPMatOps { /// * `b`: the [VecZnxDft] to on which to extract the row of the [VmpPMat]. /// * `a`: [VmpPMat] on which the values are encoded. /// * `row_i`: the index of the row to extract. - fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &VmpPMat, row_i: usize); + fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &MatZnxDft, row_i: usize); /// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft]. /// @@ -231,7 +231,7 @@ pub trait VmpPMatOps { /// * `a`: the left operand [VecZnx] of the vector matrix product. /// * `b`: the right operand [VmpPMat] of the vector matrix product. /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_tmp_bytes]. - fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]); + fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, buf: &mut [u8]); /// Applies the vector matrix product [VecZnxDft] x [VmpPMat] and adds on the receiver. /// @@ -257,7 +257,7 @@ pub trait VmpPMatOps { /// * `a`: the left operand [VecZnx] of the vector matrix product. /// * `b`: the right operand [VmpPMat] of the vector matrix product. /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_tmp_bytes]. - fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]); + fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, buf: &mut [u8]); /// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft_to_dft]. /// @@ -294,7 +294,7 @@ pub trait VmpPMatOps { /// * `a`: the left operand [VecZnxDft] of the vector matrix product. /// * `b`: the right operand [VmpPMat] of the vector matrix product. /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. - fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]); + fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &MatZnxDft, buf: &mut [u8]); /// Applies the vector matrix product [VecZnxDft] x [VmpPMat] and adds on top of the receiver instead of overwritting it. /// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. @@ -321,7 +321,7 @@ pub trait VmpPMatOps { /// * `a`: the left operand [VecZnxDft] of the vector matrix product. /// * `b`: the right operand [VmpPMat] of the vector matrix product. /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. - fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]); + fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &MatZnxDft, buf: &mut [u8]); /// Applies the vector matrix product [VecZnxDft] x [VmpPMat] in place. /// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. @@ -347,15 +347,15 @@ pub trait VmpPMatOps { /// * `b`: the input and output of the vector matrix product, as a [VecZnxDft]. /// * `a`: the right operand [VmpPMat] of the vector matrix product. /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. - fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, buf: &mut [u8]); + fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &MatZnxDft, buf: &mut [u8]); } -impl VmpPMatOps for Module { - fn new_vmp_pmat(&self, rows: usize, cols: usize, limbs: usize) -> VmpPMat { - VmpPMat::::new(self, rows, cols, limbs) +impl MatZnxDftOps for Module { + fn new_mat_znx_dft(&self, rows: usize, cols: usize, limbs: usize) -> MatZnxDft { + MatZnxDft::::new(self, rows, cols, limbs) } - fn bytes_of_vmp_pmat(&self, rows: usize, cols: usize, limbs: usize) -> usize { + fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, limbs: usize) -> usize { unsafe { vmp::bytes_of_vmp_pmat(self.ptr, rows as u64, (limbs * cols) as u64) as usize } } @@ -363,7 +363,7 @@ impl VmpPMatOps for Module { unsafe { vmp::vmp_prepare_tmp_bytes(self.ptr, rows as u64, (size * cols) as u64) as usize } } - fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], tmp_bytes: &mut [u8]) { + fn vmp_prepare_contiguous(&self, b: &mut MatZnxDft, a: &[i64], tmp_bytes: &mut [u8]) { #[cfg(debug_assertions)] { assert_eq!(a.len(), b.n() * b.poly_count()); @@ -382,7 +382,7 @@ impl VmpPMatOps for Module { } } - fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]) { + fn vmp_prepare_row(&self, b: &mut MatZnxDft, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]) { #[cfg(debug_assertions)] { assert_eq!(a.len(), b.limbs() * self.n() * b.cols()); @@ -402,7 +402,7 @@ impl VmpPMatOps for Module { } } - fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &VmpPMat, row_i: usize) { + fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &MatZnxDft, row_i: usize) { #[cfg(debug_assertions)] { assert_eq!(a.n(), b.n()); @@ -421,7 +421,7 @@ impl VmpPMatOps for Module { } } - fn vmp_prepare_row_dft(&self, b: &mut VmpPMat, a: &VecZnxDft, row_i: usize) { + fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft, a: &VecZnxDft, row_i: usize) { #[cfg(debug_assertions)] { assert_eq!(a.n(), b.n()); @@ -439,7 +439,7 @@ impl VmpPMatOps for Module { } } - fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &VmpPMat, row_i: usize) { + fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &MatZnxDft, row_i: usize) { #[cfg(debug_assertions)] { assert_eq!(a.n(), b.n()); @@ -469,7 +469,7 @@ impl VmpPMatOps for Module { } } - fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, tmp_bytes: &mut [u8]) { + fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, tmp_bytes: &mut [u8]) { debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs())); #[cfg(debug_assertions)] { @@ -491,7 +491,7 @@ impl VmpPMatOps for Module { } } - fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, tmp_bytes: &mut [u8]) { + fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, tmp_bytes: &mut [u8]) { debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs())); #[cfg(debug_assertions)] { @@ -525,7 +525,7 @@ impl VmpPMatOps for Module { } } - fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, tmp_bytes: &mut [u8]) { + fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &MatZnxDft, tmp_bytes: &mut [u8]) { debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs())); #[cfg(debug_assertions)] { @@ -546,7 +546,13 @@ impl VmpPMatOps for Module { } } - fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, tmp_bytes: &mut [u8]) { + fn vmp_apply_dft_to_dft_add( + &self, + c: &mut VecZnxDft, + a: &VecZnxDft, + b: &MatZnxDft, + tmp_bytes: &mut [u8], + ) { debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs())); #[cfg(debug_assertions)] { @@ -567,7 +573,7 @@ impl VmpPMatOps for Module { } } - fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, tmp_bytes: &mut [u8]) { + fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &MatZnxDft, tmp_bytes: &mut [u8]) { debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(b.limbs(), b.limbs(), a.rows(), a.limbs())); #[cfg(debug_assertions)] { @@ -592,8 +598,8 @@ impl VmpPMatOps for Module { #[cfg(test)] mod tests { use crate::{ - FFT64, Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxLayout, VecZnxOps, VmpPMat, - VmpPMatOps, alloc_aligned, + FFT64, MatZnxDft, MatZnxDftOps, Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, + ZnxLayout, alloc_aligned, }; use sampling::source::Source; @@ -608,8 +614,8 @@ mod tests { let mut a_big: VecZnxBig = module.new_vec_znx_big(1, vpmat_size); let mut b_big: VecZnxBig = module.new_vec_znx_big(1, vpmat_size); let mut b_dft: VecZnxDft = module.new_vec_znx_dft(1, vpmat_size); - let mut vmpmat_0: VmpPMat = module.new_vmp_pmat(vpmat_rows, 1, vpmat_size); - let mut vmpmat_1: VmpPMat = module.new_vmp_pmat(vpmat_rows, 1, vpmat_size); + let mut vmpmat_0: MatZnxDft = module.new_mat_znx_dft(vpmat_rows, 1, vpmat_size); + let mut vmpmat_1: MatZnxDft = module.new_mat_znx_dft(vpmat_rows, 1, vpmat_size); let mut tmp_bytes: Vec = alloc_aligned(module.vmp_prepare_tmp_bytes(vpmat_rows, 1, vpmat_size)); @@ -619,15 +625,15 @@ mod tests { module.vec_znx_dft(&mut a_dft, &a); module.vmp_prepare_row(&mut vmpmat_0, &a.raw(), row_i, &mut tmp_bytes); - // Checks that prepare(vmp_pmat, a) = prepare_dft(vmp_pmat, a_dft) + // Checks that prepare(mat_znx_dft, a) = prepare_dft(mat_znx_dft, a_dft) module.vmp_prepare_row_dft(&mut vmpmat_1, &a_dft, row_i); assert_eq!(vmpmat_0.raw(), vmpmat_1.raw()); - // Checks that a_dft = extract_dft(prepare(vmp_pmat, a), b_dft) + // Checks that a_dft = extract_dft(prepare(mat_znx_dft, a), b_dft) module.vmp_extract_row_dft(&mut b_dft, &vmpmat_0, row_i); assert_eq!(a_dft.raw(), b_dft.raw()); - // Checks that a_big = extract(prepare_dft(vmp_pmat, a_dft), b_big) + // Checks that a_big = extract(prepare_dft(mat_znx_dft, a_dft), b_big) module.vmp_extract_row(&mut b_big, &vmpmat_0, row_i); module.vec_znx_idft(&mut a_big, &a_dft, &mut tmp_bytes); assert_eq!(a_big.raw(), b_big.raw()); diff --git a/base2k/src/sampling.rs b/base2k/src/sampling.rs index b60e420..c415b80 100644 --- a/base2k/src/sampling.rs +++ b/base2k/src/sampling.rs @@ -1,4 +1,4 @@ -use crate::{Backend, Infos, Module, VecZnx, VecZnxLayout}; +use crate::{Backend, Module, VecZnx, ZnxInfos, ZnxLayout}; use rand_distr::{Distribution, Normal}; use sampling::source::Source; diff --git a/base2k/src/svp.rs b/base2k/src/scalar_znx_dft.rs similarity index 87% rename from base2k/src/svp.rs rename to base2k/src/scalar_znx_dft.rs index ba375c7..7457ca2 100644 --- a/base2k/src/svp.rs +++ b/base2k/src/scalar_znx_dft.rs @@ -2,9 +2,9 @@ use std::marker::PhantomData; use crate::ffi::svp::{self, svp_ppol_t}; use crate::ffi::vec_znx_dft::vec_znx_dft_t; -use crate::{Backend, FFT64, Module, VecZnx, VecZnxDft, VecZnxLayout, assert_alignement}; +use crate::{Backend, FFT64, Module, VecZnx, VecZnxDft, ZnxLayout, assert_alignement}; -use crate::{Infos, alloc_aligned, cast_mut}; +use crate::{ZnxInfos, alloc_aligned, cast_mut}; use rand::seq::SliceRandom; use rand_core::RngCore; use rand_distr::{Distribution, weighted::WeightedIndex}; @@ -148,7 +148,7 @@ impl ScalarOps for Module { } } -pub struct SvpPPol { +pub struct ScalarZnxDft { pub n: usize, pub data: Vec, pub ptr: *mut u8, @@ -157,7 +157,7 @@ pub struct SvpPPol { /// A prepared [crate::Scalar] for [SvpPPolOps::svp_apply_dft]. /// An [SvpPPol] an be seen as a [VecZnxDft] of one limb. -impl SvpPPol { +impl ScalarZnxDft { pub fn new(module: &Module) -> Self { module.new_svp_ppol() } @@ -207,9 +207,9 @@ impl SvpPPol { } } -pub trait SvpPPolOps { +pub trait ScalarZnxDftOps { /// Allocates a new [SvpPPol]. - fn new_svp_ppol(&self) -> SvpPPol; + fn new_svp_ppol(&self) -> ScalarZnxDft; /// Returns the minimum number of bytes necessary to allocate /// a new [SvpPPol] through [SvpPPol::from_bytes] ro. @@ -218,26 +218,26 @@ pub trait SvpPPolOps { /// Allocates a new [SvpPPol] from an array of bytes. /// The array of bytes is owned by the [SvpPPol]. /// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol] - fn new_svp_ppol_from_bytes(&self, bytes: &mut [u8]) -> SvpPPol; + fn new_svp_ppol_from_bytes(&self, bytes: &mut [u8]) -> ScalarZnxDft; /// Allocates a new [SvpPPol] from an array of bytes. /// The array of bytes is borrowed by the [SvpPPol]. /// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol] - fn new_svp_ppol_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> SvpPPol; + fn new_svp_ppol_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> ScalarZnxDft; /// Prepares a [crate::Scalar] for a [SvpPPolOps::svp_apply_dft]. - fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar); + fn svp_prepare(&self, svp_ppol: &mut ScalarZnxDft, a: &Scalar); /// Applies the [SvpPPol] x [VecZnxDft] product, where each limb of /// the [VecZnxDft] is multiplied with [SvpPPol]. - fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx); + fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &ScalarZnxDft, b: &VecZnx); } -impl SvpPPolOps for Module { - fn new_svp_ppol(&self) -> SvpPPol { +impl ScalarZnxDftOps for Module { + fn new_svp_ppol(&self) -> ScalarZnxDft { let mut data: Vec = alloc_aligned::(self.bytes_of_svp_ppol()); let ptr: *mut u8 = data.as_mut_ptr(); - SvpPPol:: { + ScalarZnxDft:: { data: data, ptr: ptr, n: self.n(), @@ -249,19 +249,19 @@ impl SvpPPolOps for Module { unsafe { svp::bytes_of_svp_ppol(self.ptr) as usize } } - fn new_svp_ppol_from_bytes(&self, bytes: &mut [u8]) -> SvpPPol { - SvpPPol::from_bytes(self, bytes) + fn new_svp_ppol_from_bytes(&self, bytes: &mut [u8]) -> ScalarZnxDft { + ScalarZnxDft::from_bytes(self, bytes) } - fn new_svp_ppol_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> SvpPPol { - SvpPPol::from_bytes_borrow(self, tmp_bytes) + fn new_svp_ppol_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> ScalarZnxDft { + ScalarZnxDft::from_bytes_borrow(self, tmp_bytes) } - fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar) { + fn svp_prepare(&self, svp_ppol: &mut ScalarZnxDft, a: &Scalar) { unsafe { svp::svp_prepare(self.ptr, svp_ppol.ptr as *mut svp_ppol_t, a.as_ptr()) } } - fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx) { + fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &ScalarZnxDft, b: &VecZnx) { unsafe { svp::svp_apply_dft( self.ptr, diff --git a/base2k/src/stats.rs b/base2k/src/stats.rs index f72ebaa..44e441f 100644 --- a/base2k/src/stats.rs +++ b/base2k/src/stats.rs @@ -1,4 +1,4 @@ -use crate::{Encoding, Infos, VecZnx}; +use crate::{Encoding, VecZnx, ZnxInfos}; use rug::Float; use rug::float::Round; use rug::ops::{AddAssignRound, DivAssignRound, SubAssignRound}; diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index 9b47eae..89173f0 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -1,8 +1,9 @@ use crate::Backend; +use crate::ZnxBase; use crate::cast_mut; use crate::ffi::vec_znx; use crate::ffi::znx; -use crate::{Infos, Module, VecZnxLayout}; +use crate::{Module, ZnxInfos, ZnxLayout}; use crate::{alloc_aligned, assert_alignement}; use itertools::izip; use std::cmp::min; @@ -35,7 +36,7 @@ pub struct VecZnx { pub ptr: *mut i64, } -impl Infos for VecZnx { +impl ZnxInfos for VecZnx { fn n(&self) -> usize { self.n } @@ -61,7 +62,7 @@ impl Infos for VecZnx { } } -impl VecZnxLayout for VecZnx { +impl ZnxLayout for VecZnx { type Scalar = i64; fn as_ptr(&self) -> *const Self::Scalar { @@ -84,9 +85,12 @@ pub fn copy_vec_znx_from(b: &mut VecZnx, a: &VecZnx) { data_b[..size].copy_from_slice(&data_a[..size]) } -impl VecZnx { +impl ZnxBase for VecZnx { + type Scalar = i64; + /// Allocates a new [VecZnx] composed of #size polynomials of Z\[X\]. - pub fn new(n: usize, cols: usize, limbs: usize) -> Self { + fn new(module: &Module, cols: usize, limbs: usize) -> Self { + let n: usize = module.n(); #[cfg(debug_assertions)] { assert!(n > 0); @@ -94,7 +98,7 @@ impl VecZnx { assert!(cols > 0); assert!(limbs > 0); } - let mut data: Vec = alloc_aligned::(n * cols * limbs); + let mut data: Vec = alloc_aligned::(Self::bytes_of(module, cols, limbs)); let ptr: *mut i64 = data.as_mut_ptr(); Self { n: n, @@ -105,6 +109,57 @@ impl VecZnx { } } + fn bytes_of(module: &Module, cols: usize, limbs: usize) -> usize { + module.n() * cols * limbs * size_of::() + } + + /// Returns a new struct implementing [VecZnx] with the provided data as backing array. + /// + /// The struct will take ownership of buf[..[Self::bytes_of]] + /// + /// User must ensure that data is properly alligned and that + /// the limbs of data is equal to [Self::bytes_of]. + fn from_bytes(module: &Module, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { + let n: usize = module.n(); + #[cfg(debug_assertions)] + { + assert!(cols > 0); + assert!(limbs > 0); + assert_eq!(bytes.len(), Self::bytes_of(module, cols, limbs)); + assert_alignement(bytes.as_ptr()); + } + unsafe { + let bytes_i64: &mut [i64] = cast_mut::(bytes); + let ptr: *mut i64 = bytes_i64.as_mut_ptr(); + Self { + n: n, + cols: cols, + limbs: limbs, + data: Vec::from_raw_parts(ptr, bytes.len(), bytes.len()), + ptr: ptr, + } + } + } + + fn from_bytes_borrow(module: &Module, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { + #[cfg(debug_assertions)] + { + assert!(cols > 0); + assert!(limbs > 0); + assert!(bytes.len() >= Self::bytes_of(module, cols, limbs)); + assert_alignement(bytes.as_ptr()); + } + Self { + n: module.n(), + cols: cols, + limbs: limbs, + data: Vec::new(), + ptr: bytes.as_mut_ptr() as *mut i64, + } + } +} + +impl VecZnx { /// Truncates the precision of the [VecZnx] by k bits. /// /// # Arguments @@ -133,54 +188,6 @@ impl VecZnx { } } - fn bytes_of(n: usize, cols: usize, limbs: usize) -> usize { - n * cols * limbs * size_of::() - } - - /// Returns a new struct implementing [VecZnx] with the provided data as backing array. - /// - /// The struct will take ownership of buf[..[Self::bytes_of]] - /// - /// User must ensure that data is properly alligned and that - /// the limbs of data is equal to [Self::bytes_of]. - pub fn from_bytes(n: usize, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { - #[cfg(debug_assertions)] - { - assert!(cols > 0); - assert!(limbs > 0); - assert_eq!(bytes.len(), Self::bytes_of(n, cols, limbs)); - assert_alignement(bytes.as_ptr()); - } - unsafe { - let bytes_i64: &mut [i64] = cast_mut::(bytes); - let ptr: *mut i64 = bytes_i64.as_mut_ptr(); - Self { - n: n, - cols: cols, - limbs: limbs, - data: Vec::from_raw_parts(ptr, bytes.len(), bytes.len()), - ptr: ptr, - } - } - } - - pub fn from_bytes_borrow(n: usize, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { - #[cfg(debug_assertions)] - { - assert!(cols > 0); - assert!(limbs > 0); - assert!(bytes.len() >= Self::bytes_of(n, cols, limbs)); - assert_alignement(bytes.as_ptr()); - } - Self { - n: n, - cols: cols, - limbs: limbs, - data: Vec::new(), - ptr: bytes.as_mut_ptr() as *mut i64, - } - } - pub fn copy_from(&mut self, a: &Self) { copy_vec_znx_from(self, a); } @@ -394,19 +401,19 @@ pub trait VecZnxOps { impl VecZnxOps for Module { fn new_vec_znx(&self, cols: usize, limbs: usize) -> VecZnx { - VecZnx::new(self.n(), cols, limbs) + VecZnx::new(self, cols, limbs) } fn bytes_of_vec_znx(&self, cols: usize, limbs: usize) -> usize { - VecZnx::bytes_of(self.n(), cols, limbs) + VecZnx::bytes_of(self, cols, limbs) } fn new_vec_znx_from_bytes(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnx { - VecZnx::from_bytes(self.n(), cols, limbs, bytes) + VecZnx::from_bytes(self, cols, limbs, bytes) } fn new_vec_znx_from_bytes_borrow(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnx { - VecZnx::from_bytes_borrow(self.n(), cols, limbs, tmp_bytes) + VecZnx::from_bytes_borrow(self, cols, limbs, tmp_bytes) } fn vec_znx_normalize_tmp_bytes(&self, cols: usize) -> usize { diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index ac02aab..7a8cc48 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,5 +1,5 @@ use crate::ffi::vec_znx_big::{self, vec_znx_big_t}; -use crate::{Backend, FFT64, Infos, Module, VecZnx, VecZnxDft, VecZnxLayout, alloc_aligned, assert_alignement}; +use crate::{Backend, FFT64, Module, VecZnx, VecZnxDft, ZnxBase, ZnxInfos, ZnxLayout, alloc_aligned, assert_alignement}; use std::marker::PhantomData; pub struct VecZnxBig { @@ -10,16 +10,17 @@ pub struct VecZnxBig { pub limbs: usize, pub _marker: PhantomData, } +impl ZnxBase for VecZnxBig { + type Scalar = u8; -impl VecZnxBig { - pub fn new(module: &Module, cols: usize, limbs: usize) -> Self { + fn new(module: &Module, cols: usize, limbs: usize) -> Self { #[cfg(debug_assertions)] { assert!(cols > 0); assert!(limbs > 0); } - let mut data: Vec = alloc_aligned::(module.bytes_of_vec_znx_big(cols, limbs)); - let ptr: *mut u8 = data.as_mut_ptr(); + let mut data: Vec = alloc_aligned::(Self::bytes_of(module, cols, limbs)); + let ptr: *mut Self::Scalar = data.as_mut_ptr(); Self { data: data, ptr: ptr, @@ -30,15 +31,19 @@ impl VecZnxBig { } } + fn bytes_of(module: &Module, cols: usize, limbs: usize) -> usize { + unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, limbs as u64) as usize * cols } + } + /// Returns a new [VecZnxBig] with the provided data as backing array. /// User must ensure that data is properly alligned and that /// the size of data is at least equal to [Module::bytes_of_vec_znx_big]. - pub fn from_bytes(module: &Module, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { + fn from_bytes(module: &Module, cols: usize, limbs: usize, bytes: &mut [Self::Scalar]) -> Self { #[cfg(debug_assertions)] { assert!(cols > 0); assert!(limbs > 0); - assert_eq!(bytes.len(), module.bytes_of_vec_znx_big(cols, limbs)); + assert_eq!(bytes.len(), Self::bytes_of(module, cols, limbs)); assert_alignement(bytes.as_ptr()) }; unsafe { @@ -53,12 +58,12 @@ impl VecZnxBig { } } - pub fn from_bytes_borrow(module: &Module, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { + fn from_bytes_borrow(module: &Module, cols: usize, limbs: usize, bytes: &mut [Self::Scalar]) -> Self { #[cfg(debug_assertions)] { assert!(cols > 0); assert!(limbs > 0); - assert_eq!(bytes.len(), module.bytes_of_vec_znx_big(cols, limbs)); + assert_eq!(bytes.len(), Self::bytes_of(module, cols, limbs)); assert_alignement(bytes.as_ptr()); } Self { @@ -70,24 +75,9 @@ impl VecZnxBig { _marker: PhantomData, } } - - pub fn as_vec_znx_dft(&mut self) -> VecZnxDft { - VecZnxDft:: { - data: Vec::new(), - ptr: self.ptr, - n: self.n, - cols: self.cols, - limbs: self.limbs, - _marker: self._marker, - } - } - - pub fn print(&self, n: usize) { - (0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])); - } } -impl Infos for VecZnxBig { +impl ZnxInfos for VecZnxBig { fn log_n(&self) -> usize { (usize::BITS - (self.n - 1).leading_zeros()) as _ } @@ -113,7 +103,7 @@ impl Infos for VecZnxBig { } } -impl VecZnxLayout for VecZnxBig { +impl ZnxLayout for VecZnxBig { type Scalar = i64; fn as_ptr(&self) -> *const Self::Scalar { @@ -125,6 +115,12 @@ impl VecZnxLayout for VecZnxBig { } } +impl VecZnxBig { + pub fn print(&self, n: usize) { + (0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])); + } +} + pub trait VecZnxBigOps { /// Allocates a vector Z[X]/(X^N+1) that stores not normalized values. fn new_vec_znx_big(&self, cols: usize, limbs: usize) -> VecZnxBig; @@ -220,7 +216,7 @@ impl VecZnxBigOps for Module { } fn bytes_of_vec_znx_big(&self, cols: usize, limbs: usize) -> usize { - unsafe { vec_znx_big::bytes_of_vec_znx_big(self.ptr, limbs as u64) as usize * cols } + VecZnxBig::bytes_of(self, cols, limbs) } /// [VecZnxBig] (3 cols and 4 limbs) diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index 6d3c6f6..7724710 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -1,7 +1,7 @@ use crate::ffi::vec_znx_big::vec_znx_big_t; use crate::ffi::vec_znx_dft; use crate::ffi::vec_znx_dft::{bytes_of_vec_znx_dft, vec_znx_dft_t}; -use crate::{Backend, FFT64, Infos, Module, VecZnxBig, VecZnxLayout, assert_alignement}; +use crate::{Backend, FFT64, Module, VecZnxBig, ZnxBase, ZnxInfos, ZnxLayout, assert_alignement}; use crate::{VecZnx, alloc_aligned}; use std::marker::PhantomData; @@ -14,15 +14,17 @@ pub struct VecZnxDft { pub _marker: PhantomData, } -impl VecZnxDft { - pub fn new(module: &Module, cols: usize, limbs: usize) -> Self { +impl ZnxBase for VecZnxDft { + type Scalar = u8; + + fn new(module: &Module, cols: usize, limbs: usize) -> Self { #[cfg(debug_assertions)] { assert!(cols > 0); assert!(limbs > 0); } - let mut data: Vec = alloc_aligned::(module.bytes_of_vec_znx_dft(cols, limbs)); - let ptr: *mut u8 = data.as_mut_ptr(); + let mut data: Vec = alloc_aligned(Self::bytes_of(module, cols, limbs)); + let ptr: *mut Self::Scalar = data.as_mut_ptr(); Self { data: data, ptr: ptr, @@ -33,19 +35,19 @@ impl VecZnxDft { } } - fn bytes_of(module: &Module, cols: usize, limbs: usize) -> usize { + fn bytes_of(module: &Module, cols: usize, limbs: usize) -> usize { unsafe { bytes_of_vec_znx_dft(module.ptr, limbs as u64) as usize * cols } } /// Returns a new [VecZnxDft] with the provided data as backing array. /// User must ensure that data is properly alligned and that /// the size of data is at least equal to [Module::bytes_of_vec_znx_dft]. - pub fn from_bytes(module: &Module, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { + fn from_bytes(module: &Module, cols: usize, limbs: usize, bytes: &mut [Self::Scalar]) -> Self { #[cfg(debug_assertions)] { assert!(cols > 0); assert!(limbs > 0); - assert_eq!(bytes.len(), module.bytes_of_vec_znx_dft(cols, limbs)); + assert_eq!(bytes.len(), Self::bytes_of(module, cols, limbs)); assert_alignement(bytes.as_ptr()) } unsafe { @@ -60,12 +62,12 @@ impl VecZnxDft { } } - pub fn from_bytes_borrow(module: &Module, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { + fn from_bytes_borrow(module: &Module, cols: usize, limbs: usize, bytes: &mut [Self::Scalar]) -> Self { #[cfg(debug_assertions)] { assert!(cols > 0); assert!(limbs > 0); - assert_eq!(bytes.len(), module.bytes_of_vec_znx_dft(cols, limbs)); + assert_eq!(bytes.len(), Self::bytes_of(module, cols, limbs)); assert_alignement(bytes.as_ptr()); } Self { @@ -77,12 +79,14 @@ impl VecZnxDft { _marker: PhantomData, } } +} +impl VecZnxDft { /// Cast a [VecZnxDft] into a [VecZnxBig]. /// The returned [VecZnxBig] shares the backing array /// with the original [VecZnxDft]. - pub fn as_vec_znx_big(&mut self) -> VecZnxBig { - VecZnxBig:: { + pub fn as_vec_znx_big(&mut self) -> VecZnxBig { + VecZnxBig:: { data: Vec::new(), ptr: self.ptr, n: self.n, @@ -91,13 +95,9 @@ impl VecZnxDft { _marker: PhantomData, } } - - pub fn print(&self, n: usize) { - (0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])); - } } -impl Infos for VecZnxDft { +impl ZnxInfos for VecZnxDft { fn n(&self) -> usize { self.n } @@ -123,7 +123,7 @@ impl Infos for VecZnxDft { } } -impl VecZnxLayout for VecZnxDft { +impl ZnxLayout for VecZnxDft { type Scalar = f64; fn as_ptr(&self) -> *const Self::Scalar { @@ -135,6 +135,12 @@ impl VecZnxLayout for VecZnxDft { } } +impl VecZnxDft { + pub fn print(&self, n: usize) { + (0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])); + } +} + pub trait VecZnxDftOps { /// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space. fn new_vec_znx_dft(&self, cols: usize, limbs: usize) -> VecZnxDft; @@ -314,7 +320,7 @@ impl VecZnxDftOps for Module { #[cfg(test)] mod tests { - use crate::{FFT64, Module, Sampling, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxLayout, VecZnxOps, alloc_aligned}; + use crate::{FFT64, Module, Sampling, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, ZnxLayout, alloc_aligned}; use itertools::izip; use sampling::source::Source; diff --git a/rlwe/benches/gadget_product.rs b/rlwe/benches/gadget_product.rs index fdd2240..14bb06d 100644 --- a/rlwe/benches/gadget_product.rs +++ b/rlwe/benches/gadget_product.rs @@ -1,4 +1,4 @@ -use base2k::{BACKEND, Module, Sampling, SvpPPolOps, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, alloc_aligned_u8}; +use base2k::{BACKEND, Module, Sampling, ScalarZnxDftOps, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, MatZnxDft, alloc_aligned_u8}; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; use rlwe::{ ciphertext::{Ciphertext, new_gadget_ciphertext}, @@ -16,7 +16,7 @@ fn bench_gadget_product_inplace(c: &mut Criterion) { res_dft_0: &'a mut VecZnxDft, res_dft_1: &'a mut VecZnxDft, a: &'a VecZnx, - b: &'a Ciphertext, + b: &'a Ciphertext, b_cols: usize, tmp_bytes: &'a mut [u8], ) -> Box { @@ -69,13 +69,13 @@ fn bench_gadget_product_inplace(c: &mut Criterion) { let mut source_xe: Source = Source::new([4; 32]); let mut source_xa: Source = Source::new([5; 32]); - let mut sk0_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol(); + let mut sk0_svp_ppol: base2k::ScalarZnxDft = params.module().new_svp_ppol(); params.module().svp_prepare(&mut sk0_svp_ppol, &sk0.0); - let mut sk1_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol(); + let mut sk1_svp_ppol: base2k::ScalarZnxDft = params.module().new_svp_ppol(); params.module().svp_prepare(&mut sk1_svp_ppol, &sk1.0); - let mut gadget_ct: Ciphertext = new_gadget_ciphertext( + let mut gadget_ct: Ciphertext = new_gadget_ciphertext( params.module(), params.log_base2k(), params.cols_q(), diff --git a/rlwe/examples/encryption.rs b/rlwe/examples/encryption.rs index b9d66cd..20a0603 100644 --- a/rlwe/examples/encryption.rs +++ b/rlwe/examples/encryption.rs @@ -1,4 +1,4 @@ -use base2k::{Encoding, SvpPPolOps, VecZnx, alloc_aligned}; +use base2k::{Encoding, ScalarZnxDftOps, VecZnx, alloc_aligned}; use rlwe::{ ciphertext::Ciphertext, elem::ElemCommon, @@ -51,7 +51,7 @@ fn main() { let mut source_xe: Source = Source::new([1; 32]); let mut source_xa: Source = Source::new([2; 32]); - let mut sk_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol(); + let mut sk_svp_ppol: base2k::ScalarZnxDft = params.module().new_svp_ppol(); params.module().svp_prepare(&mut sk_svp_ppol, &sk.0); params.encrypt_rlwe_sk( diff --git a/rlwe/src/automorphism.rs b/rlwe/src/automorphism.rs index 5e5b48a..d76e356 100644 --- a/rlwe/src/automorphism.rs +++ b/rlwe/src/automorphism.rs @@ -7,15 +7,15 @@ use crate::{ parameters::Parameters, }; use base2k::{ - Module, Scalar, ScalarOps, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, - VmpPMatOps, assert_alignement, + Module, Scalar, ScalarOps, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, MatZnxDft, + MatZnxDftOps, assert_alignement, }; use sampling::source::Source; use std::collections::HashMap; /// Stores DFT([-A*AUTO(s, -p) + 2^{-K*i}*s + E, A]) where AUTO(X, p): X^{i} -> X^{i*p} pub struct AutomorphismKey { - pub value: Ciphertext, + pub value: Ciphertext, pub p: i64, } @@ -106,12 +106,12 @@ impl AutomorphismKey { let (sk_out_bytes, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_svp_ppol()); let sk_auto: Scalar = module.new_scalar_from_bytes_borrow(sk_auto_bytes); - let mut sk_out: SvpPPol = module.new_svp_ppol_from_bytes_borrow(sk_out_bytes); + let mut sk_out: ScalarZnxDft = module.new_svp_ppol_from_bytes_borrow(sk_out_bytes); let mut keys: Vec = Vec::new(); p.iter().for_each(|pi| { - let mut value: Ciphertext = new_gadget_ciphertext(module, log_base2k, rows, log_q); + let mut value: Ciphertext = new_gadget_ciphertext(module, log_base2k, rows, log_q); let p_inv: i64 = module.galois_element_inv(*pi); @@ -223,7 +223,7 @@ mod test { parameters::{Parameters, ParametersLiteral}, plaintext::Plaintext, }; - use base2k::{BACKEND, Encoding, Module, SvpPPol, SvpPPolOps, VecZnx, VecZnxOps, alloc_aligned}; + use base2k::{BACKEND, Encoding, Module, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxOps, alloc_aligned}; use sampling::source::{Source, new_seed}; #[test] @@ -267,7 +267,7 @@ mod test { let mut sk: SecretKey = SecretKey::new(module); sk.fill_ternary_hw(params.xs(), &mut source_xs); - let mut sk_svp_ppol: SvpPPol = module.new_svp_ppol(); + let mut sk_svp_ppol: ScalarZnxDft = module.new_svp_ppol(); module.svp_prepare(&mut sk_svp_ppol, &sk.0); let p: i64 = -5; diff --git a/rlwe/src/ciphertext.rs b/rlwe/src/ciphertext.rs index 73addb5..bcffeec 100644 --- a/rlwe/src/ciphertext.rs +++ b/rlwe/src/ciphertext.rs @@ -1,6 +1,6 @@ use crate::elem::{Elem, ElemCommon}; use crate::parameters::Parameters; -use base2k::{Infos, Layout, Module, VecZnx, VmpPMat}; +use base2k::{ZnxInfos, Layout, Module, VecZnx, MatZnxDft}; pub struct Ciphertext(pub Elem); @@ -12,7 +12,7 @@ impl Parameters { impl ElemCommon for Ciphertext where - T: Infos, + T: ZnxInfos, { fn n(&self) -> usize { self.elem().n() @@ -78,16 +78,16 @@ pub fn new_rlwe_ciphertext(module: &Module, log_base2k: usize, log_q: usize) -> Ciphertext::::new(module, log_base2k, log_q, rows) } -pub fn new_gadget_ciphertext(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> Ciphertext { +pub fn new_gadget_ciphertext(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> Ciphertext { let cols: usize = (log_q + log_base2k - 1) / log_base2k; - let mut elem: Elem = Elem::::new(module, log_base2k, 2, rows, cols); + let mut elem: Elem = Elem::::new(module, log_base2k, 2, rows, cols); elem.log_q = log_q; Ciphertext(elem) } -pub fn new_rgsw_ciphertext(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> Ciphertext { +pub fn new_rgsw_ciphertext(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> Ciphertext { let cols: usize = (log_q + log_base2k - 1) / log_base2k; - let mut elem: Elem = Elem::::new(module, log_base2k, 4, rows, cols); + let mut elem: Elem = Elem::::new(module, log_base2k, 4, rows, cols); elem.log_q = log_q; Ciphertext(elem) } diff --git a/rlwe/src/decryptor.rs b/rlwe/src/decryptor.rs index 6eeea27..4c1fb7e 100644 --- a/rlwe/src/decryptor.rs +++ b/rlwe/src/decryptor.rs @@ -5,16 +5,16 @@ use crate::{ parameters::Parameters, plaintext::Plaintext, }; -use base2k::{Module, SvpPPol, SvpPPolOps, VecZnx, VecZnxBigOps, VecZnxDft, VecZnxDftOps}; +use base2k::{Module, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxBigOps, VecZnxDft, VecZnxDftOps}; use std::cmp::min; pub struct Decryptor { - sk: SvpPPol, + sk: ScalarZnxDft, } impl Decryptor { pub fn new(params: &Parameters, sk: &SecretKey) -> Self { - let mut sk_svp_ppol: SvpPPol = params.module().new_svp_ppol(); + let mut sk_svp_ppol: ScalarZnxDft = params.module().new_svp_ppol(); sk.prepare(params.module(), &mut sk_svp_ppol); Self { sk: sk_svp_ppol } } @@ -32,12 +32,12 @@ impl Parameters { ) } - pub fn decrypt_rlwe(&self, res: &mut Plaintext, ct: &Ciphertext, sk: &SvpPPol, tmp_bytes: &mut [u8]) { + pub fn decrypt_rlwe(&self, res: &mut Plaintext, ct: &Ciphertext, sk: &ScalarZnxDft, tmp_bytes: &mut [u8]) { decrypt_rlwe(self.module(), &mut res.0, &ct.0, sk, tmp_bytes) } } -pub fn decrypt_rlwe(module: &Module, res: &mut Elem, a: &Elem, sk: &SvpPPol, tmp_bytes: &mut [u8]) { +pub fn decrypt_rlwe(module: &Module, res: &mut Elem, a: &Elem, sk: &ScalarZnxDft, tmp_bytes: &mut [u8]) { let cols: usize = a.cols(); assert!( diff --git a/rlwe/src/elem.rs b/rlwe/src/elem.rs index 656cc3a..c6fe59f 100644 --- a/rlwe/src/elem.rs +++ b/rlwe/src/elem.rs @@ -1,4 +1,4 @@ -use base2k::{Infos, Layout, Module, VecZnx, VecZnxOps, VmpPMat, VmpPMatOps}; +use base2k::{ZnxInfos, Layout, Module, VecZnx, VecZnxOps, MatZnxDft, MatZnxDftOps}; pub struct Elem { pub value: Vec, @@ -81,7 +81,7 @@ pub trait ElemCommon { fn at_mut(&mut self, i: usize) -> &mut T; } -impl ElemCommon for Elem { +impl ElemCommon for Elem { fn n(&self) -> usize { self.value[0].n() } @@ -152,11 +152,11 @@ impl Elem { } } -impl Elem { +impl Elem { pub fn new(module: &Module, log_base2k: usize, size: usize, rows: usize, cols: usize) -> Self { assert!(rows > 0); assert!(cols > 0); - let mut value: Vec = Vec::new(); + let mut value: Vec = Vec::new(); (0..size).for_each(|_| value.push(module.new_vmp_pmat(1, rows, cols))); Self { value: value, diff --git a/rlwe/src/encryptor.rs b/rlwe/src/encryptor.rs index bdb383c..7354a0f 100644 --- a/rlwe/src/encryptor.rs +++ b/rlwe/src/encryptor.rs @@ -5,8 +5,8 @@ use crate::parameters::Parameters; use crate::plaintext::Plaintext; use base2k::sampling::Sampling; use base2k::{ - Infos, Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, - VmpPMatOps, + ZnxInfos, Module, Scalar, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, MatZnxDft, + MatZnxDftOps, }; use sampling::source::{Source, new_seed}; @@ -19,7 +19,7 @@ impl Parameters { &self, ct: &mut Ciphertext, pt: Option<&Plaintext>, - sk: &SvpPPol, + sk: &ScalarZnxDft, source_xa: &mut Source, source_xe: &mut Source, tmp_bytes: &mut [u8], @@ -38,7 +38,7 @@ impl Parameters { } pub struct EncryptorSk { - sk: SvpPPol, + sk: ScalarZnxDft, source_xa: Source, source_xe: Source, initialized: bool, @@ -47,7 +47,7 @@ pub struct EncryptorSk { impl EncryptorSk { pub fn new(params: &Parameters, sk: Option<&SecretKey>) -> Self { - let mut sk_svp_ppol: SvpPPol = params.module().new_svp_ppol(); + let mut sk_svp_ppol: ScalarZnxDft = params.module().new_svp_ppol(); let mut initialized: bool = false; if let Some(sk) = sk { sk.prepare(params.module(), &mut sk_svp_ppol); @@ -114,7 +114,7 @@ pub fn encrypt_rlwe_sk( module: &Module, ct: &mut Elem, pt: Option<&VecZnx>, - sk: &SvpPPol, + sk: &ScalarZnxDft, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, @@ -127,7 +127,7 @@ fn encrypt_rlwe_sk_core( module: &Module, ct: &mut Elem, pt: Option<&VecZnx>, - sk: &SvpPPol, + sk: &ScalarZnxDft, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, @@ -217,9 +217,9 @@ pub fn encrypt_grlwe_sk_tmp_bytes(module: &Module, log_base2k: usize, rows: usiz pub fn encrypt_grlwe_sk( module: &Module, - ct: &mut Ciphertext, + ct: &mut Ciphertext, m: &Scalar, - sk: &SvpPPol, + sk: &ScalarZnxDft, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, @@ -258,9 +258,9 @@ pub fn encrypt_rgsw_sk_tmp_bytes(module: &Module, log_base2k: usize, rows: usize pub fn encrypt_rgsw_sk( module: &Module, - ct: &mut Ciphertext, + ct: &mut Ciphertext, m: &Scalar, - sk: &SvpPPol, + sk: &ScalarZnxDft, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, @@ -302,10 +302,10 @@ pub fn encrypt_rgsw_sk( fn encrypt_grlwe_sk_core( module: &Module, log_base2k: usize, - mut ct: [&mut VmpPMat; 2], + mut ct: [&mut MatZnxDft; 2], log_q: usize, m: &Scalar, - sk: &SvpPPol, + sk: &ScalarZnxDft, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, diff --git a/rlwe/src/gadget_product.rs b/rlwe/src/gadget_product.rs index bbf9642..9315cd8 100644 --- a/rlwe/src/gadget_product.rs +++ b/rlwe/src/gadget_product.rs @@ -1,5 +1,5 @@ use crate::{ciphertext::Ciphertext, elem::ElemCommon, parameters::Parameters}; -use base2k::{Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps}; +use base2k::{Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, MatZnxDft, MatZnxDftOps}; use std::cmp::min; pub fn gadget_product_core_tmp_bytes( @@ -34,7 +34,7 @@ pub fn gadget_product_core( res_dft_0: &mut VecZnxDft, res_dft_1: &mut VecZnxDft, a: &VecZnx, - b: &Ciphertext, + b: &Ciphertext, b_cols: usize, tmp_bytes: &mut [u8], ) { @@ -61,7 +61,7 @@ pub fn gadget_product_big( module: &Module, c: &mut Ciphertext, a: &Ciphertext, - b: &Ciphertext, + b: &Ciphertext, tmp_bytes: &mut [u8], ) { let cols: usize = min(c.cols(), a.cols()); @@ -94,7 +94,7 @@ pub fn gadget_product( module: &Module, c: &mut Ciphertext, a: &Ciphertext, - b: &Ciphertext, + b: &Ciphertext, tmp_bytes: &mut [u8], ) { let cols: usize = min(c.cols(), a.cols()); @@ -130,7 +130,7 @@ mod test { plaintext::Plaintext, }; use base2k::{ - BACKEND, Infos, Sampling, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, + BACKEND, ZnxInfos, Sampling, ScalarZnxDftOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, MatZnxDft, alloc_aligned_u8, }; use sampling::source::{Source, new_seed}; @@ -175,16 +175,16 @@ mod test { // Two secret keys let mut sk0: SecretKey = SecretKey::new(params.module()); sk0.fill_ternary_hw(params.xs(), &mut source_xs); - let mut sk0_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol(); + let mut sk0_svp_ppol: base2k::ScalarZnxDft = params.module().new_svp_ppol(); params.module().svp_prepare(&mut sk0_svp_ppol, &sk0.0); let mut sk1: SecretKey = SecretKey::new(params.module()); sk1.fill_ternary_hw(params.xs(), &mut source_xs); - let mut sk1_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol(); + let mut sk1_svp_ppol: base2k::ScalarZnxDft = params.module().new_svp_ppol(); params.module().svp_prepare(&mut sk1_svp_ppol, &sk1.0); // The gadget ciphertext - let mut gadget_ct: Ciphertext = new_gadget_ciphertext( + let mut gadget_ct: Ciphertext = new_gadget_ciphertext( params.module(), log_base2k, params.cols_qp(), diff --git a/rlwe/src/key_generator.rs b/rlwe/src/key_generator.rs index 4f62a2c..88a2331 100644 --- a/rlwe/src/key_generator.rs +++ b/rlwe/src/key_generator.rs @@ -1,7 +1,7 @@ use crate::encryptor::{encrypt_grlwe_sk, encrypt_grlwe_sk_tmp_bytes}; use crate::keys::{PublicKey, SecretKey, SwitchingKey}; use crate::parameters::Parameters; -use base2k::{Module, SvpPPol}; +use base2k::{Module, ScalarZnxDft}; use sampling::source::Source; pub struct KeyGenerator {} @@ -16,7 +16,7 @@ impl KeyGenerator { pub fn gen_public_key_thread_safe( &self, params: &Parameters, - sk_ppol: &SvpPPol, + sk_ppol: &ScalarZnxDft, source: &mut Source, tmp_bytes: &mut [u8], ) -> PublicKey { @@ -43,7 +43,7 @@ pub fn gen_switching_key( module: &Module, swk: &mut SwitchingKey, sk_in: &SecretKey, - sk_out: &SvpPPol, + sk_out: &ScalarZnxDft, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, diff --git a/rlwe/src/key_switching.rs b/rlwe/src/key_switching.rs index 4e0001a..e73c7f9 100644 --- a/rlwe/src/key_switching.rs +++ b/rlwe/src/key_switching.rs @@ -1,6 +1,6 @@ use crate::ciphertext::Ciphertext; use crate::elem::ElemCommon; -use base2k::{Module, VecZnx, VecZnxBigOps, VecZnxDftOps, VmpPMat, VmpPMatOps, assert_alignement}; +use base2k::{Module, VecZnx, VecZnxBigOps, VecZnxDftOps, MatZnxDft, MatZnxDftOps, assert_alignement}; use std::cmp::min; pub fn key_switch_tmp_bytes(module: &Module, log_base2k: usize, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize { @@ -16,7 +16,7 @@ pub fn key_switch_rlwe( module: &Module, c: &mut Ciphertext, a: &Ciphertext, - b: &Ciphertext, + b: &Ciphertext, b_cols: usize, tmp_bytes: &mut [u8], ) { @@ -26,7 +26,7 @@ pub fn key_switch_rlwe( pub fn key_switch_rlwe_inplace( module: &Module, a: &mut Ciphertext, - b: &Ciphertext, + b: &Ciphertext, b_cols: usize, tmp_bytes: &mut [u8], ) { @@ -37,7 +37,7 @@ fn key_switch_rlwe_core( module: &Module, c: *mut Ciphertext, a: *const Ciphertext, - b: &Ciphertext, + b: &Ciphertext, b_cols: usize, tmp_bytes: &mut [u8], ) { @@ -74,6 +74,6 @@ fn key_switch_rlwe_core( module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(1), &mut res_big, tmp_bytes); } -pub fn key_switch_grlwe(module: &Module, c: &mut Ciphertext, a: &Ciphertext, b: &Ciphertext) {} +pub fn key_switch_grlwe(module: &Module, c: &mut Ciphertext, a: &Ciphertext, b: &Ciphertext) {} -pub fn key_switch_rgsw(module: &Module, c: &mut Ciphertext, a: &Ciphertext, b: &Ciphertext) {} +pub fn key_switch_rgsw(module: &Module, c: &mut Ciphertext, a: &Ciphertext, b: &Ciphertext) {} diff --git a/rlwe/src/keys.rs b/rlwe/src/keys.rs index da7c412..6017159 100644 --- a/rlwe/src/keys.rs +++ b/rlwe/src/keys.rs @@ -1,7 +1,7 @@ use crate::ciphertext::{Ciphertext, new_gadget_ciphertext}; use crate::elem::{Elem, ElemCommon}; use crate::encryptor::{encrypt_rlwe_sk, encrypt_rlwe_sk_tmp_bytes}; -use base2k::{Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VmpPMat}; +use base2k::{Module, Scalar, ScalarZnxDft, ScalarZnxDftOps, VecZnx, MatZnxDft}; use sampling::source::Source; pub struct SecretKey(pub Scalar); @@ -19,7 +19,7 @@ impl SecretKey { self.0.fill_ternary_hw(hw, source); } - pub fn prepare(&self, module: &Module, sk_ppol: &mut SvpPPol) { + pub fn prepare(&self, module: &Module, sk_ppol: &mut ScalarZnxDft) { module.svp_prepare(sk_ppol, &self.0) } } @@ -34,7 +34,7 @@ impl PublicKey { pub fn gen_thread_safe( &mut self, module: &Module, - sk: &SvpPPol, + sk: &ScalarZnxDft, xe: f64, xa_source: &mut Source, xe_source: &mut Source, @@ -57,7 +57,7 @@ impl PublicKey { } } -pub struct SwitchingKey(pub Ciphertext); +pub struct SwitchingKey(pub Ciphertext); impl SwitchingKey { pub fn new(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> SwitchingKey { diff --git a/rlwe/src/rgsw_product.rs b/rlwe/src/rgsw_product.rs index dc42602..1f76166 100644 --- a/rlwe/src/rgsw_product.rs +++ b/rlwe/src/rgsw_product.rs @@ -1,5 +1,5 @@ use crate::{ciphertext::Ciphertext, elem::ElemCommon, parameters::Parameters}; -use base2k::{Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps, assert_alignement}; +use base2k::{Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, MatZnxDft, MatZnxDftOps, assert_alignement}; use std::cmp::min; impl Parameters { @@ -26,7 +26,7 @@ pub fn rgsw_product( module: &Module, c: &mut Ciphertext, a: &Ciphertext, - b: &Ciphertext, + b: &Ciphertext, b_cols: usize, tmp_bytes: &mut [u8], ) { @@ -69,7 +69,7 @@ pub fn rgsw_product( pub fn rgsw_product_inplace( module: &Module, a: &mut Ciphertext, - b: &Ciphertext, + b: &Ciphertext, b_cols: usize, tmp_bytes: &mut [u8], ) { @@ -120,7 +120,7 @@ mod test { plaintext::Plaintext, rgsw_product::rgsw_product_inplace, }; - use base2k::{BACKEND, Encoding, Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxOps, VmpPMat, alloc_aligned}; + use base2k::{BACKEND, Encoding, Module, Scalar, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxOps, MatZnxDft, alloc_aligned}; use sampling::source::{Source, new_seed}; #[test] @@ -164,10 +164,10 @@ mod test { let mut sk: SecretKey = SecretKey::new(module); sk.fill_ternary_hw(params.xs(), &mut source_xs); - let mut sk_svp_ppol: SvpPPol = module.new_svp_ppol(); + let mut sk_svp_ppol: ScalarZnxDft = module.new_svp_ppol(); module.svp_prepare(&mut sk_svp_ppol, &sk.0); - let mut ct_rgsw: Ciphertext = new_rgsw_ciphertext(module, log_base2k, gct_rows, log_qp); + let mut ct_rgsw: Ciphertext = new_rgsw_ciphertext(module, log_base2k, gct_rows, log_qp); let k: i64 = 3; diff --git a/rlwe/src/trace.rs b/rlwe/src/trace.rs index 9e7feb8..005c497 100644 --- a/rlwe/src/trace.rs +++ b/rlwe/src/trace.rs @@ -1,5 +1,5 @@ use crate::{automorphism::AutomorphismKey, ciphertext::Ciphertext, elem::ElemCommon, parameters::Parameters}; -use base2k::{Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMatOps, assert_alignement}; +use base2k::{Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, MatZnxDftOps, assert_alignement}; use std::collections::HashMap; pub fn trace_galois_elements(module: &Module) -> Vec { @@ -115,7 +115,7 @@ mod test { parameters::{DEFAULT_SIGMA, Parameters, ParametersLiteral}, plaintext::Plaintext, }; - use base2k::{BACKEND, Encoding, Module, SvpPPol, SvpPPolOps, VecZnx, alloc_aligned}; + use base2k::{BACKEND, Encoding, Module, ScalarZnxDft, ScalarZnxDftOps, VecZnx, alloc_aligned}; use sampling::source::{Source, new_seed}; use std::collections::HashMap; @@ -160,7 +160,7 @@ mod test { let mut sk: SecretKey = SecretKey::new(module); sk.fill_ternary_hw(params.xs(), &mut source_xs); - let mut sk_svp_ppol: SvpPPol = module.new_svp_ppol(); + let mut sk_svp_ppol: ScalarZnxDft = module.new_svp_ppol(); module.svp_prepare(&mut sk_svp_ppol, &sk.0); let gal_els: Vec = trace_galois_elements(module);