mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 21:26:41 +01:00
added debug checks for alignement
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
use crate::ffi::vmp;
|
||||
use crate::{Infos, Module, VecZnxApi, VecZnxDft};
|
||||
use crate::{assert_alignement, Infos, Module, VecZnxApi, VecZnxDft};
|
||||
|
||||
/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx],
|
||||
/// stored as a 3D matrix in the DFT domain in a single contiguous array.
|
||||
@@ -109,7 +109,7 @@ pub trait VmpPMatOps {
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use base2k::{Module, VmpPMat, VmpPMatOps, FFT64, Free};
|
||||
/// use base2k::{Module, VmpPMat, VmpPMatOps, FFT64, Free, alloc_aligned};
|
||||
/// use std::cmp::min;
|
||||
///
|
||||
/// let n: usize = 1024;
|
||||
@@ -119,7 +119,7 @@ pub trait VmpPMatOps {
|
||||
///
|
||||
/// let mut b_mat: Vec<i64> = vec![0i64;n * cols * rows];
|
||||
///
|
||||
/// let mut buf: Vec<u8> = vec![u8::default(); module.vmp_prepare_tmp_bytes(rows, cols)];
|
||||
/// let mut buf: Vec<u8> = alloc_aligned(module.vmp_prepare_tmp_bytes(rows, cols));
|
||||
///
|
||||
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
|
||||
/// module.vmp_prepare_contiguous(&mut vmp_pmat, &b_mat, &mut buf);
|
||||
@@ -140,7 +140,7 @@ pub trait VmpPMatOps {
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use base2k::{Module, FFT64, VmpPMat, VmpPMatOps, VecZnx, VecZnxApi, VecZnxOps, Free};
|
||||
/// use base2k::{Module, FFT64, VmpPMat, VmpPMatOps, VecZnx, VecZnxApi, VecZnxOps, Free, alloc_aligned};
|
||||
/// use std::cmp::min;
|
||||
///
|
||||
/// let n: usize = 1024;
|
||||
@@ -155,7 +155,7 @@ pub trait VmpPMatOps {
|
||||
///
|
||||
/// let slices: Vec<&[i64]> = vecznx.iter().map(|v| v.data.as_slice()).collect();
|
||||
///
|
||||
/// let mut buf: Vec<u8> = vec![u8::default(); module.vmp_prepare_tmp_bytes(rows, cols)];
|
||||
/// let mut buf: Vec<u8> = alloc_aligned(module.vmp_prepare_tmp_bytes(rows, cols));
|
||||
///
|
||||
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
|
||||
/// module.vmp_prepare_dblptr(&mut vmp_pmat, &slices, &mut buf);
|
||||
@@ -177,7 +177,7 @@ pub trait VmpPMatOps {
|
||||
/// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
|
||||
/// /// # Example
|
||||
/// ```
|
||||
/// use base2k::{Module, FFT64, VmpPMat, VmpPMatOps, VecZnx, VecZnxApi, VecZnxOps, Free};
|
||||
/// use base2k::{Module, FFT64, VmpPMat, VmpPMatOps, VecZnx, VecZnxApi, VecZnxOps, Free, alloc_aligned};
|
||||
/// use std::cmp::min;
|
||||
///
|
||||
/// let n: usize = 1024;
|
||||
@@ -187,7 +187,7 @@ pub trait VmpPMatOps {
|
||||
///
|
||||
/// let vecznx = module.new_vec_znx(cols);
|
||||
///
|
||||
/// let mut buf: Vec<u8> = vec![u8::default(); module.vmp_prepare_tmp_bytes(rows, cols)];
|
||||
/// let mut buf: Vec<u8> = alloc_aligned(module.vmp_prepare_tmp_bytes(rows, cols));
|
||||
///
|
||||
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
|
||||
/// module.vmp_prepare_row(&mut vmp_pmat, vecznx.raw(), 0, &mut buf);
|
||||
@@ -240,7 +240,7 @@ pub trait VmpPMatOps {
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use base2k::{Module, VecZnx, VecZnxOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps, FFT64, Free, VecZnxApi};
|
||||
/// use base2k::{Module, VecZnx, VecZnxOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps, FFT64, Free, VecZnxApi, alloc_aligned};
|
||||
///
|
||||
/// let n = 1024;
|
||||
///
|
||||
@@ -251,9 +251,7 @@ pub trait VmpPMatOps {
|
||||
/// let cols: usize = cols + 1;
|
||||
/// let c_cols: usize = cols;
|
||||
/// let a_cols: usize = cols;
|
||||
/// let tmp_bytes: usize = module.vmp_apply_dft_tmp_bytes(c_cols, a_cols, rows, cols);
|
||||
///
|
||||
/// let mut buf: Vec<u8> = vec![0; tmp_bytes];
|
||||
/// let mut buf: Vec<u8> = alloc_aligned(module.vmp_apply_dft_tmp_bytes(c_cols, a_cols, rows, cols));
|
||||
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
|
||||
///
|
||||
/// let a: VecZnx = module.new_vec_znx(cols);
|
||||
@@ -316,7 +314,7 @@ pub trait VmpPMatOps {
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use base2k::{Module, VecZnx, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps, FFT64, Free};
|
||||
/// use base2k::{Module, VecZnx, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps, FFT64, Free, alloc_aligned};
|
||||
///
|
||||
/// let n = 1024;
|
||||
///
|
||||
@@ -327,14 +325,12 @@ pub trait VmpPMatOps {
|
||||
/// let cols: usize = cols + 1;
|
||||
/// let c_cols: usize = cols;
|
||||
/// let a_cols: usize = cols;
|
||||
/// let tmp_bytes: usize = module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, rows, cols);
|
||||
///
|
||||
/// let mut buf: Vec<u8> = vec![0; tmp_bytes];
|
||||
/// let mut tmp_bytes: Vec<u8> = alloc_aligned(module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, rows, cols));
|
||||
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
|
||||
///
|
||||
/// let a_dft: VecZnxDft = module.new_vec_znx_dft(cols);
|
||||
/// let mut c_dft: VecZnxDft = module.new_vec_znx_dft(cols);
|
||||
/// module.vmp_apply_dft_to_dft(&mut c_dft, &a_dft, &vmp_pmat, &mut buf);
|
||||
/// module.vmp_apply_dft_to_dft(&mut c_dft, &a_dft, &vmp_pmat, &mut tmp_bytes);
|
||||
///
|
||||
/// a_dft.free();
|
||||
/// c_dft.free();
|
||||
@@ -369,8 +365,8 @@ pub trait VmpPMatOps {
|
||||
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use base2k::{Module, VecZnx, VecZnxOps, VecZnxDft, VmpPMat, VmpPMatOps, FFT64, Free, VecZnxApi, VecZnxDftOps};
|
||||
/// ```rust
|
||||
/// use base2k::{Module, VecZnx, VecZnxOps, VecZnxDft, VmpPMat, VmpPMatOps, FFT64, Free, VecZnxApi, VecZnxDftOps,alloc_aligned};
|
||||
///
|
||||
/// let n = 1024;
|
||||
///
|
||||
@@ -379,14 +375,12 @@ pub trait VmpPMatOps {
|
||||
///
|
||||
/// let rows: usize = cols;
|
||||
/// let cols: usize = cols + 1;
|
||||
/// let tmp_bytes: usize = module.vmp_apply_dft_to_dft_tmp_bytes(cols, cols, rows, cols);
|
||||
///
|
||||
/// let mut buf: Vec<u8> = vec![0; tmp_bytes];
|
||||
/// let mut tmp_bytes: Vec<u8> = alloc_aligned(module.vmp_apply_dft_to_dft_tmp_bytes(cols, cols, rows, cols));
|
||||
/// let a: VecZnx = module.new_vec_znx(cols);
|
||||
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
|
||||
///
|
||||
/// let mut c_dft: VecZnxDft = module.new_vec_znx_dft(cols);
|
||||
/// module.vmp_apply_dft_to_dft_inplace(&mut c_dft, &vmp_pmat, &mut buf);
|
||||
/// module.vmp_apply_dft_to_dft_inplace(&mut c_dft, &vmp_pmat, &mut tmp_bytes);
|
||||
///
|
||||
/// c_dft.free();
|
||||
/// vmp_pmat.free();
|
||||
@@ -411,9 +405,13 @@ impl VmpPMatOps for Module {
|
||||
unsafe { vmp::vmp_prepare_tmp_bytes(self.0, rows as u64, cols as u64) as usize }
|
||||
}
|
||||
|
||||
fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], buf: &mut [u8]) {
|
||||
fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], tmp_bytes: &mut [u8]) {
|
||||
debug_assert_eq!(a.len(), b.n * b.rows * b.cols);
|
||||
debug_assert!(buf.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols()));
|
||||
debug_assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols()));
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vmp::vmp_prepare_contiguous(
|
||||
self.0,
|
||||
@@ -421,12 +419,12 @@ impl VmpPMatOps for Module {
|
||||
a.as_ptr(),
|
||||
b.rows() as u64,
|
||||
b.cols() as u64,
|
||||
buf.as_mut_ptr(),
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], buf: &mut [u8]) {
|
||||
fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], tmp_bytes: &mut [u8]) {
|
||||
let ptrs: Vec<*const i64> = a.iter().map(|v| v.as_ptr()).collect();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
@@ -434,7 +432,8 @@ impl VmpPMatOps for Module {
|
||||
a.iter().for_each(|ai| {
|
||||
debug_assert_eq!(ai.len(), b.n * b.cols);
|
||||
});
|
||||
debug_assert!(buf.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols()));
|
||||
debug_assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols()));
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vmp::vmp_prepare_dblptr(
|
||||
@@ -443,14 +442,18 @@ impl VmpPMatOps for Module {
|
||||
ptrs.as_ptr(),
|
||||
b.rows() as u64,
|
||||
b.cols() as u64,
|
||||
buf.as_mut_ptr(),
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, buf: &mut [u8]) {
|
||||
fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]) {
|
||||
debug_assert_eq!(a.len(), b.cols() * self.n());
|
||||
debug_assert!(buf.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols()));
|
||||
debug_assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols()));
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vmp::vmp_prepare_row(
|
||||
self.0,
|
||||
@@ -459,7 +462,7 @@ impl VmpPMatOps for Module {
|
||||
row_i as u64,
|
||||
b.rows() as u64,
|
||||
b.cols() as u64,
|
||||
buf.as_mut_ptr(),
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -487,11 +490,15 @@ impl VmpPMatOps for Module {
|
||||
c: &mut VecZnxDft,
|
||||
a: &T,
|
||||
b: &VmpPMat,
|
||||
buf: &mut [u8],
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
debug_assert!(
|
||||
buf.len() >= self.vmp_apply_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())
|
||||
tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())
|
||||
);
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft(
|
||||
self.0,
|
||||
@@ -503,7 +510,7 @@ impl VmpPMatOps for Module {
|
||||
b.data(),
|
||||
b.rows() as u64,
|
||||
b.cols() as u64,
|
||||
buf.as_mut_ptr(),
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -526,11 +533,21 @@ impl VmpPMatOps for Module {
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]) {
|
||||
fn vmp_apply_dft_to_dft(
|
||||
&self,
|
||||
c: &mut VecZnxDft,
|
||||
a: &VecZnxDft,
|
||||
b: &VmpPMat,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
debug_assert!(
|
||||
buf.len()
|
||||
tmp_bytes.len()
|
||||
>= self.vmp_apply_dft_to_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())
|
||||
);
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft(
|
||||
self.0,
|
||||
@@ -541,16 +558,20 @@ impl VmpPMatOps for Module {
|
||||
b.data(),
|
||||
b.rows() as u64,
|
||||
b.cols() as u64,
|
||||
buf.as_mut_ptr(),
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, buf: &mut [u8]) {
|
||||
fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, tmp_bytes: &mut [u8]) {
|
||||
debug_assert!(
|
||||
buf.len()
|
||||
tmp_bytes.len()
|
||||
>= self.vmp_apply_dft_to_dft_tmp_bytes(b.cols(), b.cols(), a.rows(), a.cols())
|
||||
);
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft(
|
||||
self.0,
|
||||
@@ -561,7 +582,7 @@ impl VmpPMatOps for Module {
|
||||
a.data(),
|
||||
a.rows() as u64,
|
||||
a.cols() as u64,
|
||||
buf.as_mut_ptr(),
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user