added debug checks for alignement

This commit is contained in:
Jean-Philippe Bossuat
2025-02-25 15:04:56 +01:00
parent 871b85e471
commit 483a142ab0
8 changed files with 140 additions and 80 deletions

View File

@@ -1,5 +1,5 @@
use crate::ffi::vmp;
use crate::{Infos, Module, VecZnxApi, VecZnxDft};
use crate::{assert_alignement, Infos, Module, VecZnxApi, VecZnxDft};
/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx],
/// stored as a 3D matrix in the DFT domain in a single contiguous array.
@@ -109,7 +109,7 @@ pub trait VmpPMatOps {
///
/// # Example
/// ```
/// use base2k::{Module, VmpPMat, VmpPMatOps, FFT64, Free};
/// use base2k::{Module, VmpPMat, VmpPMatOps, FFT64, Free, alloc_aligned};
/// use std::cmp::min;
///
/// let n: usize = 1024;
@@ -119,7 +119,7 @@ pub trait VmpPMatOps {
///
/// let mut b_mat: Vec<i64> = vec![0i64;n * cols * rows];
///
/// let mut buf: Vec<u8> = vec![u8::default(); module.vmp_prepare_tmp_bytes(rows, cols)];
/// let mut buf: Vec<u8> = alloc_aligned(module.vmp_prepare_tmp_bytes(rows, cols));
///
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
/// module.vmp_prepare_contiguous(&mut vmp_pmat, &b_mat, &mut buf);
@@ -140,7 +140,7 @@ pub trait VmpPMatOps {
///
/// # Example
/// ```
/// use base2k::{Module, FFT64, VmpPMat, VmpPMatOps, VecZnx, VecZnxApi, VecZnxOps, Free};
/// use base2k::{Module, FFT64, VmpPMat, VmpPMatOps, VecZnx, VecZnxApi, VecZnxOps, Free, alloc_aligned};
/// use std::cmp::min;
///
/// let n: usize = 1024;
@@ -155,7 +155,7 @@ pub trait VmpPMatOps {
///
/// let slices: Vec<&[i64]> = vecznx.iter().map(|v| v.data.as_slice()).collect();
///
/// let mut buf: Vec<u8> = vec![u8::default(); module.vmp_prepare_tmp_bytes(rows, cols)];
/// let mut buf: Vec<u8> = alloc_aligned(module.vmp_prepare_tmp_bytes(rows, cols));
///
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
/// module.vmp_prepare_dblptr(&mut vmp_pmat, &slices, &mut buf);
@@ -177,7 +177,7 @@ pub trait VmpPMatOps {
/// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
/// /// # Example
/// ```
/// use base2k::{Module, FFT64, VmpPMat, VmpPMatOps, VecZnx, VecZnxApi, VecZnxOps, Free};
/// use base2k::{Module, FFT64, VmpPMat, VmpPMatOps, VecZnx, VecZnxApi, VecZnxOps, Free, alloc_aligned};
/// use std::cmp::min;
///
/// let n: usize = 1024;
@@ -187,7 +187,7 @@ pub trait VmpPMatOps {
///
/// let vecznx = module.new_vec_znx(cols);
///
/// let mut buf: Vec<u8> = vec![u8::default(); module.vmp_prepare_tmp_bytes(rows, cols)];
/// let mut buf: Vec<u8> = alloc_aligned(module.vmp_prepare_tmp_bytes(rows, cols));
///
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
/// module.vmp_prepare_row(&mut vmp_pmat, vecznx.raw(), 0, &mut buf);
@@ -240,7 +240,7 @@ pub trait VmpPMatOps {
///
/// # Example
/// ```
/// use base2k::{Module, VecZnx, VecZnxOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps, FFT64, Free, VecZnxApi};
/// use base2k::{Module, VecZnx, VecZnxOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps, FFT64, Free, VecZnxApi, alloc_aligned};
///
/// let n = 1024;
///
@@ -251,9 +251,7 @@ pub trait VmpPMatOps {
/// let cols: usize = cols + 1;
/// let c_cols: usize = cols;
/// let a_cols: usize = cols;
/// let tmp_bytes: usize = module.vmp_apply_dft_tmp_bytes(c_cols, a_cols, rows, cols);
///
/// let mut buf: Vec<u8> = vec![0; tmp_bytes];
/// let mut buf: Vec<u8> = alloc_aligned(module.vmp_apply_dft_tmp_bytes(c_cols, a_cols, rows, cols));
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
///
/// let a: VecZnx = module.new_vec_znx(cols);
@@ -316,7 +314,7 @@ pub trait VmpPMatOps {
///
/// # Example
/// ```
/// use base2k::{Module, VecZnx, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps, FFT64, Free};
/// use base2k::{Module, VecZnx, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps, FFT64, Free, alloc_aligned};
///
/// let n = 1024;
///
@@ -327,14 +325,12 @@ pub trait VmpPMatOps {
/// let cols: usize = cols + 1;
/// let c_cols: usize = cols;
/// let a_cols: usize = cols;
/// let tmp_bytes: usize = module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, rows, cols);
///
/// let mut buf: Vec<u8> = vec![0; tmp_bytes];
/// let mut tmp_bytes: Vec<u8> = alloc_aligned(module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, rows, cols));
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
///
/// let a_dft: VecZnxDft = module.new_vec_znx_dft(cols);
/// let mut c_dft: VecZnxDft = module.new_vec_znx_dft(cols);
/// module.vmp_apply_dft_to_dft(&mut c_dft, &a_dft, &vmp_pmat, &mut buf);
/// module.vmp_apply_dft_to_dft(&mut c_dft, &a_dft, &vmp_pmat, &mut tmp_bytes);
///
/// a_dft.free();
/// c_dft.free();
@@ -369,8 +365,8 @@ pub trait VmpPMatOps {
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
///
/// # Example
/// ```
/// use base2k::{Module, VecZnx, VecZnxOps, VecZnxDft, VmpPMat, VmpPMatOps, FFT64, Free, VecZnxApi, VecZnxDftOps};
/// ```rust
/// use base2k::{Module, VecZnx, VecZnxOps, VecZnxDft, VmpPMat, VmpPMatOps, FFT64, Free, VecZnxApi, VecZnxDftOps,alloc_aligned};
///
/// let n = 1024;
///
@@ -379,14 +375,12 @@ pub trait VmpPMatOps {
///
/// let rows: usize = cols;
/// let cols: usize = cols + 1;
/// let tmp_bytes: usize = module.vmp_apply_dft_to_dft_tmp_bytes(cols, cols, rows, cols);
///
/// let mut buf: Vec<u8> = vec![0; tmp_bytes];
/// let mut tmp_bytes: Vec<u8> = alloc_aligned(module.vmp_apply_dft_to_dft_tmp_bytes(cols, cols, rows, cols));
/// let a: VecZnx = module.new_vec_znx(cols);
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
///
/// let mut c_dft: VecZnxDft = module.new_vec_znx_dft(cols);
/// module.vmp_apply_dft_to_dft_inplace(&mut c_dft, &vmp_pmat, &mut buf);
/// module.vmp_apply_dft_to_dft_inplace(&mut c_dft, &vmp_pmat, &mut tmp_bytes);
///
/// c_dft.free();
/// vmp_pmat.free();
@@ -411,9 +405,13 @@ impl VmpPMatOps for Module {
unsafe { vmp::vmp_prepare_tmp_bytes(self.0, rows as u64, cols as u64) as usize }
}
fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], buf: &mut [u8]) {
fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], tmp_bytes: &mut [u8]) {
debug_assert_eq!(a.len(), b.n * b.rows * b.cols);
debug_assert!(buf.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols()));
debug_assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols()));
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr());
}
unsafe {
vmp::vmp_prepare_contiguous(
self.0,
@@ -421,12 +419,12 @@ impl VmpPMatOps for Module {
a.as_ptr(),
b.rows() as u64,
b.cols() as u64,
buf.as_mut_ptr(),
tmp_bytes.as_mut_ptr(),
);
}
}
fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], buf: &mut [u8]) {
fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], tmp_bytes: &mut [u8]) {
let ptrs: Vec<*const i64> = a.iter().map(|v| v.as_ptr()).collect();
#[cfg(debug_assertions)]
{
@@ -434,7 +432,8 @@ impl VmpPMatOps for Module {
a.iter().for_each(|ai| {
debug_assert_eq!(ai.len(), b.n * b.cols);
});
debug_assert!(buf.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols()));
debug_assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols()));
assert_alignement(tmp_bytes.as_ptr());
}
unsafe {
vmp::vmp_prepare_dblptr(
@@ -443,14 +442,18 @@ impl VmpPMatOps for Module {
ptrs.as_ptr(),
b.rows() as u64,
b.cols() as u64,
buf.as_mut_ptr(),
tmp_bytes.as_mut_ptr(),
);
}
}
fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, buf: &mut [u8]) {
fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]) {
debug_assert_eq!(a.len(), b.cols() * self.n());
debug_assert!(buf.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols()));
debug_assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols()));
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr());
}
unsafe {
vmp::vmp_prepare_row(
self.0,
@@ -459,7 +462,7 @@ impl VmpPMatOps for Module {
row_i as u64,
b.rows() as u64,
b.cols() as u64,
buf.as_mut_ptr(),
tmp_bytes.as_mut_ptr(),
);
}
}
@@ -487,11 +490,15 @@ impl VmpPMatOps for Module {
c: &mut VecZnxDft,
a: &T,
b: &VmpPMat,
buf: &mut [u8],
tmp_bytes: &mut [u8],
) {
debug_assert!(
buf.len() >= self.vmp_apply_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())
tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())
);
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr());
}
unsafe {
vmp::vmp_apply_dft(
self.0,
@@ -503,7 +510,7 @@ impl VmpPMatOps for Module {
b.data(),
b.rows() as u64,
b.cols() as u64,
buf.as_mut_ptr(),
tmp_bytes.as_mut_ptr(),
)
}
}
@@ -526,11 +533,21 @@ impl VmpPMatOps for Module {
}
}
fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]) {
fn vmp_apply_dft_to_dft(
&self,
c: &mut VecZnxDft,
a: &VecZnxDft,
b: &VmpPMat,
tmp_bytes: &mut [u8],
) {
debug_assert!(
buf.len()
tmp_bytes.len()
>= self.vmp_apply_dft_to_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())
);
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr());
}
unsafe {
vmp::vmp_apply_dft_to_dft(
self.0,
@@ -541,16 +558,20 @@ impl VmpPMatOps for Module {
b.data(),
b.rows() as u64,
b.cols() as u64,
buf.as_mut_ptr(),
tmp_bytes.as_mut_ptr(),
)
}
}
fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, buf: &mut [u8]) {
fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, tmp_bytes: &mut [u8]) {
debug_assert!(
buf.len()
tmp_bytes.len()
>= self.vmp_apply_dft_to_dft_tmp_bytes(b.cols(), b.cols(), a.rows(), a.cols())
);
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr());
}
unsafe {
vmp::vmp_apply_dft_to_dft(
self.0,
@@ -561,7 +582,7 @@ impl VmpPMatOps for Module {
a.data(),
a.rows() as u64,
a.cols() as u64,
buf.as_mut_ptr(),
tmp_bytes.as_mut_ptr(),
)
}
}