mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
added debug checks for alignement
This commit is contained in:
@@ -35,14 +35,18 @@ pub use vmp::*;
|
||||
pub const GALOISGENERATOR: u64 = 5;
|
||||
pub const DEFAULTALIGN: usize = 64;
|
||||
|
||||
fn is_aligned_custom<T>(ptr: *const T, align: usize) -> bool {
|
||||
pub fn is_aligned_custom<T>(ptr: *const T, align: usize) -> bool {
|
||||
(ptr as usize) % align == 0
|
||||
}
|
||||
|
||||
fn is_aligned<T>(ptr: *const T) -> bool {
|
||||
pub fn is_aligned<T>(ptr: *const T) -> bool {
|
||||
is_aligned_custom(ptr, DEFAULTALIGN)
|
||||
}
|
||||
|
||||
pub fn assert_alignement<T>(ptr: *const T) {
|
||||
assert!(is_aligned(ptr), "invalid alignement: ensure passed bytes have been allocated with [alloc_aligned_u8] or [alloc_aligned]")
|
||||
}
|
||||
|
||||
pub fn cast<T, V>(data: &[T]) -> &[V] {
|
||||
let ptr: *const V = data.as_ptr() as *const V;
|
||||
let len: usize = data.len() / std::mem::size_of::<V>();
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::ffi::svp::{self, bytes_of_svp_ppol};
|
||||
use crate::{alias_mut_slice_to_vec, is_aligned, Module, VecZnxApi, VecZnxDft};
|
||||
use crate::ffi::svp;
|
||||
use crate::{alias_mut_slice_to_vec, assert_alignement, Module, VecZnxApi, VecZnxDft};
|
||||
|
||||
use crate::{alloc_aligned, cast, Infos};
|
||||
use rand::seq::SliceRandom;
|
||||
@@ -37,7 +37,10 @@ impl Scalar {
|
||||
n,
|
||||
size
|
||||
);
|
||||
debug_assert!(is_aligned(buf.as_ptr()));
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(buf.as_ptr())
|
||||
}
|
||||
self.0 = alias_mut_slice_to_vec(cast::<u8, i64>(&buf[..size]))
|
||||
}
|
||||
|
||||
@@ -75,7 +78,10 @@ impl SvpPPol {
|
||||
}
|
||||
|
||||
pub fn from_bytes(size: usize, bytes: &mut [u8]) -> SvpPPol {
|
||||
debug_assert!(is_aligned(bytes.as_ptr()));
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(bytes.as_ptr())
|
||||
}
|
||||
debug_assert!(bytes.len() << 3 >= size);
|
||||
SvpPPol(bytes.as_mut_ptr() as *mut svp::svp_ppol_t, size)
|
||||
}
|
||||
|
||||
@@ -2,8 +2,7 @@ use crate::cast_mut;
|
||||
use crate::ffi::vec_znx;
|
||||
use crate::ffi::znx;
|
||||
use crate::ffi::znx::znx_zero_i64_ref;
|
||||
use crate::is_aligned;
|
||||
use crate::{alias_mut_slice_to_vec, alloc_aligned};
|
||||
use crate::{alias_mut_slice_to_vec, alloc_aligned, assert_alignement};
|
||||
use crate::{Infos, Module};
|
||||
use itertools::izip;
|
||||
use std::cmp::min;
|
||||
@@ -137,7 +136,10 @@ impl VecZnxApi for VecZnxBorrow {
|
||||
cols,
|
||||
size
|
||||
);
|
||||
debug_assert!(is_aligned(bytes.as_ptr()));
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(bytes.as_ptr())
|
||||
}
|
||||
VecZnxBorrow {
|
||||
n: n,
|
||||
cols: cols,
|
||||
@@ -237,7 +239,10 @@ impl VecZnxApi for VecZnx {
|
||||
cols,
|
||||
size
|
||||
);
|
||||
debug_assert!(is_aligned(bytes.as_ptr()));
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(bytes.as_ptr())
|
||||
}
|
||||
VecZnx {
|
||||
n: n,
|
||||
data: alias_mut_slice_to_vec(cast_mut(&mut bytes[..size])),
|
||||
@@ -410,7 +415,10 @@ fn normalize<T: VecZnxCommon>(log_base2k: usize, a: &mut T, tmp_bytes: &mut [u8]
|
||||
tmp_bytes.len(),
|
||||
n
|
||||
);
|
||||
debug_assert!(is_aligned(tmp_bytes.as_ptr()));
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr())
|
||||
}
|
||||
|
||||
let carry_i64: &mut [i64] = cast_mut(tmp_bytes);
|
||||
|
||||
@@ -439,7 +447,10 @@ pub fn rsh<T: VecZnxCommon>(log_base2k: usize, a: &mut T, k: usize, tmp_bytes: &
|
||||
n
|
||||
);
|
||||
|
||||
debug_assert!(is_aligned(tmp_bytes.as_ptr()));
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr())
|
||||
}
|
||||
|
||||
let cols: usize = a.cols();
|
||||
let cols_steps: usize = k / log_base2k;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::ffi::vec_znx_big;
|
||||
use crate::ffi::vec_znx_dft;
|
||||
use crate::{is_aligned, Infos, Module, VecZnxApi, VecZnxDft};
|
||||
use crate::{assert_alignement, Infos, Module, VecZnxApi, VecZnxDft};
|
||||
|
||||
pub struct VecZnxBig(pub *mut vec_znx_big::vec_znx_bigcoeff_t, pub usize);
|
||||
|
||||
@@ -8,10 +8,13 @@ impl VecZnxBig {
|
||||
/// Returns a new [VecZnxBig] with the provided data as backing array.
|
||||
/// User must ensure that data is properly alligned and that
|
||||
/// the size of data is at least equal to [Module::bytes_of_vec_znx_big].
|
||||
pub fn from_bytes(cols: usize, data: &mut [u8]) -> VecZnxBig {
|
||||
debug_assert!(is_aligned(data.as_ptr()));
|
||||
pub fn from_bytes(cols: usize, bytes: &mut [u8]) -> VecZnxBig {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(bytes.as_ptr())
|
||||
};
|
||||
VecZnxBig(
|
||||
data.as_mut_ptr() as *mut vec_znx_big::vec_znx_bigcoeff_t,
|
||||
bytes.as_mut_ptr() as *mut vec_znx_big::vec_znx_bigcoeff_t,
|
||||
cols,
|
||||
)
|
||||
}
|
||||
@@ -101,7 +104,10 @@ impl VecZnxBigOps for Module {
|
||||
bytes.len(),
|
||||
<Module as VecZnxBigOps>::bytes_of_vec_znx_big(self, cols)
|
||||
);
|
||||
debug_assert!(is_aligned(bytes.as_ptr()));
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(bytes.as_ptr())
|
||||
}
|
||||
VecZnxBig::from_bytes(cols, bytes)
|
||||
}
|
||||
|
||||
@@ -185,13 +191,16 @@ impl VecZnxBigOps for Module {
|
||||
a: &VecZnxBig,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
assert!(
|
||||
debug_assert!(
|
||||
tmp_bytes.len() >= <Module as VecZnxBigOps>::vec_znx_big_normalize_tmp_bytes(self),
|
||||
"invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_normalize_tmp_bytes()={}",
|
||||
tmp_bytes.len(),
|
||||
<Module as VecZnxBigOps>::vec_znx_big_normalize_tmp_bytes(self)
|
||||
);
|
||||
debug_assert!(is_aligned(tmp_bytes.as_ptr()));
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr())
|
||||
}
|
||||
unsafe {
|
||||
vec_znx_big::vec_znx_big_normalize_base2k(
|
||||
self.0,
|
||||
@@ -220,13 +229,16 @@ impl VecZnxBigOps for Module {
|
||||
a_range_step: usize,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
assert!(
|
||||
debug_assert!(
|
||||
tmp_bytes.len() >= <Module as VecZnxBigOps>::vec_znx_big_range_normalize_base2k_tmp_bytes(self),
|
||||
"invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_range_normalize_base2k_tmp_bytes()={}",
|
||||
tmp_bytes.len(),
|
||||
<Module as VecZnxBigOps>::vec_znx_big_range_normalize_base2k_tmp_bytes(self)
|
||||
);
|
||||
debug_assert!(is_aligned(tmp_bytes.as_ptr()));
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr())
|
||||
}
|
||||
unsafe {
|
||||
vec_znx_big::vec_znx_big_range_normalize_base2k(
|
||||
self.0,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::ffi::vec_znx_big;
|
||||
use crate::ffi::vec_znx_dft;
|
||||
use crate::ffi::vec_znx_dft::bytes_of_vec_znx_dft;
|
||||
use crate::{is_aligned, Infos, Module, VecZnxApi, VecZnxBig};
|
||||
use crate::{assert_alignement, Infos, Module, VecZnxApi, VecZnxBig};
|
||||
|
||||
pub struct VecZnxDft(pub *mut vec_znx_dft::vec_znx_dft_t, pub usize);
|
||||
|
||||
@@ -10,7 +10,10 @@ impl VecZnxDft {
|
||||
/// User must ensure that data is properly alligned and that
|
||||
/// the size of data is at least equal to [Module::bytes_of_vec_znx_dft].
|
||||
pub fn from_bytes(cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft {
|
||||
debug_assert!(is_aligned(tmp_bytes.as_ptr()));
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr())
|
||||
}
|
||||
VecZnxDft(
|
||||
tmp_bytes.as_mut_ptr() as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
cols,
|
||||
@@ -84,7 +87,10 @@ impl VecZnxDftOps for Module {
|
||||
tmp_bytes.len(),
|
||||
<Module as VecZnxDftOps>::bytes_of_vec_znx_dft(self, cols)
|
||||
);
|
||||
debug_assert!(is_aligned(tmp_bytes.as_ptr()));
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr())
|
||||
}
|
||||
VecZnxDft::from_bytes(cols, tmp_bytes)
|
||||
}
|
||||
|
||||
@@ -157,7 +163,10 @@ impl VecZnxDftOps for Module {
|
||||
tmp_bytes.len(),
|
||||
<Module as VecZnxDftOps>::vec_znx_idft_tmp_bytes(self)
|
||||
);
|
||||
debug_assert!(is_aligned(tmp_bytes.as_ptr()));
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr())
|
||||
}
|
||||
unsafe {
|
||||
vec_znx_dft::vec_znx_idft(
|
||||
self.0,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::ffi::vmp;
|
||||
use crate::{Infos, Module, VecZnxApi, VecZnxDft};
|
||||
use crate::{assert_alignement, Infos, Module, VecZnxApi, VecZnxDft};
|
||||
|
||||
/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx],
|
||||
/// stored as a 3D matrix in the DFT domain in a single contiguous array.
|
||||
@@ -109,7 +109,7 @@ pub trait VmpPMatOps {
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use base2k::{Module, VmpPMat, VmpPMatOps, FFT64, Free};
|
||||
/// use base2k::{Module, VmpPMat, VmpPMatOps, FFT64, Free, alloc_aligned};
|
||||
/// use std::cmp::min;
|
||||
///
|
||||
/// let n: usize = 1024;
|
||||
@@ -119,7 +119,7 @@ pub trait VmpPMatOps {
|
||||
///
|
||||
/// let mut b_mat: Vec<i64> = vec![0i64;n * cols * rows];
|
||||
///
|
||||
/// let mut buf: Vec<u8> = vec![u8::default(); module.vmp_prepare_tmp_bytes(rows, cols)];
|
||||
/// let mut buf: Vec<u8> = alloc_aligned(module.vmp_prepare_tmp_bytes(rows, cols));
|
||||
///
|
||||
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
|
||||
/// module.vmp_prepare_contiguous(&mut vmp_pmat, &b_mat, &mut buf);
|
||||
@@ -140,7 +140,7 @@ pub trait VmpPMatOps {
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use base2k::{Module, FFT64, VmpPMat, VmpPMatOps, VecZnx, VecZnxApi, VecZnxOps, Free};
|
||||
/// use base2k::{Module, FFT64, VmpPMat, VmpPMatOps, VecZnx, VecZnxApi, VecZnxOps, Free, alloc_aligned};
|
||||
/// use std::cmp::min;
|
||||
///
|
||||
/// let n: usize = 1024;
|
||||
@@ -155,7 +155,7 @@ pub trait VmpPMatOps {
|
||||
///
|
||||
/// let slices: Vec<&[i64]> = vecznx.iter().map(|v| v.data.as_slice()).collect();
|
||||
///
|
||||
/// let mut buf: Vec<u8> = vec![u8::default(); module.vmp_prepare_tmp_bytes(rows, cols)];
|
||||
/// let mut buf: Vec<u8> = alloc_aligned(module.vmp_prepare_tmp_bytes(rows, cols));
|
||||
///
|
||||
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
|
||||
/// module.vmp_prepare_dblptr(&mut vmp_pmat, &slices, &mut buf);
|
||||
@@ -177,7 +177,7 @@ pub trait VmpPMatOps {
|
||||
/// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
|
||||
/// /// # Example
|
||||
/// ```
|
||||
/// use base2k::{Module, FFT64, VmpPMat, VmpPMatOps, VecZnx, VecZnxApi, VecZnxOps, Free};
|
||||
/// use base2k::{Module, FFT64, VmpPMat, VmpPMatOps, VecZnx, VecZnxApi, VecZnxOps, Free, alloc_aligned};
|
||||
/// use std::cmp::min;
|
||||
///
|
||||
/// let n: usize = 1024;
|
||||
@@ -187,7 +187,7 @@ pub trait VmpPMatOps {
|
||||
///
|
||||
/// let vecznx = module.new_vec_znx(cols);
|
||||
///
|
||||
/// let mut buf: Vec<u8> = vec![u8::default(); module.vmp_prepare_tmp_bytes(rows, cols)];
|
||||
/// let mut buf: Vec<u8> = alloc_aligned(module.vmp_prepare_tmp_bytes(rows, cols));
|
||||
///
|
||||
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
|
||||
/// module.vmp_prepare_row(&mut vmp_pmat, vecznx.raw(), 0, &mut buf);
|
||||
@@ -240,7 +240,7 @@ pub trait VmpPMatOps {
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use base2k::{Module, VecZnx, VecZnxOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps, FFT64, Free, VecZnxApi};
|
||||
/// use base2k::{Module, VecZnx, VecZnxOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps, FFT64, Free, VecZnxApi, alloc_aligned};
|
||||
///
|
||||
/// let n = 1024;
|
||||
///
|
||||
@@ -251,9 +251,7 @@ pub trait VmpPMatOps {
|
||||
/// let cols: usize = cols + 1;
|
||||
/// let c_cols: usize = cols;
|
||||
/// let a_cols: usize = cols;
|
||||
/// let tmp_bytes: usize = module.vmp_apply_dft_tmp_bytes(c_cols, a_cols, rows, cols);
|
||||
///
|
||||
/// let mut buf: Vec<u8> = vec![0; tmp_bytes];
|
||||
/// let mut buf: Vec<u8> = alloc_aligned(module.vmp_apply_dft_tmp_bytes(c_cols, a_cols, rows, cols));
|
||||
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
|
||||
///
|
||||
/// let a: VecZnx = module.new_vec_znx(cols);
|
||||
@@ -316,7 +314,7 @@ pub trait VmpPMatOps {
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use base2k::{Module, VecZnx, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps, FFT64, Free};
|
||||
/// use base2k::{Module, VecZnx, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps, FFT64, Free, alloc_aligned};
|
||||
///
|
||||
/// let n = 1024;
|
||||
///
|
||||
@@ -327,14 +325,12 @@ pub trait VmpPMatOps {
|
||||
/// let cols: usize = cols + 1;
|
||||
/// let c_cols: usize = cols;
|
||||
/// let a_cols: usize = cols;
|
||||
/// let tmp_bytes: usize = module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, rows, cols);
|
||||
///
|
||||
/// let mut buf: Vec<u8> = vec![0; tmp_bytes];
|
||||
/// let mut tmp_bytes: Vec<u8> = alloc_aligned(module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, rows, cols));
|
||||
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
|
||||
///
|
||||
/// let a_dft: VecZnxDft = module.new_vec_znx_dft(cols);
|
||||
/// let mut c_dft: VecZnxDft = module.new_vec_znx_dft(cols);
|
||||
/// module.vmp_apply_dft_to_dft(&mut c_dft, &a_dft, &vmp_pmat, &mut buf);
|
||||
/// module.vmp_apply_dft_to_dft(&mut c_dft, &a_dft, &vmp_pmat, &mut tmp_bytes);
|
||||
///
|
||||
/// a_dft.free();
|
||||
/// c_dft.free();
|
||||
@@ -369,8 +365,8 @@ pub trait VmpPMatOps {
|
||||
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use base2k::{Module, VecZnx, VecZnxOps, VecZnxDft, VmpPMat, VmpPMatOps, FFT64, Free, VecZnxApi, VecZnxDftOps};
|
||||
/// ```rust
|
||||
/// use base2k::{Module, VecZnx, VecZnxOps, VecZnxDft, VmpPMat, VmpPMatOps, FFT64, Free, VecZnxApi, VecZnxDftOps,alloc_aligned};
|
||||
///
|
||||
/// let n = 1024;
|
||||
///
|
||||
@@ -379,14 +375,12 @@ pub trait VmpPMatOps {
|
||||
///
|
||||
/// let rows: usize = cols;
|
||||
/// let cols: usize = cols + 1;
|
||||
/// let tmp_bytes: usize = module.vmp_apply_dft_to_dft_tmp_bytes(cols, cols, rows, cols);
|
||||
///
|
||||
/// let mut buf: Vec<u8> = vec![0; tmp_bytes];
|
||||
/// let mut tmp_bytes: Vec<u8> = alloc_aligned(module.vmp_apply_dft_to_dft_tmp_bytes(cols, cols, rows, cols));
|
||||
/// let a: VecZnx = module.new_vec_znx(cols);
|
||||
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
|
||||
///
|
||||
/// let mut c_dft: VecZnxDft = module.new_vec_znx_dft(cols);
|
||||
/// module.vmp_apply_dft_to_dft_inplace(&mut c_dft, &vmp_pmat, &mut buf);
|
||||
/// module.vmp_apply_dft_to_dft_inplace(&mut c_dft, &vmp_pmat, &mut tmp_bytes);
|
||||
///
|
||||
/// c_dft.free();
|
||||
/// vmp_pmat.free();
|
||||
@@ -411,9 +405,13 @@ impl VmpPMatOps for Module {
|
||||
unsafe { vmp::vmp_prepare_tmp_bytes(self.0, rows as u64, cols as u64) as usize }
|
||||
}
|
||||
|
||||
fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], buf: &mut [u8]) {
|
||||
fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], tmp_bytes: &mut [u8]) {
|
||||
debug_assert_eq!(a.len(), b.n * b.rows * b.cols);
|
||||
debug_assert!(buf.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols()));
|
||||
debug_assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols()));
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vmp::vmp_prepare_contiguous(
|
||||
self.0,
|
||||
@@ -421,12 +419,12 @@ impl VmpPMatOps for Module {
|
||||
a.as_ptr(),
|
||||
b.rows() as u64,
|
||||
b.cols() as u64,
|
||||
buf.as_mut_ptr(),
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], buf: &mut [u8]) {
|
||||
fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], tmp_bytes: &mut [u8]) {
|
||||
let ptrs: Vec<*const i64> = a.iter().map(|v| v.as_ptr()).collect();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
@@ -434,7 +432,8 @@ impl VmpPMatOps for Module {
|
||||
a.iter().for_each(|ai| {
|
||||
debug_assert_eq!(ai.len(), b.n * b.cols);
|
||||
});
|
||||
debug_assert!(buf.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols()));
|
||||
debug_assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols()));
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vmp::vmp_prepare_dblptr(
|
||||
@@ -443,14 +442,18 @@ impl VmpPMatOps for Module {
|
||||
ptrs.as_ptr(),
|
||||
b.rows() as u64,
|
||||
b.cols() as u64,
|
||||
buf.as_mut_ptr(),
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, buf: &mut [u8]) {
|
||||
fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]) {
|
||||
debug_assert_eq!(a.len(), b.cols() * self.n());
|
||||
debug_assert!(buf.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols()));
|
||||
debug_assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols()));
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vmp::vmp_prepare_row(
|
||||
self.0,
|
||||
@@ -459,7 +462,7 @@ impl VmpPMatOps for Module {
|
||||
row_i as u64,
|
||||
b.rows() as u64,
|
||||
b.cols() as u64,
|
||||
buf.as_mut_ptr(),
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -487,11 +490,15 @@ impl VmpPMatOps for Module {
|
||||
c: &mut VecZnxDft,
|
||||
a: &T,
|
||||
b: &VmpPMat,
|
||||
buf: &mut [u8],
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
debug_assert!(
|
||||
buf.len() >= self.vmp_apply_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())
|
||||
tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())
|
||||
);
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft(
|
||||
self.0,
|
||||
@@ -503,7 +510,7 @@ impl VmpPMatOps for Module {
|
||||
b.data(),
|
||||
b.rows() as u64,
|
||||
b.cols() as u64,
|
||||
buf.as_mut_ptr(),
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -526,11 +533,21 @@ impl VmpPMatOps for Module {
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]) {
|
||||
fn vmp_apply_dft_to_dft(
|
||||
&self,
|
||||
c: &mut VecZnxDft,
|
||||
a: &VecZnxDft,
|
||||
b: &VmpPMat,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
debug_assert!(
|
||||
buf.len()
|
||||
tmp_bytes.len()
|
||||
>= self.vmp_apply_dft_to_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())
|
||||
);
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft(
|
||||
self.0,
|
||||
@@ -541,16 +558,20 @@ impl VmpPMatOps for Module {
|
||||
b.data(),
|
||||
b.rows() as u64,
|
||||
b.cols() as u64,
|
||||
buf.as_mut_ptr(),
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, buf: &mut [u8]) {
|
||||
fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, tmp_bytes: &mut [u8]) {
|
||||
debug_assert!(
|
||||
buf.len()
|
||||
tmp_bytes.len()
|
||||
>= self.vmp_apply_dft_to_dft_tmp_bytes(b.cols(), b.cols(), a.rows(), a.cols())
|
||||
);
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft(
|
||||
self.0,
|
||||
@@ -561,7 +582,7 @@ impl VmpPMatOps for Module {
|
||||
a.data(),
|
||||
a.rows() as u64,
|
||||
a.cols() as u64,
|
||||
buf.as_mut_ptr(),
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user