Added VecZnxBorrow

This commit is contained in:
Jean-Philippe Bossuat
2025-02-14 18:26:54 +01:00
parent 68e61dc0e3
commit 67d8fd31b7
12 changed files with 605 additions and 595 deletions

View File

@@ -1,6 +1,5 @@
use crate::ffi::vmp;
use crate::{Infos, Module, VecZnx, VecZnxDft};
use std::cmp::min;
use crate::{Infos, Module, VecZnx, VecZnxApi, VecZnxDft};
/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx],
/// stored as a 3D matrix in the DFT domain in a single contiguous array.
@@ -110,7 +109,7 @@ pub trait VmpPMatOps {
///
/// # Example
/// ```
/// use base2k::{Module, Matrix3D, VmpPMat, VmpPMatOps, FFT64, Free};
/// use base2k::{Module, VmpPMat, VmpPMatOps, FFT64, Free};
/// use std::cmp::min;
///
/// let n: usize = 1024;
@@ -118,17 +117,12 @@ pub trait VmpPMatOps {
/// let rows = 5;
/// let cols = 6;
///
/// let mut b_mat: Matrix3D<i64> = Matrix3D::new(rows, cols, n);
///
/// // Populates the i-th row of b_math with X^1 * 2^(i * log_w) (here log_w is undefined)
/// (0..min(rows, cols)).for_each(|i| {
/// b_mat.at_mut(i, i)[1] = 1 as i64;
/// });
/// let mut b_mat: Vec<i64> = vec![0i64;n * cols * rows];
///
/// let mut buf: Vec<u8> = vec![u8::default(); module.vmp_prepare_tmp_bytes(rows, cols)];
///
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
/// module.vmp_prepare_contiguous(&mut vmp_pmat, &b_mat.data, &mut buf);
/// module.vmp_prepare_contiguous(&mut vmp_pmat, &b_mat, &mut buf);
///
/// vmp_pmat.free() // don't forget to free the memory once vmp_pmat is not needed anymore.
/// ```
@@ -146,7 +140,7 @@ pub trait VmpPMatOps {
///
/// # Example
/// ```
/// use base2k::{Module, FFT64, Matrix3D, VmpPMat, VmpPMatOps, VecZnx, VecZnxOps, Free};
/// use base2k::{Module, FFT64, VmpPMat, VmpPMatOps, VecZnx, VecZnxApi, VecZnxOps, Free};
/// use std::cmp::min;
///
/// let n: usize = 1024;
@@ -159,17 +153,15 @@ pub trait VmpPMatOps {
/// vecznx.push(module.new_vec_znx(cols));
/// });
///
/// let dble: Vec<&[i64]> = vecznx.iter().map(|v| v.data.as_slice()).collect();
///
/// let mut buf: Vec<u8> = vec![u8::default(); module.vmp_prepare_tmp_bytes(rows, cols)];
///
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
/// module.vmp_prepare_dblptr(&mut vmp_pmat, &dble, &mut buf);
/// module.vmp_prepare_dblptr(&mut vmp_pmat, &vecznx, &mut buf);
///
/// vmp_pmat.free();
/// module.free();
/// ```
fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], buf: &mut [u8]);
fn vmp_prepare_dblptr<T: VecZnxApi + Infos>(&self, b: &mut VmpPMat, a: &Vec<T>, buf: &mut [u8]);
/// Prepares the ith-row of [VmpPMat] from a vector of [VecZnx].
///
@@ -183,7 +175,7 @@ pub trait VmpPMatOps {
/// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
/// /// # Example
/// ```
/// use base2k::{Module, FFT64, Matrix3D, VmpPMat, VmpPMatOps, VecZnx, VecZnxOps, Free};
/// use base2k::{Module, FFT64, VmpPMat, VmpPMatOps, VecZnx, VecZnxOps, Free};
/// use std::cmp::min;
///
/// let n: usize = 1024;
@@ -191,7 +183,7 @@ pub trait VmpPMatOps {
/// let rows: usize = 5;
/// let cols: usize = 6;
///
/// let vecznx = vec![0i64; cols*n];
/// let vecznx = module.new_vec_znx(cols);
///
/// let mut buf: Vec<u8> = vec![u8::default(); module.vmp_prepare_tmp_bytes(rows, cols)];
///
@@ -201,7 +193,13 @@ pub trait VmpPMatOps {
/// vmp_pmat.free();
/// module.free();
/// ```
fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]);
fn vmp_prepare_row<T: VecZnxApi + Infos>(
&self,
b: &mut VmpPMat,
a: &T,
row_i: usize,
tmp_bytes: &mut [u8],
);
/// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft].
///
@@ -246,7 +244,7 @@ pub trait VmpPMatOps {
///
/// # Example
/// ```
/// use base2k::{Module, VecZnx, VecZnxOps, VecZnxDft, VmpPMat, VmpPMatOps, FFT64, Free};
/// use base2k::{Module, VecZnx, VecZnxOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps, FFT64, Free, VecZnxApi};
///
/// let n = 1024;
///
@@ -270,7 +268,13 @@ pub trait VmpPMatOps {
/// vmp_pmat.free();
/// module.free();
/// ```
fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]);
fn vmp_apply_dft<T: VecZnxApi + Infos>(
&self,
c: &mut VecZnxDft,
a: &T,
b: &VmpPMat,
buf: &mut [u8],
);
/// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft_to_dft].
///
@@ -316,7 +320,7 @@ pub trait VmpPMatOps {
///
/// # Example
/// ```
/// use base2k::{Module, VecZnx, VecZnxDft, VmpPMat, VmpPMatOps, FFT64, Free};
/// use base2k::{Module, VecZnx, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps, FFT64, Free};
///
/// let n = 1024;
///
@@ -370,7 +374,7 @@ pub trait VmpPMatOps {
///
/// # Example
/// ```
/// use base2k::{Module, VecZnx, VecZnxOps, VecZnxDft, VmpPMat, VmpPMatOps, FFT64, Free};
/// use base2k::{Module, VecZnx, VecZnxOps, VecZnxDft, VmpPMat, VmpPMatOps, FFT64, Free, VecZnxApi, VecZnxDftOps};
///
/// let n = 1024;
///
@@ -424,7 +428,12 @@ impl VmpPMatOps for Module {
}
}
fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], buf: &mut [u8]) {
fn vmp_prepare_dblptr<T: VecZnxApi + Infos>(
&self,
b: &mut VmpPMat,
a: &Vec<T>,
buf: &mut [u8],
) {
let ptrs: Vec<*const i64> = a.iter().map(|v| v.as_ptr()).collect();
unsafe {
vmp::vmp_prepare_dblptr(
@@ -438,7 +447,13 @@ impl VmpPMatOps for Module {
}
}
fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, buf: &mut [u8]) {
fn vmp_prepare_row<T: VecZnxApi + Infos>(
&self,
b: &mut VmpPMat,
a: &T,
row_i: usize,
buf: &mut [u8],
) {
unsafe {
vmp::vmp_prepare_row(
self.0,
@@ -470,7 +485,13 @@ impl VmpPMatOps for Module {
}
}
fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]) {
fn vmp_apply_dft<T: VecZnxApi + Infos>(
&self,
c: &mut VecZnxDft,
a: &T,
b: &VmpPMat,
buf: &mut [u8],
) {
unsafe {
vmp::vmp_apply_dft(
self.0,
@@ -537,135 +558,3 @@ impl VmpPMatOps for Module {
}
}
}
/// A helper struture that stores a 3D matrix as a contiguous array.
/// To be passed to [VmpPMatOps::vmp_prepare_contiguous].
///
/// rows: index of the i-th base2K power.
/// cols: index of the j-th limb of the i-th row.
/// n : polynomial degree.
///
/// A [Matrix3D] can be seen as a vector of [VecZnx].
pub struct Matrix3D<T> {
pub data: Vec<T>,
pub rows: usize,
pub cols: usize,
pub n: usize,
}
impl<T: Default + Clone + std::marker::Copy> Matrix3D<T> {
/// Allocates a new [Matrix3D] with the respective dimensions.
///
/// # Arguments
///
/// * `rows`: the number of rows of the matrix.
/// * `cols`: the number of cols of the matrix.
/// # `n`: the size of each entry of the matrix.
///
/// # Example
/// ```
/// use base2k::Matrix3D;
///
/// let rows = 5; // #decomp
/// let cols = 5; // #limbs
/// let n = 1024; // #coeffs
///
/// let mut mat = Matrix3D::<i64>::new(rows, cols, n);
/// ```
pub fn new(rows: usize, cols: usize, n: usize) -> Self {
let size = rows * cols * n;
Self {
data: vec![T::default(); size],
rows,
cols,
n,
}
}
/// Returns a non-mutable reference to the entry (row, col) of the [Matrix3D].
/// The returned array is of size n.
///
/// # Arguments
///
/// * `row`: the index of the row.
/// * `col`: the index of the col.
///
/// # Example
/// ```
/// use base2k::Matrix3D;
///
/// let rows = 5; // #decomp
/// let cols = 5; // #limbs
/// let n = 1024; // #coeffs
///
/// let mut mat = Matrix3D::<i64>::new(rows, cols, n);
///
/// let elem: &[i64] = mat.at(4, 4); // size n
/// ```
pub fn at(&self, row: usize, col: usize) -> &[T] {
assert!(row < self.rows && col < self.cols);
let idx: usize = row * (self.n * self.cols) + col * self.n;
&self.data[idx..idx + self.n]
}
/// Returns a mutable reference of the array at the (row, col) entry of the [Matrix3D].
/// The returned array is of size n.
///
/// # Arguments
///
/// * `row`: the index of the row.
/// * `col`: the index of the col.
///
/// # Example
/// ```
/// use base2k::Matrix3D;
///
/// let rows = 5; // #decomp
/// let cols = 5; // #limbs
/// let n = 1024; // #coeffs
///
/// let mut mat = Matrix3D::<i64>::new(rows, cols, n);
///
/// let elem: &mut [i64] = mat.at_mut(4, 4); // size n
/// ```
pub fn at_mut(&mut self, row: usize, col: usize) -> &mut [T] {
assert!(row < self.rows && col < self.cols);
let idx: usize = row * (self.n * self.cols) + col * self.n;
&mut self.data[idx..idx + self.n]
}
/// Sets the entry \[row\] of the [Matrix3D].
/// Typicall this is used to assign a [VecZnx] to the i-th row
/// of the [Matrix3D].
///
/// # Arguments
///
/// * `row`: the index of the row.
/// * `a`: the data to encode onthe row.
///
/// # Example
/// ```
/// use base2k::{Matrix3D, VecZnx};
///
/// let rows = 5; // #decomp
/// let cols = 5; // #limbs
/// let n = 1024; // #coeffs
///
/// let mut mat = Matrix3D::<i64>::new(rows, cols, n);
///
/// let a: VecZnx = VecZnx::new(n, cols);
///
/// mat.set_row(1, &a.data);
/// ```
pub fn set_row(&mut self, row: usize, a: &[T]) {
assert!(
row < self.rows,
"invalid argument row: row={} > self.rows={}",
row,
self.rows
);
let idx: usize = row * (self.n * self.cols);
let size: usize = min(a.len(), self.cols * self.n);
self.data[idx..idx + size].copy_from_slice(&a[..size]);
}
}