diff --git a/base2k/src/ffi/module.rs b/base2k/src/ffi/module.rs index 755d613..e35d4c0 100644 --- a/base2k/src/ffi/module.rs +++ b/base2k/src/ffi/module.rs @@ -3,8 +3,6 @@ pub struct module_info_t { } pub type module_type_t = ::std::os::raw::c_uint; -pub const module_type_t_FFT64: module_type_t = 0; -pub const module_type_t_NTT120: module_type_t = 1; pub use self::module_type_t as MODULE_TYPE; pub type MODULE = module_info_t; diff --git a/base2k/src/infos.rs b/base2k/src/infos.rs index 08472d9..2445022 100644 --- a/base2k/src/infos.rs +++ b/base2k/src/infos.rs @@ -1,22 +1,43 @@ -use crate::LAYOUT; +#[derive(Copy, Clone)] +#[repr(C)] +pub struct LAYOUT{ + /// Ring degree. + n: usize, + /// Number of logical rows in the layout. + rows: usize, + /// Number of polynomials per row. + cols: usize, + /// Number of limbs per polynomial. + size: usize, + /// Whether limbs are interleaved across rows. + /// + /// For example, for (rows, cols, size) = (2, 2, 3): + /// + /// - `true`: layout is ((a0, b0, a1, b1), (c0, d0, c1, d1)) + /// - `false`: layout is ((a0, a1, b0, b1), (c0, c1, d0, d1)) + interleaved : bool, +} pub trait Infos { - /// Returns the ring degree of the receiver. - fn n(&self) -> usize; - /// Returns the base two logarithm of the ring dimension of the receiver. - fn log_n(&self) -> usize; - - /// Returns the number of stacked polynomials. - fn size(&self) -> usize; - - /// Returns the memory layout of the stacked polynomials. + /// Returns the full layout. fn layout(&self) -> LAYOUT; - /// Returns the number of columns of the receiver. - /// This method is equivalent to [Infos::cols]. + /// Returns the ring degree of the polynomials. + fn n(&self) -> usize; + + /// Returns the base two logarithm of the ring dimension of the polynomials. + fn log_n(&self) -> usize; + + /// Returns the number of rows. + fn rows(&self) -> usize; + + /// Returns the number of polynomials in each row. fn cols(&self) -> usize; - /// Returns the number of rows of the receiver. - fn rows(&self) -> usize; + /// Returns the number of limbs per polynomial. + fn size(&self) -> usize; + + /// Whether limbs are interleaved across rows. + fn interleaved(&self) -> bool; } diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index 7e97b00..ec0d2b7 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -27,13 +27,6 @@ pub use vmp::*; pub const GALOISGENERATOR: u64 = 5; pub const DEFAULTALIGN: usize = 64; -#[derive(Copy, Clone)] -#[repr(u8)] -pub enum LAYOUT { - ROW, - COL, -} - pub fn is_aligned_custom(ptr: *const T, align: usize) -> bool { (ptr as usize) % align == 0 } diff --git a/base2k/src/svp.rs b/base2k/src/svp.rs index 0e85a31..bc37f86 100644 --- a/base2k/src/svp.rs +++ b/base2k/src/svp.rs @@ -119,7 +119,7 @@ impl Scalar { n: self.n, size: 1, // TODO REVIEW IF NEED TO ADD size TO SCALAR cols: 1, - layout: LAYOUT::COL, + layout: LAYOUT::COL(1, 1), data: Vec::new(), ptr: self.ptr, } diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index 7445b5b..71a315e 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -22,15 +22,12 @@ pub struct VecZnx { /// Polynomial degree. pub n: usize, - /// Stack size + /// Number of limbs pub size: usize, - /// Stacking layout + /// Layout pub layout: LAYOUT, - /// Number of columns. - pub cols: usize, - /// Polynomial coefficients, as a contiguous array. Each col is equally spaced by n. pub data: Vec, @@ -38,8 +35,8 @@ pub struct VecZnx { pub ptr: *mut i64, } -pub fn bytes_of_vec_znx(n: usize, size: usize, cols: usize) -> usize { - n * size * cols * 8 +pub fn bytes_of_vec_znx(n: usize, layout: LAYOUT, size: usize) -> usize { + n * layout.size() * size * 8 } impl VecZnx { @@ -49,11 +46,11 @@ impl VecZnx { /// /// User must ensure that data is properly alligned and that /// the size of data is equal to [VecZnx::bytes_of]. - pub fn from_bytes(n: usize, size: usize, cols: usize, bytes: &mut [u8]) -> Self { + pub fn from_bytes(n: usize, layout: LAYOUT, size: usize, bytes: &mut [u8]) -> Self { #[cfg(debug_assertions)] { assert!(size > 0); - assert_eq!(bytes.len(), Self::bytes_of(n, size, cols)); + assert_eq!(bytes.len(), Self::bytes_of(n, layout, size)); assert_alignement(bytes.as_ptr()); } unsafe { @@ -62,33 +59,31 @@ impl VecZnx { VecZnx { n: n, size: size, - cols: cols, - layout: LAYOUT::COL, + layout: layout, data: Vec::from_raw_parts(ptr, bytes.len(), bytes.len()), ptr: ptr, } } } - pub fn from_bytes_borrow(n: usize, size: usize, cols: usize, bytes: &mut [u8]) -> Self { + pub fn from_bytes_borrow(n: usize, layout: LAYOUT, size: usize, bytes: &mut [u8]) -> Self { #[cfg(debug_assertions)] { assert!(size > 0); - assert!(bytes.len() >= Self::bytes_of(n, size, cols)); + assert!(bytes.len() >= Self::bytes_of(n, layout, size)); assert_alignement(bytes.as_ptr()); } VecZnx { n: n, size: size, - cols: cols, - layout: LAYOUT::COL, + layout: layout, data: Vec::new(), ptr: bytes.as_mut_ptr() as *mut i64, } } - pub fn bytes_of(n: usize, size: usize, cols: usize) -> usize { - bytes_of_vec_znx(n, size, cols) + pub fn bytes_of(n: usize, layout: LAYOUT, size: usize) -> usize { + bytes_of_vec_znx(n, layout, size) } pub fn copy_from(&mut self, a: &VecZnx) { @@ -99,15 +94,15 @@ impl VecZnx { self.data.len() == 0 } - /// Total size is [VecZnx::n()] * [VecZnx::size()] * [VecZnx::cols()]. + /// Total size is [VecZnx::n()] * [VecZnx::size()] * [VecZnx::size()]. pub fn raw(&self) -> &[i64] { - unsafe { std::slice::from_raw_parts(self.ptr, self.n * self.size * self.cols) } + unsafe { std::slice::from_raw_parts(self.ptr, self.n * self.size * self.size) } } /// Returns a reference to backend slice of the receiver. - /// Total size is [VecZnx::n()] * [VecZnx::size()] * [VecZnx::cols()]. + /// Total size is [VecZnx::n()] * [VecZnx::size()] * [VecZnx::size()]. pub fn raw_mut(&mut self) -> &mut [i64] { - unsafe { std::slice::from_raw_parts_mut(self.ptr, self.n * self.size * self.cols) } + unsafe { std::slice::from_raw_parts_mut(self.ptr, self.n * self.size * self.size) } } /// Returns a non-mutable pointer to the backedn slice of the receiver. @@ -124,7 +119,7 @@ impl VecZnx { pub fn at_ptr(&self, i: usize) -> *const i64 { #[cfg(debug_assertions)] { - assert!(i < self.cols); + assert!(i < self.size); } let offset: usize = self.n * self.size * i; self.ptr.wrapping_add(offset) @@ -141,7 +136,7 @@ impl VecZnx { #[cfg(debug_assertions)] { assert!(i < self.size); - assert!(j < self.cols); + assert!(j < self.size); } let offset: usize = self.n * (self.size * j + i); self.ptr.wrapping_add(offset) @@ -157,7 +152,7 @@ impl VecZnx { pub fn at_mut_ptr(&self, i: usize) -> *mut i64 { #[cfg(debug_assertions)] { - assert!(i < self.cols); + assert!(i < self.size); } let offset: usize = self.n * self.size * i; self.ptr.wrapping_add(offset) @@ -174,7 +169,7 @@ impl VecZnx { #[cfg(debug_assertions)] { assert!(i < self.size); - assert!(j < self.cols); + assert!(j < self.size); } let offset: usize = self.n * (self.size * j + i); @@ -189,7 +184,7 @@ impl VecZnx { } pub fn zero(&mut self) { - unsafe { znx::znx_zero_i64_ref((self.n * self.cols * self.size) as u64, self.ptr) } + unsafe { znx::znx_zero_i64_ref((self.n * self.size * self.size) as u64, self.ptr) } } pub fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) { @@ -204,8 +199,8 @@ impl VecZnx { switch_degree(a, self) } - pub fn print(&self, poly: usize, cols: usize, n: usize) { - (0..cols).for_each(|i| println!("{}: {:?}", i, &self.at_poly(poly, i)[..n])) + pub fn print(&self, poly: usize, size: usize, n: usize) { + (0..size).for_each(|i| println!("{}: {:?}", i, &self.at_poly(poly, i)[..n])) } } @@ -228,9 +223,9 @@ impl Infos for VecZnx { self.layout } - /// Returns the number of cols of the [VecZnx]. - fn cols(&self) -> usize { - self.cols + /// Returns the number of size of the [VecZnx]. + fn size(&self) -> usize { + self.size } /// Returns the number of rows of the [VecZnx]. @@ -249,22 +244,22 @@ pub fn copy_vec_znx_from(b: &mut VecZnx, a: &VecZnx) { } impl VecZnx { - /// Allocates a new [VecZnx] composed of #cols polynomials of Z\[X\]. - pub fn new(n: usize, size: usize, cols: usize) -> Self { + /// Allocates a new [VecZnx] composed of #size polynomials of Z\[X\]. + pub fn new(n: usize, size: usize, size: usize) -> Self { #[cfg(debug_assertions)] { assert!(n > 0); assert!(n & (n - 1) == 0); assert!(size > 0); - assert!(cols > 0); + assert!(size <= u8::MAX as usize); + assert!(size > 0); } - let mut data: Vec = alloc_aligned::(n * size * cols); + let mut data: Vec = alloc_aligned::(n * size * size); let ptr: *mut i64 = data.as_mut_ptr(); Self { n: n, + layout: LAYOUT::COL(1, size as u8), size: size, - layout: LAYOUT::COL, - cols: cols, data: data, ptr: ptr, } @@ -283,16 +278,16 @@ impl VecZnx { if !self.borrowing() { self.data - .truncate((self.cols() - k / log_base2k) * self.n() * self.size()); + .truncate((self.size() - k / log_base2k) * self.n() * self.size()); } - self.cols -= k / log_base2k; + self.size -= k / log_base2k; let k_rem: usize = k % log_base2k; if k_rem != 0 { let mask: i64 = ((1 << (log_base2k - k_rem - 1)) - 1) << k_rem; - self.at_mut(self.cols() - 1) + self.at_mut(self.size() - 1) .iter_mut() .for_each(|x: &mut i64| *x &= mask) } @@ -310,9 +305,9 @@ pub fn switch_degree(b: &mut VecZnx, a: &VecZnx) { b.zero(); } - let cols = min(a.cols(), b.cols()); + let size = min(a.size(), b.size()); - (0..cols).for_each(|i| { + (0..size).for_each(|i| { izip!( a.at(i).iter().step_by(gap_in), b.at_mut(i).iter_mut().step_by(gap_out) @@ -345,7 +340,7 @@ fn normalize(log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) { unsafe { znx::znx_zero_i64_ref(n as u64, carry_i64.as_mut_ptr()); - (0..a.cols()).rev().for_each(|i| { + (0..a.size()).rev().for_each(|i| { znx::znx_normalize( (n * size) as u64, log_base2k as u64, @@ -378,12 +373,12 @@ pub fn rsh(log_base2k: usize, a: &mut VecZnx, k: usize, tmp_bytes: &mut [u8]) { assert_alignement(tmp_bytes.as_ptr()); } - let cols: usize = a.cols(); - let cols_steps: usize = k / log_base2k; + let size: usize = a.size(); + let size_steps: usize = k / log_base2k; - a.raw_mut().rotate_right(n * size * cols_steps); + a.raw_mut().rotate_right(n * size * size_steps); unsafe { - znx::znx_zero_i64_ref((n * size * cols_steps) as u64, a.as_mut_ptr()); + znx::znx_zero_i64_ref((n * size * size_steps) as u64, a.as_mut_ptr()); } let k_rem = k % log_base2k; @@ -397,7 +392,7 @@ pub fn rsh(log_base2k: usize, a: &mut VecZnx, k: usize, tmp_bytes: &mut [u8]) { let log_base2k: usize = log_base2k; - (cols_steps..cols).for_each(|i| { + (size_steps..size).for_each(|i| { izip!(carry_i64.iter_mut(), a.at_mut(i).iter_mut()).for_each(|(ci, xi)| { *xi += *ci << log_base2k; *ci = get_base_k_carry(*xi, k_rem); @@ -417,12 +412,12 @@ pub trait VecZnxOps { /// /// # Arguments /// - /// * `cols`: the number of cols. - fn new_vec_znx(&self, size: usize, cols: usize) -> VecZnx; + /// * `size`: the number of size. + fn new_vec_znx(&self, size: usize, size: usize) -> VecZnx; /// Returns the minimum number of bytes necessary to allocate /// a new [VecZnx] through [VecZnx::from_bytes]. - fn bytes_of_vec_znx(&self, size: usize, cols: usize) -> usize; + fn bytes_of_vec_znx(&self, size: usize, size: usize) -> usize; fn vec_znx_normalize_tmp_bytes(&self, size: usize) -> usize; @@ -477,12 +472,12 @@ pub trait VecZnxOps { } impl VecZnxOps for Module { - fn new_vec_znx(&self, size: usize, cols: usize) -> VecZnx { - VecZnx::new(self.n(), size, cols) + fn new_vec_znx(&self, size: usize, size: usize) -> VecZnx { + VecZnx::new(self.n(), size, size) } - fn bytes_of_vec_znx(&self, size: usize, cols: usize) -> usize { - bytes_of_vec_znx(self.n(), size, cols) + fn bytes_of_vec_znx(&self, size: usize, size: usize) -> usize { + bytes_of_vec_znx(self.n(), size, size) } fn vec_znx_normalize_tmp_bytes(&self, size: usize) -> usize { @@ -502,13 +497,13 @@ impl VecZnxOps for Module { vec_znx::vec_znx_add( self.ptr, c.as_mut_ptr(), - c.cols() as u64, + c.size() as u64, (n * c.size()) as u64, a.as_ptr(), - a.cols() as u64, + a.size() as u64, (n * a.size()) as u64, b.as_ptr(), - b.cols() as u64, + b.size() as u64, (n * b.size()) as u64, ) } @@ -526,13 +521,13 @@ impl VecZnxOps for Module { vec_znx::vec_znx_add( self.ptr, b.as_mut_ptr(), - b.cols() as u64, + b.size() as u64, (n * b.size()) as u64, a.as_ptr(), - a.cols() as u64, + a.size() as u64, (n * a.size()) as u64, b.as_ptr(), - b.cols() as u64, + b.size() as u64, (n * b.size()) as u64, ) } @@ -551,13 +546,13 @@ impl VecZnxOps for Module { vec_znx::vec_znx_sub( self.ptr, c.as_mut_ptr(), - c.cols() as u64, + c.size() as u64, (n * c.size()) as u64, a.as_ptr(), - a.cols() as u64, + a.size() as u64, (n * a.size()) as u64, b.as_ptr(), - b.cols() as u64, + b.size() as u64, (n * b.size()) as u64, ) } @@ -575,13 +570,13 @@ impl VecZnxOps for Module { vec_znx::vec_znx_sub( self.ptr, b.as_mut_ptr(), - b.cols() as u64, + b.size() as u64, (n * b.size()) as u64, a.as_ptr(), - a.cols() as u64, + a.size() as u64, (n * a.size()) as u64, b.as_ptr(), - b.cols() as u64, + b.size() as u64, (n * b.size()) as u64, ) } @@ -599,13 +594,13 @@ impl VecZnxOps for Module { vec_znx::vec_znx_sub( self.ptr, b.as_mut_ptr(), - b.cols() as u64, + b.size() as u64, (n * b.size()) as u64, b.as_ptr(), - b.cols() as u64, + b.size() as u64, (n * b.size()) as u64, a.as_ptr(), - a.cols() as u64, + a.size() as u64, (n * a.size()) as u64, ) } @@ -622,10 +617,10 @@ impl VecZnxOps for Module { vec_znx::vec_znx_negate( self.ptr, b.as_mut_ptr(), - b.cols() as u64, + b.size() as u64, (n * b.size()) as u64, a.as_ptr(), - a.cols() as u64, + a.size() as u64, (n * a.size()) as u64, ) } @@ -641,10 +636,10 @@ impl VecZnxOps for Module { vec_znx::vec_znx_negate( self.ptr, a.as_mut_ptr(), - a.cols() as u64, + a.size() as u64, (n * a.size()) as u64, a.as_ptr(), - a.cols() as u64, + a.size() as u64, (n * a.size()) as u64, ) } @@ -662,10 +657,10 @@ impl VecZnxOps for Module { self.ptr, k, b.as_mut_ptr(), - b.cols() as u64, + b.size() as u64, (n * b.size()) as u64, a.as_ptr(), - a.cols() as u64, + a.size() as u64, (n * a.size()) as u64, ) } @@ -682,27 +677,27 @@ impl VecZnxOps for Module { self.ptr, k, a.as_mut_ptr(), - a.cols() as u64, + a.size() as u64, (n * a.size()) as u64, a.as_ptr(), - a.cols() as u64, + a.size() as u64, (n * a.size()) as u64, ) } } - /// Maps X^i to X^{ik} mod X^{n}+1. The mapping is applied independently on each cols. + /// Maps X^i to X^{ik} mod X^{n}+1. The mapping is applied independently on each size. /// /// # Arguments /// /// * `a`: input. /// * `b`: output. /// * `k`: the power to which to map each coefficients. - /// * `a_cols`: the number of a_cols on which to apply the mapping. + /// * `a_size`: the number of a_size on which to apply the mapping. /// /// # Panics /// - /// The method will panic if the argument `a` is greater than `a.cols()`. + /// The method will panic if the argument `a` is greater than `a.size()`. fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx) { let n: usize = self.n(); #[cfg(debug_assertions)] @@ -715,26 +710,26 @@ impl VecZnxOps for Module { self.ptr, k, b.as_mut_ptr(), - b.cols() as u64, + b.size() as u64, (n * b.size()) as u64, a.as_ptr(), - a.cols() as u64, + a.size() as u64, (n * a.size()) as u64, ); } } - /// Maps X^i to X^{ik} mod X^{n}+1. The mapping is applied independently on each cols. + /// Maps X^i to X^{ik} mod X^{n}+1. The mapping is applied independently on each size. /// /// # Arguments /// /// * `a`: input and output. /// * `k`: the power to which to map each coefficients. - /// * `a_cols`: the number of cols on which to apply the mapping. + /// * `a_size`: the number of size on which to apply the mapping. /// /// # Panics /// - /// The method will panic if the argument `cols` is greater than `self.cols()`. + /// The method will panic if the argument `size` is greater than `self.size()`. fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx) { let n: usize = self.n(); #[cfg(debug_assertions)] @@ -746,10 +741,10 @@ impl VecZnxOps for Module { self.ptr, k, a.as_mut_ptr(), - a.cols() as u64, + a.size() as u64, (n * a.size()) as u64, a.as_ptr(), - a.cols() as u64, + a.size() as u64, (n * a.size()) as u64, ); } diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index 8b31ea6..b512fd8 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -113,10 +113,6 @@ impl Infos for VecZnxDft { self.n } - fn size(&self) -> usize { - self.size - } - fn layout(&self) -> LAYOUT { self.layout } diff --git a/base2k/src/vmp.rs b/base2k/src/vmp.rs index 7d6c26f..b04232d 100644 --- a/base2k/src/vmp.rs +++ b/base2k/src/vmp.rs @@ -5,27 +5,25 @@ use crate::{BACKEND, Infos, LAYOUT, Module, VecZnx, VecZnxBig, VecZnxDft, alloc_ /// Vector Matrix Product Prepared Matrix: a vector of [VecZnx], /// stored as a 3D matrix in the DFT domain in a single contiguous array. -/// Each row of the [VmpPMat] can be seen as a [VecZnxDft]. +/// Each col of the [VmpPMat] can be seen as a collection of [VecZnxDft]. /// -/// The backend array of [VmpPMat] is allocate in C, -/// and thus must be manually freed. -/// -/// [VmpPMat] is used to permform a vector matrix product between a [VecZnx] and a [VmpPMat]. +/// [VmpPMat] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [VmpPMat]. /// See the trait [VmpPMatOps] for additional information. pub struct VmpPMat { /// Raw data, is empty if borrowing scratch space. data: Vec, /// Pointer to data. Can point to scratch space. ptr: *mut u8, - /// The number of [VecZnxDft]. + /// The size of the decomposition basis (i.e. nb. [VecZnxDft]). rows: usize, - /// The number of cols in each [VecZnxDft]. + /// The size of each [VecZnxDft]. cols: usize, /// The ring degree of each [VecZnxDft]. n: usize, - /// The number of stacked [VmpPMat], must be a square. - size: usize, - /// The memory layout of the stacked [VmpPMat]. + /// 1nd dim: the number of stacked [VecZnxDft] per decomposition basis (row-dimension). + /// A value greater than one enables to compute a sum of [VecZnx] x [VmpPMat]. + /// 2st dim: the number of stacked [VecZnxDft] (col-dimension). + /// A value greater than one enables to compute multiple [VecZnx] x [VmpPMat] in parallel. layout: LAYOUT, /// The backend fft or ntt. backend: BACKEND, @@ -531,6 +529,7 @@ impl VmpPMatOps for Module { #[cfg(debug_assertions)] { assert_alignement(tmp_bytes.as_ptr()); + assert_eq!(a.size()*a.size(), b.size()); } unsafe { vmp::vmp_apply_dft( @@ -539,7 +538,7 @@ impl VmpPMatOps for Module { c.cols() as u64, a.as_ptr(), a.cols() as u64, - a.n() as u64, + (a.n()*a.size()) as u64, b.as_ptr() as *const vmp_pmat_t, b.rows() as u64, b.cols() as u64, @@ -561,7 +560,7 @@ impl VmpPMatOps for Module { c.cols() as u64, a.as_ptr(), a.cols() as u64, - a.n() as u64, + (a.n()*a.size()) as u64, b.as_ptr() as *const vmp_pmat_t, b.rows() as u64, b.cols() as u64,