refactoring of vec_znx

This commit is contained in:
Jean-Philippe Bossuat
2025-04-28 10:33:15 +02:00
parent 39bbe5b917
commit 2f9a1cf6d9
13 changed files with 1218 additions and 738 deletions

View File

@@ -22,7 +22,7 @@ pub struct MatZnxDft<B: Backend> {
/// Number of cols
cols: usize,
/// The number of small polynomials
limbs: usize,
size: usize,
_marker: PhantomData<B>,
}
@@ -31,10 +31,6 @@ impl<B: Backend> ZnxInfos for MatZnxDft<B> {
self.n
}
fn log_n(&self) -> usize {
(usize::BITS - (self.n() - 1).leading_zeros()) as _
}
fn rows(&self) -> usize {
self.rows
}
@@ -43,18 +39,14 @@ impl<B: Backend> ZnxInfos for MatZnxDft<B> {
self.cols
}
fn limbs(&self) -> usize {
self.limbs
}
fn poly_count(&self) -> usize {
self.rows * self.cols * self.limbs
fn size(&self) -> usize {
self.size
}
}
impl MatZnxDft<FFT64> {
fn new(module: &Module<FFT64>, rows: usize, cols: usize, limbs: usize) -> MatZnxDft<FFT64> {
let mut data: Vec<u8> = alloc_aligned::<u8>(module.bytes_of_mat_znx_dft(rows, cols, limbs));
fn new(module: &Module<FFT64>, rows: usize, cols: usize, size: usize) -> MatZnxDft<FFT64> {
let mut data: Vec<u8> = alloc_aligned::<u8>(module.bytes_of_mat_znx_dft(rows, cols, size));
let ptr: *mut u8 = data.as_mut_ptr();
MatZnxDft::<FFT64> {
data: data,
@@ -62,7 +54,7 @@ impl MatZnxDft<FFT64> {
n: module.n(),
rows: rows,
cols: cols,
limbs: limbs,
size: size,
_marker: PhantomData,
}
}
@@ -115,7 +107,7 @@ impl MatZnxDft<FFT64> {
fn at_block(&self, row: usize, col: usize, blk: usize) -> &[f64] {
let nrows: usize = self.rows();
let nsize: usize = self.limbs();
let nsize: usize = self.size();
if col == (nsize - 1) && (nsize & 1 == 1) {
&self.raw()[blk * nrows * nsize * 8 + col * nrows * 8 + row * 8..]
} else {
@@ -127,7 +119,7 @@ impl MatZnxDft<FFT64> {
/// This trait implements methods for vector matrix product,
/// that is, multiplying a [VecZnx] with a [VmpPMat].
pub trait MatZnxDftOps<B: Backend> {
fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, limbs: usize) -> usize;
fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> usize;
/// Allocates a new [VmpPMat] with the given number of rows and columns.
///
@@ -135,7 +127,7 @@ pub trait MatZnxDftOps<B: Backend> {
///
/// * `rows`: number of rows (number of [VecZnxDft]).
/// * `size`: number of size (number of size of each [VecZnxDft]).
fn new_mat_znx_dft(&self, rows: usize, cols: usize, limbs: usize) -> MatZnxDft<B>;
fn new_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> MatZnxDft<B>;
/// Returns the number of bytes needed as scratch space for [VmpPMatOps::vmp_prepare_contiguous].
///
@@ -351,12 +343,12 @@ pub trait MatZnxDftOps<B: Backend> {
}
impl MatZnxDftOps<FFT64> for Module<FFT64> {
fn new_mat_znx_dft(&self, rows: usize, cols: usize, limbs: usize) -> MatZnxDft<FFT64> {
MatZnxDft::<FFT64>::new(self, rows, cols, limbs)
fn new_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> MatZnxDft<FFT64> {
MatZnxDft::<FFT64>::new(self, rows, cols, size)
}
fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, limbs: usize) -> usize {
unsafe { vmp::bytes_of_vmp_pmat(self.ptr, rows as u64, (limbs * cols) as u64) as usize }
fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> usize {
unsafe { vmp::bytes_of_vmp_pmat(self.ptr, rows as u64, (size * cols) as u64) as usize }
}
fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize, size: usize) -> usize {
@@ -367,7 +359,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), b.n() * b.poly_count());
assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.limbs()));
assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.size()));
assert_alignement(tmp_bytes.as_ptr());
}
unsafe {
@@ -376,7 +368,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
b.as_mut_ptr() as *mut vmp_pmat_t,
a.as_ptr(),
b.rows() as u64,
(b.limbs() * b.cols()) as u64,
(b.size() * b.cols()) as u64,
tmp_bytes.as_mut_ptr(),
);
}
@@ -385,8 +377,8 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
fn vmp_prepare_row(&self, b: &mut MatZnxDft<FFT64>, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), b.limbs() * self.n() * b.cols());
assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.limbs()));
assert_eq!(a.len(), b.size() * self.n() * b.cols());
assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.size()));
assert_alignement(tmp_bytes.as_ptr());
}
unsafe {
@@ -396,7 +388,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
a.as_ptr(),
row_i as u64,
b.rows() as u64,
(b.limbs() * b.cols()) as u64,
(b.size() * b.cols()) as u64,
tmp_bytes.as_mut_ptr(),
);
}
@@ -406,7 +398,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), b.n());
assert_eq!(a.limbs(), b.limbs());
assert_eq!(a.size(), b.size());
assert_eq!(a.cols(), b.cols());
}
unsafe {
@@ -416,7 +408,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
a.as_ptr() as *const vmp_pmat_t,
row_i as u64,
a.rows() as u64,
(a.limbs() * a.cols()) as u64,
(a.size() * a.cols()) as u64,
);
}
}
@@ -425,7 +417,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), b.n());
assert_eq!(a.limbs(), b.limbs());
assert_eq!(a.size(), b.size());
}
unsafe {
vmp::vmp_prepare_row_dft(
@@ -434,7 +426,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
a.ptr as *const vec_znx_dft_t,
row_i as u64,
b.rows() as u64,
b.limbs() as u64,
b.size() as u64,
);
}
}
@@ -443,7 +435,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), b.n());
assert_eq!(a.limbs(), b.limbs());
assert_eq!(a.size(), b.size());
}
unsafe {
vmp::vmp_extract_row_dft(
@@ -452,7 +444,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
a.as_ptr() as *const vmp_pmat_t,
row_i as u64,
a.rows() as u64,
a.limbs() as u64,
a.size() as u64,
);
}
}
@@ -470,7 +462,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
}
fn vmp_apply_dft(&self, c: &mut VecZnxDft<FFT64>, a: &VecZnx, b: &MatZnxDft<FFT64>, tmp_bytes: &mut [u8]) {
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs()));
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size()));
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr());
@@ -479,20 +471,20 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
vmp::vmp_apply_dft(
self.ptr,
c.ptr as *mut vec_znx_dft_t,
c.limbs() as u64,
c.size() as u64,
a.as_ptr(),
a.limbs() as u64,
a.size() as u64,
(a.n() * a.cols()) as u64,
b.as_ptr() as *const vmp_pmat_t,
b.rows() as u64,
b.limbs() as u64,
b.size() as u64,
tmp_bytes.as_mut_ptr(),
)
}
}
fn vmp_apply_dft_add(&self, c: &mut VecZnxDft<FFT64>, a: &VecZnx, b: &MatZnxDft<FFT64>, tmp_bytes: &mut [u8]) {
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs()));
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size()));
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr());
@@ -501,13 +493,13 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
vmp::vmp_apply_dft_add(
self.ptr,
c.ptr as *mut vec_znx_dft_t,
c.limbs() as u64,
c.size() as u64,
a.as_ptr(),
a.limbs() as u64,
(a.n() * a.limbs()) as u64,
a.size() as u64,
(a.n() * a.size()) as u64,
b.as_ptr() as *const vmp_pmat_t,
b.rows() as u64,
b.limbs() as u64,
b.size() as u64,
tmp_bytes.as_mut_ptr(),
)
}
@@ -526,7 +518,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
}
fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft<FFT64>, a: &VecZnxDft<FFT64>, b: &MatZnxDft<FFT64>, tmp_bytes: &mut [u8]) {
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs()));
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size()));
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr());
@@ -535,12 +527,12 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
vmp::vmp_apply_dft_to_dft(
self.ptr,
c.ptr as *mut vec_znx_dft_t,
c.limbs() as u64,
c.size() as u64,
a.ptr as *const vec_znx_dft_t,
a.limbs() as u64,
a.size() as u64,
b.as_ptr() as *const vmp_pmat_t,
b.rows() as u64,
b.limbs() as u64,
b.size() as u64,
tmp_bytes.as_mut_ptr(),
)
}
@@ -553,7 +545,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
b: &MatZnxDft<FFT64>,
tmp_bytes: &mut [u8],
) {
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs()));
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size()));
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr());
@@ -562,19 +554,19 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
vmp::vmp_apply_dft_to_dft_add(
self.ptr,
c.ptr as *mut vec_znx_dft_t,
c.limbs() as u64,
c.size() as u64,
a.ptr as *const vec_znx_dft_t,
a.limbs() as u64,
a.size() as u64,
b.as_ptr() as *const vmp_pmat_t,
b.rows() as u64,
b.limbs() as u64,
b.size() as u64,
tmp_bytes.as_mut_ptr(),
)
}
}
fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft<FFT64>, a: &MatZnxDft<FFT64>, tmp_bytes: &mut [u8]) {
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(b.limbs(), b.limbs(), a.rows(), a.limbs()));
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(b.size(), b.size(), a.rows(), a.size()));
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr());
@@ -583,12 +575,12 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
vmp::vmp_apply_dft_to_dft(
self.ptr,
b.ptr as *mut vec_znx_dft_t,
b.limbs() as u64,
b.size() as u64,
b.ptr as *mut vec_znx_dft_t,
b.limbs() as u64,
b.size() as u64,
a.as_ptr() as *const vmp_pmat_t,
a.rows() as u64,
a.limbs() as u64,
a.size() as u64,
tmp_bytes.as_mut_ptr(),
)
}