diff --git a/base2k/src/vector_matrix_product.rs b/base2k/src/vector_matrix_product.rs index 0136511..dc098bf 100644 --- a/base2k/src/vector_matrix_product.rs +++ b/base2k/src/vector_matrix_product.rs @@ -6,76 +6,90 @@ use crate::ffi::vmp::{ use crate::{Module, VecZnx, VecZnxDft}; use std::cmp::min; +/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx], +/// stored as a 3D matrix in the DFT domain in a single contiguous array. pub struct VmpPMat { + /// The pointer to the C memory. pub data: *mut vmp_pmat_t, + /// The number of [VecZnx]. pub rows: usize, + /// The number of limbs in each [VecZnx]. pub cols: usize, + /// The ring degree of each [VecZnx]. pub n: usize, } impl VmpPMat { + + /// Returns the pointer to the [vmp_pmat_t]. pub fn data(&self) -> *mut vmp_pmat_t { self.data } + /// Returns the number of rows of the [VmpPMat]. + /// The number of rows (i.e. of [VecZnx]) of the [VmpPMat]. pub fn rows(&self) -> usize { self.rows } + /// Returns the number of cols of the [VmpPMat]. + /// The number of cols refers to the number of limbs + /// of the prepared [VecZnx]. pub fn cols(&self) -> usize { self.cols } + /// Returns the ring dimension of the [VmpPMat]. pub fn n(&self) -> usize { self.n } - pub fn as_f64(&self) -> &[f64] { - let ptr: *const f64 = self.data as *const f64; - let len: usize = (self.rows() * self.cols() * self.n() * 8) / std::mem::size_of::(); - unsafe { &std::slice::from_raw_parts(ptr, len) } - } - - pub fn get_addr(&self, row: usize, col: usize, blk: usize) -> &[f64] { - let nrows: usize = self.rows(); - let ncols: usize = self.cols(); - if col == (ncols - 1) && (ncols & 1 == 1) { - &self.as_f64()[blk * nrows * ncols * 8 + col * nrows * 8 + row * 8..] - } else { - &self.as_f64()[blk * nrows * ncols * 8 - + (col / 2) * (2 * nrows) * 8 - + row * 2 * 8 - + (col % 2) * 8..] - } - } - - pub fn at(&self, row: usize, col: usize) -> Vec { - //assert!(row <= self.rows && col <= self.cols); - - let mut res: Vec = vec![f64::default(); self.n]; + /// Returns a copy of the backend array at index (i, j) of the [VmpPMat]. + /// When using FFT64 as backend, T should be f64. + /// When using NTT120 as backend, T should be i64. + pub fn at(&self, row: usize, col: usize) -> Vec { + let mut res: Vec = vec![T::default(); self.n]; if self.n < 8 { res.copy_from_slice( - &self.as_f64()[(row + col * self.rows()) * self.n() + &self.get_backend_array::()[(row + col * self.rows()) * self.n() ..(row + col * self.rows()) * (self.n() + 1)], ); } else { (0..self.n >> 3).for_each(|blk| { - res[blk * 8..(blk + 1) * 8].copy_from_slice(&self.get_addr(row, col, blk)[..8]); + res[blk * 8..(blk + 1) * 8].copy_from_slice(&self.get_array(row, col, blk)[..8]); }); } res } - pub fn at_mut(&self, row: usize, col: usize) -> &mut [f64] { - assert!(row <= self.rows && col <= self.cols); - let idx: usize = col * (self.n / 2 * self.rows) + row * (self.n >> 1); - let ptr: *mut f64 = self.data as *mut f64; - let len: usize = (self.rows() * self.cols() * self.n() * 8) / std::mem::size_of::(); - unsafe { &mut std::slice::from_raw_parts_mut(ptr, len)[idx..idx + self.n] } + /// When using FFT64 as backend, T should be f64. + /// When using NTT120 as backend, T should be i64. + fn get_array(&self, row: usize, col: usize, blk: usize) -> &[T] { + let nrows: usize = self.rows(); + let ncols: usize = self.cols(); + if col == (ncols - 1) && (ncols & 1 == 1) { + &self.get_backend_array::()[blk * nrows * ncols * 8 + col * nrows * 8 + row * 8..] + } else { + &self.get_backend_array::()[blk * nrows * ncols * 8 + + (col / 2) * (2 * nrows) * 8 + + row * 2 * 8 + + (col % 2) * 8..] + } } + /// Returns a non-mutable reference of T to the entire contiguous array of the [VmpPMat]. + /// When using FFT64 as backend, T should be f64. + /// When using NTT120 as backend, T should be i64. + /// The length of the returned array is rows * cols * n. + pub fn get_backend_array(&self) -> &[T] { + let ptr: *const T = self.data as *const T; + let len: usize = (self.rows() * self.cols() * self.n() * 8) / std::mem::size_of::(); + unsafe { &std::slice::from_raw_parts(ptr, len) } + } + + /// frees the memory and self destructs. pub fn delete(self) { unsafe { delete_vmp_pmat(self.data) }; drop(self); @@ -83,6 +97,8 @@ impl VmpPMat { } impl Module { + + /// Allocates a new [VmpPMat] with the given number of rows and columns. pub fn new_vmp_pmat(&self, rows: usize, cols: usize) -> VmpPMat { unsafe { VmpPMat { @@ -94,10 +110,25 @@ impl Module { } } + /// Returns the number of bytes needed as scratch space for [Self::vmp_prepare_contiguous]. pub fn vmp_prepare_contiguous_tmp_bytes(&self, rows: usize, cols: usize) -> usize { unsafe { vmp_prepare_contiguous_tmp_bytes(self.0, rows as u64, cols as u64) as usize } } + /// Prepares a [VmpPMat] given a contiguous array of [i64]. + /// The helper struct [Matrix3D] can be used to contruct the + /// appropriate contiguous array. + /// + /// # Example + /// ``` + /// let mut b_mat: Matrix3D = Matrix3D::new(rows, cols, n); + /// + /// (0..min(rows, cols)).for_each(|i| { + /// b_mat.at_mut(i, i)[1] = 1 as i64; + /// }); + /// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols); + /// module.vmp_prepare_contiguous(&mut vmp_pmat, &b_mat.data, &mut buf); + /// ``` pub fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], buf: &mut [u8]) { unsafe { vmp_prepare_contiguous( @@ -231,6 +262,14 @@ impl Module { } } +/// A helper struture that stores a 3D matrix as a contiguous array. +/// To be passed to [Module::vmp_prepare_contiguous]. +/// +/// rows: index of the i-th base2K power. +/// cols: index of the j-th limb of the i-th row. +/// n : polynomial degree. +/// +/// A [Matrix3D] can be seen as a vector of [VecZnx]. pub struct Matrix3D { pub data: Vec, pub rows: usize, @@ -239,6 +278,16 @@ pub struct Matrix3D { } impl Matrix3D { + /// Allocates a new [Matrix3D] with the respective dimensions. + /// + /// # Example + /// ``` + /// let rows = 5; // #decomp + /// let cols = 5; // #limbs + /// let n = 1024; // #coeffs + /// + /// let mut mat = Matrix3D::::new(rows, cols, n); + /// ``` pub fn new(rows: usize, cols: usize, n: usize) -> Self { let size = rows * cols * n; Self { @@ -249,18 +298,60 @@ impl Matrix3D { } } + /// Returns a non-mutable reference to the entry (row, col) of the [Matrix3D]. + /// The returned array is of size n. + /// + /// # Example + /// ``` + /// let rows = 5; // #decomp + /// let cols = 5; // #limbs + /// let n = 1024; // #coeffs + /// + /// let mut mat = Matrix3D::::new(rows, cols, n); + /// + /// let elem: &[i64] = mat.at(5, 5); // size n + /// ``` pub fn at(&self, row: usize, col: usize) -> &[T] { assert!(row <= self.rows && col <= self.cols); let idx: usize = row * (self.n * self.cols) + col * self.n; &self.data[idx..idx + self.n] } + /// Returns a mutable reference of the array at the (row, col) entry of the [Matrix3D]. + /// The returned array is of size n. + /// + /// # Example + /// ``` + /// let rows = 5; // #decomp + /// let cols = 5; // #limbs + /// let n = 1024; // #coeffs + /// + /// let mut mat = Matrix3D::::new(rows, cols, n); + /// + /// let elem: &mut [i64] = mat.at_mut(5, 5); // size n + /// ``` pub fn at_mut(&mut self, row: usize, col: usize) -> &mut [T] { assert!(row <= self.rows && col <= self.cols); let idx: usize = row * (self.n * self.cols) + col * self.n; &mut self.data[idx..idx + self.n] } + /// Sets the entry \[row\] of the [Matrix3D]. + /// Typicall this is used to assign a [VecZnx] to the i-th row + /// of the [Matrix3D]. + /// + /// # Example + /// ``` + /// let rows = 5; // #decomp + /// let cols = 5; // #limbs + /// let n = 1024; // #coeffs + /// + /// let mut mat = Matrix3D::::new(rows, cols, n); + /// + /// let a: Vec = VecZnx::new(n, cols); + /// + /// mat.set_row(1, &a.data); + /// ``` pub fn set_row(&mut self, row: usize, a: &[T]) { assert!( row < self.rows,