diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index 5385a5b..cb9dfa8 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -1,6 +1,6 @@ use base2k::{ Encoding, FFT64, Infos, Module, Sampling, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, - VecZnxDftOps, VecZnxOps, alloc_aligned, + VecZnxDftOps, VecZnxLayout, VecZnxOps, alloc_aligned, }; use itertools::izip; use sampling::source::Source; diff --git a/base2k/examples/vector_matrix_product.rs b/base2k/examples/vector_matrix_product.rs index 4e8b97e..8d4a33d 100644 --- a/base2k/examples/vector_matrix_product.rs +++ b/base2k/examples/vector_matrix_product.rs @@ -1,6 +1,6 @@ use base2k::{ - Encoding, FFT64, Infos, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, - alloc_aligned, + Encoding, FFT64, Infos, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxLayout, VecZnxOps, VmpPMat, + VmpPMatOps, alloc_aligned, }; fn main() { diff --git a/base2k/src/commons.rs b/base2k/src/commons.rs new file mode 100644 index 0000000..ef7a410 --- /dev/null +++ b/base2k/src/commons.rs @@ -0,0 +1,70 @@ +pub trait Infos { + /// Returns the ring degree of the polynomials. + fn n(&self) -> usize; + + /// Returns the base two logarithm of the ring dimension of the polynomials. + fn log_n(&self) -> usize; + + /// Returns the number of rows. + fn rows(&self) -> usize; + + /// Returns the number of polynomials in each row. + fn cols(&self) -> usize; + + /// Returns the number of limbs per polynomial. + fn limbs(&self) -> usize; + + /// Returns the total number of small polynomials. + fn poly_count(&self) -> usize; +} + +pub trait VecZnxLayout: Infos { + type Scalar; + + fn as_ptr(&self) -> *const Self::Scalar; + fn as_mut_ptr(&mut self) -> *mut Self::Scalar; + + fn raw(&self) -> &[Self::Scalar] { + unsafe { std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) } + } + + fn raw_mut(&mut self) -> &mut [Self::Scalar] { + unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.n() * self.poly_count()) } + } + + fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar { + #[cfg(debug_assertions)] + { + assert!(i < self.cols()); + assert!(j < self.limbs()); + } + let offset = self.n() * (j * self.cols() + i); + unsafe { self.as_ptr().add(offset) } + } + + fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar { + #[cfg(debug_assertions)] + { + assert!(i < self.cols()); + assert!(j < self.limbs()); + } + let offset = self.n() * (j * self.cols() + i); + unsafe { self.as_mut_ptr().add(offset) } + } + + fn at_poly(&self, i: usize, j: usize) -> &[Self::Scalar] { + unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n()) } + } + + fn at_poly_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] { + unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n()) } + } + + fn at_limb(&self, j: usize) -> &[Self::Scalar] { + unsafe { std::slice::from_raw_parts(self.at_ptr(0, j), self.n() * self.cols()) } + } + + fn at_limb_mut(&mut self, j: usize) -> &mut [Self::Scalar] { + unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(0, j), self.n() * self.cols()) } + } +} diff --git a/base2k/src/encoding.rs b/base2k/src/encoding.rs index d4085cb..5944f3c 100644 --- a/base2k/src/encoding.rs +++ b/base2k/src/encoding.rs @@ -1,5 +1,5 @@ use crate::ffi::znx::znx_zero_i64_ref; -use crate::{Infos, VecZnx}; +use crate::{Infos, VecZnx, VecZnxLayout}; use itertools::izip; use rug::{Assign, Float}; use std::cmp::min; @@ -262,7 +262,7 @@ fn decode_coeff_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, i #[cfg(test)] mod tests { - use crate::{Encoding, Infos, VecZnx}; + use crate::{Encoding, Infos, VecZnx, VecZnxLayout}; use itertools::izip; use sampling::source::Source; diff --git a/base2k/src/infos.rs b/base2k/src/infos.rs deleted file mode 100644 index 764a7fe..0000000 --- a/base2k/src/infos.rs +++ /dev/null @@ -1,19 +0,0 @@ -pub trait Infos { - /// Returns the ring degree of the polynomials. - fn n(&self) -> usize; - - /// Returns the base two logarithm of the ring dimension of the polynomials. - fn log_n(&self) -> usize; - - /// Returns the number of rows. - fn rows(&self) -> usize; - - /// Returns the number of polynomials in each row. - fn cols(&self) -> usize; - - /// Returns the number of limbs per polynomial. - fn limbs(&self) -> usize; - - /// Returns the total number of small polynomials. - fn poly_count(&self) -> usize; -} diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index 83c937a..4d54ca0 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -1,8 +1,8 @@ +pub mod commons; pub mod encoding; #[allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, improper_ctypes)] // Other modules and exports pub mod ffi; -pub mod infos; pub mod module; pub mod sampling; pub mod stats; @@ -12,8 +12,8 @@ pub mod vec_znx_big; pub mod vec_znx_dft; pub mod vmp; +pub use commons::*; pub use encoding::*; -pub use infos::*; pub use module::*; pub use sampling::*; #[allow(unused_imports)] diff --git a/base2k/src/sampling.rs b/base2k/src/sampling.rs index db9a79b..b60e420 100644 --- a/base2k/src/sampling.rs +++ b/base2k/src/sampling.rs @@ -1,4 +1,4 @@ -use crate::{Backend, Infos, Module, VecZnx}; +use crate::{Backend, Infos, Module, VecZnx, VecZnxLayout}; use rand_distr::{Distribution, Normal}; use sampling::source::Source; diff --git a/base2k/src/svp.rs b/base2k/src/svp.rs index e293668..ba375c7 100644 --- a/base2k/src/svp.rs +++ b/base2k/src/svp.rs @@ -2,7 +2,7 @@ use std::marker::PhantomData; use crate::ffi::svp::{self, svp_ppol_t}; use crate::ffi::vec_znx_dft::vec_znx_dft_t; -use crate::{Backend, FFT64, Module, VecZnx, VecZnxDft, assert_alignement}; +use crate::{Backend, FFT64, Module, VecZnx, VecZnxDft, VecZnxLayout, assert_alignement}; use crate::{Infos, alloc_aligned, cast_mut}; use rand::seq::SliceRandom; diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index a6d5858..9b47eae 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -2,7 +2,7 @@ use crate::Backend; use crate::cast_mut; use crate::ffi::vec_znx; use crate::ffi::znx; -use crate::{Infos, Module}; +use crate::{Infos, Module, VecZnxLayout}; use crate::{alloc_aligned, assert_alignement}; use itertools::izip; use std::cmp::min; @@ -35,157 +35,6 @@ pub struct VecZnx { pub ptr: *mut i64, } -pub fn bytes_of_vec_znx(n: usize, cols: usize, limbs: usize) -> usize { - n * cols * limbs * size_of::() -} - -impl VecZnx { - /// Returns a new struct implementing [VecZnx] with the provided data as backing array. - /// - /// The struct will take ownership of buf[..[Self::bytes_of]] - /// - /// User must ensure that data is properly alligned and that - /// the limbs of data is equal to [Self::bytes_of]. - pub fn from_bytes(n: usize, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { - #[cfg(debug_assertions)] - { - assert!(cols > 0); - assert!(limbs > 0); - assert_eq!(bytes.len(), Self::bytes_of(n, cols, limbs)); - assert_alignement(bytes.as_ptr()); - } - unsafe { - let bytes_i64: &mut [i64] = cast_mut::(bytes); - let ptr: *mut i64 = bytes_i64.as_mut_ptr(); - Self { - n: n, - cols: cols, - limbs: limbs, - data: Vec::from_raw_parts(ptr, bytes.len(), bytes.len()), - ptr: ptr, - } - } - } - - pub fn from_bytes_borrow(n: usize, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { - #[cfg(debug_assertions)] - { - assert!(cols > 0); - assert!(limbs > 0); - assert!(bytes.len() >= Self::bytes_of(n, cols, limbs)); - assert_alignement(bytes.as_ptr()); - } - Self { - n: n, - cols: cols, - limbs: limbs, - data: Vec::new(), - ptr: bytes.as_mut_ptr() as *mut i64, - } - } - - pub fn bytes_of(n: usize, cols: usize, limbs: usize) -> usize { - bytes_of_vec_znx(n, cols, limbs) - } - - pub fn copy_from(&mut self, a: &Self) { - copy_vec_znx_from(self, a); - } - - pub fn borrowing(&self) -> bool { - self.data.len() == 0 - } - - /// Total limbs is [Self::n()] * [Self::poly_count()]. - pub fn raw(&self) -> &[i64] { - unsafe { std::slice::from_raw_parts(self.ptr, self.n * self.poly_count()) } - } - - /// Returns a reference to backend slice of the receiver. - /// Total size is [Self::n()] * [Self::poly_count()]. - pub fn raw_mut(&mut self) -> &mut [i64] { - unsafe { std::slice::from_raw_parts_mut(self.ptr, self.n * self.poly_count()) } - } - - /// Returns a non-mutable pointer to the backedn slice of the receiver. - pub fn as_ptr(&self) -> *const i64 { - self.ptr - } - - /// Returns a mutable pointer to the backedn slice of the receiver. - pub fn as_mut_ptr(&mut self) -> *mut i64 { - self.ptr - } - - /// Returns a non-mutable pointer starting a the (i, j)-th small poly. - pub fn at_ptr(&self, i: usize, j: usize) -> *const i64 { - #[cfg(debug_assertions)] - { - assert!(i < self.cols()); - assert!(j < self.limbs()); - } - let offset: usize = self.n * (j * self.cols() + i); - self.ptr.wrapping_add(offset) - } - - /// Returns a non-mutable reference to the i-th limb. - /// The returned array is of size [Self::n()] * [Self::cols()]. - pub fn at_limb(&self, i: usize) -> &[i64] { - unsafe { std::slice::from_raw_parts(self.at_ptr(0, i), self.n * self.cols()) } - } - - /// Returns a non-mutable reference to the (i, j)-th poly. - /// The returned array is of size [Self::n()]. - pub fn at_poly(&self, i: usize, j: usize) -> &[i64] { - unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n) } - } - - /// Returns a mutable pointer starting a the (i, j)-th small poly. - pub fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut i64 { - #[cfg(debug_assertions)] - { - assert!(i < self.cols()); - assert!(j < self.limbs()); - } - - let offset: usize = self.n * (j * self.cols() + i); - self.ptr.wrapping_add(offset) - } - - /// Returns a mutable reference to the i-th limb. - /// The returned array is of size [Self::n()] * [Self::cols()]. - pub fn at_limb_mut(&mut self, i: usize) -> &mut [i64] { - unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(0, i), self.n * self.cols()) } - } - - /// Returns a mutable reference to the (i, j)-th poly. - /// The returned array is of size [Self::n()]. - pub fn at_poly_mut(&mut self, i: usize, j: usize) -> &mut [i64] { - unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n) } - } - - pub fn zero(&mut self) { - unsafe { znx::znx_zero_i64_ref((self.n * self.poly_count()) as u64, self.ptr) } - } - - pub fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) { - normalize(log_base2k, self, carry) - } - - pub fn rsh(&mut self, log_base2k: usize, k: usize, carry: &mut [u8]) { - rsh(log_base2k, self, k, carry) - } - - pub fn switch_degree(&self, a: &mut Self) { - switch_degree(a, self) - } - - // Prints the first `n` coefficients of each limb - pub fn print(&self, n: usize) { - (0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])) - } -} - impl Infos for VecZnx { fn n(&self) -> usize { self.n @@ -212,6 +61,18 @@ impl Infos for VecZnx { } } +impl VecZnxLayout for VecZnx { + type Scalar = i64; + + fn as_ptr(&self) -> *const Self::Scalar { + self.ptr + } + + fn as_mut_ptr(&mut self) -> *mut Self::Scalar { + self.ptr + } +} + /// Copies the coefficients of `a` on the receiver. /// Copy is done with the minimum size matching both backing arrays. /// Panics if the cols do not match. @@ -271,6 +132,83 @@ impl VecZnx { .for_each(|x: &mut i64| *x &= mask) } } + + fn bytes_of(n: usize, cols: usize, limbs: usize) -> usize { + n * cols * limbs * size_of::() + } + + /// Returns a new struct implementing [VecZnx] with the provided data as backing array. + /// + /// The struct will take ownership of buf[..[Self::bytes_of]] + /// + /// User must ensure that data is properly alligned and that + /// the limbs of data is equal to [Self::bytes_of]. + pub fn from_bytes(n: usize, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { + #[cfg(debug_assertions)] + { + assert!(cols > 0); + assert!(limbs > 0); + assert_eq!(bytes.len(), Self::bytes_of(n, cols, limbs)); + assert_alignement(bytes.as_ptr()); + } + unsafe { + let bytes_i64: &mut [i64] = cast_mut::(bytes); + let ptr: *mut i64 = bytes_i64.as_mut_ptr(); + Self { + n: n, + cols: cols, + limbs: limbs, + data: Vec::from_raw_parts(ptr, bytes.len(), bytes.len()), + ptr: ptr, + } + } + } + + pub fn from_bytes_borrow(n: usize, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { + #[cfg(debug_assertions)] + { + assert!(cols > 0); + assert!(limbs > 0); + assert!(bytes.len() >= Self::bytes_of(n, cols, limbs)); + assert_alignement(bytes.as_ptr()); + } + Self { + n: n, + cols: cols, + limbs: limbs, + data: Vec::new(), + ptr: bytes.as_mut_ptr() as *mut i64, + } + } + + pub fn copy_from(&mut self, a: &Self) { + copy_vec_znx_from(self, a); + } + + pub fn borrowing(&self) -> bool { + self.data.len() == 0 + } + + pub fn zero(&mut self) { + unsafe { znx::znx_zero_i64_ref((self.n * self.poly_count()) as u64, self.ptr) } + } + + pub fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) { + normalize(log_base2k, self, carry) + } + + pub fn rsh(&mut self, log_base2k: usize, k: usize, carry: &mut [u8]) { + rsh(log_base2k, self, k, carry) + } + + pub fn switch_degree(&self, a: &mut Self) { + switch_degree(a, self) + } + + // Prints the first `n` coefficients of each limb + pub fn print(&self, n: usize) { + (0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])) + } } pub fn switch_degree(b: &mut VecZnx, a: &VecZnx) { @@ -395,6 +333,9 @@ pub trait VecZnxOps { /// * `limbs`: the number of limbs per polynomial (a.k.a small polynomials). fn new_vec_znx(&self, cols: usize, limbs: usize) -> VecZnx; + fn new_vec_znx_from_bytes(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnx; + fn new_vec_znx_from_bytes_borrow(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnx; + /// Returns the minimum number of bytes necessary to allocate /// a new [VecZnx] through [VecZnx::from_bytes]. fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize; @@ -457,7 +398,15 @@ impl VecZnxOps for Module { } fn bytes_of_vec_znx(&self, cols: usize, limbs: usize) -> usize { - bytes_of_vec_znx(self.n(), cols, limbs) + VecZnx::bytes_of(self.n(), cols, limbs) + } + + fn new_vec_znx_from_bytes(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnx { + VecZnx::from_bytes(self.n(), cols, limbs, bytes) + } + + fn new_vec_znx_from_bytes_borrow(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnx { + VecZnx::from_bytes_borrow(self.n(), cols, limbs, tmp_bytes) } fn vec_znx_normalize_tmp_bytes(&self, cols: usize) -> usize { diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index e1f656f..ac02aab 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,5 +1,5 @@ use crate::ffi::vec_znx_big::{self, vec_znx_big_t}; -use crate::{Backend, FFT64, Infos, Module, VecZnx, VecZnxDft, alloc_aligned, assert_alignement}; +use crate::{Backend, FFT64, Infos, Module, VecZnx, VecZnxDft, VecZnxLayout, alloc_aligned, assert_alignement}; use std::marker::PhantomData; pub struct VecZnxBig { @@ -12,7 +12,6 @@ pub struct VecZnxBig { } impl VecZnxBig { - pub fn new(module: &Module, cols: usize, limbs: usize) -> Self { #[cfg(debug_assertions)] { @@ -83,72 +82,6 @@ impl VecZnxBig { } } - /// Returns a non-mutable pointer to the backedn slice of the receiver. - pub fn as_ptr(&self) -> *const i64 { - self.ptr as *const i64 - } - - /// Returns a mutable pointer to the backedn slice of the receiver. - pub fn as_mut_ptr(&mut self) -> *mut i64 { - self.ptr as *mut i64 - } - - pub fn raw(&self) -> &[i64] { - unsafe { &std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) } - } - - pub fn raw_mut(&mut self) -> &mut [i64] { - let ptr: *mut i64 = self.ptr as *mut i64; - let size: usize = self.n() * self.poly_count(); - unsafe { std::slice::from_raw_parts_mut(ptr, size) } - } - - pub fn at_ptr(&self, i: usize, j: usize) -> *const i64 { - #[cfg(debug_assertions)] - { - assert!(i < self.cols()); - assert!(j < self.limbs()); - } - let offset: usize = self.n * (j * self.cols() + i); - self.as_ptr().wrapping_add(offset) - } - - /// Returns a non-mutable reference to the i-th limb. - /// The returned array is of size [Self::n()] * [Self::cols()]. - pub fn at_limb(&self, i: usize) -> &[i64] { - unsafe { std::slice::from_raw_parts(self.at_ptr(0, i), self.n * self.cols()) } - } - - /// Returns a non-mutable reference to the (i, j)-th poly. - /// The returned array is of size [Self::n()]. - pub fn at_poly(&self, i: usize, j: usize) -> &[i64] { - unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n) } - } - - /// Returns a mutable pointer starting a the (i, j)-th small poly. - pub fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut i64 { - #[cfg(debug_assertions)] - { - assert!(i < self.cols()); - assert!(j < self.limbs()); - } - - let offset: usize = self.n * (j * self.cols() + i); - self.as_mut_ptr().wrapping_add(offset) - } - - /// Returns a mutable reference to the i-th limb. - /// The returned array is of size [Self::n()] * [Self::cols()]. - pub fn at_limb_mut(&mut self, i: usize) -> &mut [i64] { - unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(0, i), self.n * self.cols()) } - } - - /// Returns a mutable reference to the (i, j)-th poly. - /// The returned array is of size [Self::n()]. - pub fn at_poly_mut(&mut self, i: usize, j: usize) -> &mut [i64] { - unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n) } - } - pub fn print(&self, n: usize) { (0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])); } @@ -180,6 +113,18 @@ impl Infos for VecZnxBig { } } +impl VecZnxLayout for VecZnxBig { + type Scalar = i64; + + fn as_ptr(&self) -> *const Self::Scalar { + self.ptr as *const Self::Scalar + } + + fn as_mut_ptr(&mut self) -> *mut Self::Scalar { + self.ptr as *mut Self::Scalar + } +} + pub trait VecZnxBigOps { /// Allocates a vector Z[X]/(X^N+1) that stores not normalized values. fn new_vec_znx_big(&self, cols: usize, limbs: usize) -> VecZnxBig; diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index d984cdd..6d3c6f6 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -1,8 +1,8 @@ use crate::ffi::vec_znx_big::vec_znx_big_t; use crate::ffi::vec_znx_dft; use crate::ffi::vec_znx_dft::{bytes_of_vec_znx_dft, vec_znx_dft_t}; -use crate::{Backend, FFT64, Infos, Module, VecZnxBig, assert_alignement}; -use crate::{DEFAULTALIGN, VecZnx, alloc_aligned}; +use crate::{Backend, FFT64, Infos, Module, VecZnxBig, VecZnxLayout, assert_alignement}; +use crate::{VecZnx, alloc_aligned}; use std::marker::PhantomData; pub struct VecZnxDft { @@ -32,6 +32,11 @@ impl VecZnxDft { _marker: PhantomData, } } + + fn bytes_of(module: &Module, cols: usize, limbs: usize) -> usize { + unsafe { bytes_of_vec_znx_dft(module.ptr, limbs as u64) as usize * cols } + } + /// Returns a new [VecZnxDft] with the provided data as backing array. /// User must ensure that data is properly alligned and that /// the size of data is at least equal to [Module::bytes_of_vec_znx_dft]. @@ -87,72 +92,6 @@ impl VecZnxDft { } } - /// Returns a non-mutable pointer to the backedn slice of the receiver. - pub fn as_ptr(&self) -> *const f64 { - self.ptr as *const f64 - } - - /// Returns a mutable pointer to the backedn slice of the receiver. - pub fn as_mut_ptr(&mut self) -> *mut f64 { - self.ptr as *mut f64 - } - - pub fn raw(&self) -> &[f64] { - unsafe { &std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) } - } - - pub fn raw_mut(&mut self) -> &mut [f64] { - let ptr: *mut f64 = self.ptr as *mut f64; - let size: usize = self.n() * self.poly_count(); - unsafe { std::slice::from_raw_parts_mut(ptr, size) } - } - - pub fn at_ptr(&self, i: usize, j: usize) -> *const f64 { - #[cfg(debug_assertions)] - { - assert!(i < self.cols()); - assert!(j < self.limbs()); - } - let offset: usize = self.n * (j * self.cols() + i); - self.as_ptr().wrapping_add(offset) - } - - /// Returns a non-mutable reference to the i-th limb. - /// The returned array is of size [Self::n()] * [Self::cols()]. - pub fn at_limb(&self, i: usize) -> &[f64] { - unsafe { std::slice::from_raw_parts(self.at_ptr(0, i), self.n * self.cols()) } - } - - /// Returns a non-mutable reference to the (i, j)-th poly. - /// The returned array is of size [Self::n()]. - pub fn at_poly(&self, i: usize, j: usize) -> &[f64] { - unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n) } - } - - /// Returns a mutable pointer starting a the (i, j)-th small poly. - pub fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut f64 { - #[cfg(debug_assertions)] - { - assert!(i < self.cols()); - assert!(j < self.limbs()); - } - - let offset: usize = self.n * (j * self.cols() + i); - self.as_mut_ptr().wrapping_add(offset) - } - - /// Returns a mutable reference to the i-th limb. - /// The returned array is of size [Self::n()] * [Self::cols()]. - pub fn at_limb_mut(&mut self, i: usize) -> &mut [f64] { - unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(0, i), self.n * self.cols()) } - } - - /// Returns a mutable reference to the (i, j)-th poly. - /// The returned array is of size [Self::n()]. - pub fn at_poly_mut(&mut self, i: usize, j: usize) -> &mut [f64] { - unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n) } - } - pub fn print(&self, n: usize) { (0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])); } @@ -184,6 +123,18 @@ impl Infos for VecZnxDft { } } +impl VecZnxLayout for VecZnxDft { + type Scalar = f64; + + fn as_ptr(&self) -> *const Self::Scalar { + self.ptr as *const Self::Scalar + } + + fn as_mut_ptr(&mut self) -> *mut Self::Scalar { + self.ptr as *mut Self::Scalar + } +} + pub trait VecZnxDftOps { /// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space. fn new_vec_znx_dft(&self, cols: usize, limbs: usize) -> VecZnxDft; @@ -257,7 +208,7 @@ impl VecZnxDftOps for Module { } fn bytes_of_vec_znx_dft(&self, cols: usize, limbs: usize) -> usize { - unsafe { bytes_of_vec_znx_dft(self.ptr, limbs as u64) as usize * cols } + VecZnxDft::bytes_of(&self, cols, limbs) } fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft) { @@ -363,7 +314,7 @@ impl VecZnxDftOps for Module { #[cfg(test)] mod tests { - use crate::{FFT64, Module, Sampling, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, alloc_aligned}; + use crate::{FFT64, Module, Sampling, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxLayout, VecZnxOps, alloc_aligned}; use itertools::izip; use sampling::source::Source; diff --git a/base2k/src/vmp.rs b/base2k/src/vmp.rs index f868a06..f2af561 100644 --- a/base2k/src/vmp.rs +++ b/base2k/src/vmp.rs @@ -1,7 +1,7 @@ use crate::ffi::vec_znx_big::vec_znx_big_t; use crate::ffi::vec_znx_dft::vec_znx_dft_t; use crate::ffi::vmp::{self, vmp_pmat_t}; -use crate::{Backend, FFT64, Infos, Module, VecZnx, VecZnxBig, VecZnxDft, alloc_aligned, assert_alignement}; +use crate::{Backend, FFT64, Infos, Module, VecZnx, VecZnxBig, VecZnxDft, VecZnxLayout, alloc_aligned, assert_alignement}; use std::marker::PhantomData; /// Vector Matrix Product Prepared Matrix: a vector of [VecZnx], @@ -592,8 +592,8 @@ impl VmpPMatOps for Module { #[cfg(test)] mod tests { use crate::{ - FFT64, Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, - alloc_aligned, + FFT64, Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxLayout, VecZnxOps, VmpPMat, + VmpPMatOps, alloc_aligned, }; use sampling::source::Source;