improved alligned vec allocation & fixed vec_znx calls, fixed auto dft test

This commit is contained in:
Jean-Philippe Bossuat
2025-04-26 11:23:47 +02:00
parent 2a96f89047
commit 82082db727
6 changed files with 113 additions and 67 deletions

View File

@@ -53,7 +53,6 @@ impl<B: Backend> Infos for VmpPMat<B> {
}
impl VmpPMat<FFT64> {
fn new(module: &Module<FFT64>, rows: usize, cols: usize, limbs: usize) -> VmpPMat<FFT64> {
let mut data: Vec<u8> = alloc_aligned::<u8>(module.bytes_of_vmp_pmat(rows, cols, limbs));
let ptr: *mut u8 = data.as_mut_ptr();
@@ -352,21 +351,19 @@ pub trait VmpPMatOps<B: Backend> {
}
impl VmpPMatOps<FFT64> for Module<FFT64> {
fn new_vmp_pmat(&self, rows: usize, cols: usize, limbs: usize) -> VmpPMat<FFT64> {
VmpPMat::<FFT64>::new(self, rows, cols, limbs)
}
fn bytes_of_vmp_pmat(&self, rows: usize, cols: usize, limbs: usize) -> usize {
unsafe { vmp::bytes_of_vmp_pmat(self.ptr, rows as u64, (limbs* cols) as u64) as usize }
unsafe { vmp::bytes_of_vmp_pmat(self.ptr, rows as u64, (limbs * cols) as u64) as usize }
}
fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize, size: usize) -> usize {
unsafe { vmp::vmp_prepare_tmp_bytes(self.ptr, rows as u64, (size * cols) as u64) as usize }
unsafe { vmp::vmp_prepare_tmp_bytes(self.ptr, rows as u64, (size * cols) as u64) as usize }
}
fn vmp_prepare_contiguous(&self, b: &mut VmpPMat<FFT64>, a: &[i64], tmp_bytes: &mut [u8]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), b.n() * b.poly_count());
@@ -379,7 +376,7 @@ impl VmpPMatOps<FFT64> for Module<FFT64> {
b.as_mut_ptr() as *mut vmp_pmat_t,
a.as_ptr(),
b.rows() as u64,
(b.limbs()*b.cols()) as u64,
(b.limbs() * b.cols()) as u64,
tmp_bytes.as_mut_ptr(),
);
}
@@ -387,7 +384,7 @@ impl VmpPMatOps<FFT64> for Module<FFT64> {
fn vmp_prepare_row(&self, b: &mut VmpPMat<FFT64>, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]) {
#[cfg(debug_assertions)]
{
{
assert_eq!(a.len(), b.limbs() * self.n() * b.cols());
assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.limbs()));
assert_alignement(tmp_bytes.as_ptr());
@@ -399,7 +396,7 @@ impl VmpPMatOps<FFT64> for Module<FFT64> {
a.as_ptr(),
row_i as u64,
b.rows() as u64,
(b.limbs()*b.cols()) as u64,
(b.limbs() * b.cols()) as u64,
tmp_bytes.as_mut_ptr(),
);
}
@@ -419,7 +416,7 @@ impl VmpPMatOps<FFT64> for Module<FFT64> {
a.as_ptr() as *const vmp_pmat_t,
row_i as u64,
a.rows() as u64,
(a.limbs()*a.cols()) as u64,
(a.limbs() * a.cols()) as u64,
);
}
}