mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
refactoring of vec_znx
This commit is contained in:
@@ -22,7 +22,7 @@ pub struct MatZnxDft<B: Backend> {
|
||||
/// Number of cols
|
||||
cols: usize,
|
||||
/// The number of small polynomials
|
||||
limbs: usize,
|
||||
size: usize,
|
||||
_marker: PhantomData<B>,
|
||||
}
|
||||
|
||||
@@ -31,10 +31,6 @@ impl<B: Backend> ZnxInfos for MatZnxDft<B> {
|
||||
self.n
|
||||
}
|
||||
|
||||
fn log_n(&self) -> usize {
|
||||
(usize::BITS - (self.n() - 1).leading_zeros()) as _
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
self.rows
|
||||
}
|
||||
@@ -43,18 +39,14 @@ impl<B: Backend> ZnxInfos for MatZnxDft<B> {
|
||||
self.cols
|
||||
}
|
||||
|
||||
fn limbs(&self) -> usize {
|
||||
self.limbs
|
||||
}
|
||||
|
||||
fn poly_count(&self) -> usize {
|
||||
self.rows * self.cols * self.limbs
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
}
|
||||
|
||||
impl MatZnxDft<FFT64> {
|
||||
fn new(module: &Module<FFT64>, rows: usize, cols: usize, limbs: usize) -> MatZnxDft<FFT64> {
|
||||
let mut data: Vec<u8> = alloc_aligned::<u8>(module.bytes_of_mat_znx_dft(rows, cols, limbs));
|
||||
fn new(module: &Module<FFT64>, rows: usize, cols: usize, size: usize) -> MatZnxDft<FFT64> {
|
||||
let mut data: Vec<u8> = alloc_aligned::<u8>(module.bytes_of_mat_znx_dft(rows, cols, size));
|
||||
let ptr: *mut u8 = data.as_mut_ptr();
|
||||
MatZnxDft::<FFT64> {
|
||||
data: data,
|
||||
@@ -62,7 +54,7 @@ impl MatZnxDft<FFT64> {
|
||||
n: module.n(),
|
||||
rows: rows,
|
||||
cols: cols,
|
||||
limbs: limbs,
|
||||
size: size,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
@@ -115,7 +107,7 @@ impl MatZnxDft<FFT64> {
|
||||
|
||||
fn at_block(&self, row: usize, col: usize, blk: usize) -> &[f64] {
|
||||
let nrows: usize = self.rows();
|
||||
let nsize: usize = self.limbs();
|
||||
let nsize: usize = self.size();
|
||||
if col == (nsize - 1) && (nsize & 1 == 1) {
|
||||
&self.raw()[blk * nrows * nsize * 8 + col * nrows * 8 + row * 8..]
|
||||
} else {
|
||||
@@ -127,7 +119,7 @@ impl MatZnxDft<FFT64> {
|
||||
/// This trait implements methods for vector matrix product,
|
||||
/// that is, multiplying a [VecZnx] with a [VmpPMat].
|
||||
pub trait MatZnxDftOps<B: Backend> {
|
||||
fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, limbs: usize) -> usize;
|
||||
fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> usize;
|
||||
|
||||
/// Allocates a new [VmpPMat] with the given number of rows and columns.
|
||||
///
|
||||
@@ -135,7 +127,7 @@ pub trait MatZnxDftOps<B: Backend> {
|
||||
///
|
||||
/// * `rows`: number of rows (number of [VecZnxDft]).
|
||||
/// * `size`: number of size (number of size of each [VecZnxDft]).
|
||||
fn new_mat_znx_dft(&self, rows: usize, cols: usize, limbs: usize) -> MatZnxDft<B>;
|
||||
fn new_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> MatZnxDft<B>;
|
||||
|
||||
/// Returns the number of bytes needed as scratch space for [VmpPMatOps::vmp_prepare_contiguous].
|
||||
///
|
||||
@@ -351,12 +343,12 @@ pub trait MatZnxDftOps<B: Backend> {
|
||||
}
|
||||
|
||||
impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
fn new_mat_znx_dft(&self, rows: usize, cols: usize, limbs: usize) -> MatZnxDft<FFT64> {
|
||||
MatZnxDft::<FFT64>::new(self, rows, cols, limbs)
|
||||
fn new_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> MatZnxDft<FFT64> {
|
||||
MatZnxDft::<FFT64>::new(self, rows, cols, size)
|
||||
}
|
||||
|
||||
fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, limbs: usize) -> usize {
|
||||
unsafe { vmp::bytes_of_vmp_pmat(self.ptr, rows as u64, (limbs * cols) as u64) as usize }
|
||||
fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> usize {
|
||||
unsafe { vmp::bytes_of_vmp_pmat(self.ptr, rows as u64, (size * cols) as u64) as usize }
|
||||
}
|
||||
|
||||
fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize, size: usize) -> usize {
|
||||
@@ -367,7 +359,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), b.n() * b.poly_count());
|
||||
assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.limbs()));
|
||||
assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.size()));
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
@@ -376,7 +368,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
b.as_mut_ptr() as *mut vmp_pmat_t,
|
||||
a.as_ptr(),
|
||||
b.rows() as u64,
|
||||
(b.limbs() * b.cols()) as u64,
|
||||
(b.size() * b.cols()) as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
);
|
||||
}
|
||||
@@ -385,8 +377,8 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
fn vmp_prepare_row(&self, b: &mut MatZnxDft<FFT64>, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), b.limbs() * self.n() * b.cols());
|
||||
assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.limbs()));
|
||||
assert_eq!(a.len(), b.size() * self.n() * b.cols());
|
||||
assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.size()));
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
@@ -396,7 +388,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
a.as_ptr(),
|
||||
row_i as u64,
|
||||
b.rows() as u64,
|
||||
(b.limbs() * b.cols()) as u64,
|
||||
(b.size() * b.cols()) as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
);
|
||||
}
|
||||
@@ -406,7 +398,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), b.n());
|
||||
assert_eq!(a.limbs(), b.limbs());
|
||||
assert_eq!(a.size(), b.size());
|
||||
assert_eq!(a.cols(), b.cols());
|
||||
}
|
||||
unsafe {
|
||||
@@ -416,7 +408,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
a.as_ptr() as *const vmp_pmat_t,
|
||||
row_i as u64,
|
||||
a.rows() as u64,
|
||||
(a.limbs() * a.cols()) as u64,
|
||||
(a.size() * a.cols()) as u64,
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -425,7 +417,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), b.n());
|
||||
assert_eq!(a.limbs(), b.limbs());
|
||||
assert_eq!(a.size(), b.size());
|
||||
}
|
||||
unsafe {
|
||||
vmp::vmp_prepare_row_dft(
|
||||
@@ -434,7 +426,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
a.ptr as *const vec_znx_dft_t,
|
||||
row_i as u64,
|
||||
b.rows() as u64,
|
||||
b.limbs() as u64,
|
||||
b.size() as u64,
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -443,7 +435,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), b.n());
|
||||
assert_eq!(a.limbs(), b.limbs());
|
||||
assert_eq!(a.size(), b.size());
|
||||
}
|
||||
unsafe {
|
||||
vmp::vmp_extract_row_dft(
|
||||
@@ -452,7 +444,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
a.as_ptr() as *const vmp_pmat_t,
|
||||
row_i as u64,
|
||||
a.rows() as u64,
|
||||
a.limbs() as u64,
|
||||
a.size() as u64,
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -470,7 +462,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
}
|
||||
|
||||
fn vmp_apply_dft(&self, c: &mut VecZnxDft<FFT64>, a: &VecZnx, b: &MatZnxDft<FFT64>, tmp_bytes: &mut [u8]) {
|
||||
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs()));
|
||||
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size()));
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
@@ -479,20 +471,20 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
vmp::vmp_apply_dft(
|
||||
self.ptr,
|
||||
c.ptr as *mut vec_znx_dft_t,
|
||||
c.limbs() as u64,
|
||||
c.size() as u64,
|
||||
a.as_ptr(),
|
||||
a.limbs() as u64,
|
||||
a.size() as u64,
|
||||
(a.n() * a.cols()) as u64,
|
||||
b.as_ptr() as *const vmp_pmat_t,
|
||||
b.rows() as u64,
|
||||
b.limbs() as u64,
|
||||
b.size() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_apply_dft_add(&self, c: &mut VecZnxDft<FFT64>, a: &VecZnx, b: &MatZnxDft<FFT64>, tmp_bytes: &mut [u8]) {
|
||||
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs()));
|
||||
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size()));
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
@@ -501,13 +493,13 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
vmp::vmp_apply_dft_add(
|
||||
self.ptr,
|
||||
c.ptr as *mut vec_znx_dft_t,
|
||||
c.limbs() as u64,
|
||||
c.size() as u64,
|
||||
a.as_ptr(),
|
||||
a.limbs() as u64,
|
||||
(a.n() * a.limbs()) as u64,
|
||||
a.size() as u64,
|
||||
(a.n() * a.size()) as u64,
|
||||
b.as_ptr() as *const vmp_pmat_t,
|
||||
b.rows() as u64,
|
||||
b.limbs() as u64,
|
||||
b.size() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
@@ -526,7 +518,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
}
|
||||
|
||||
fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft<FFT64>, a: &VecZnxDft<FFT64>, b: &MatZnxDft<FFT64>, tmp_bytes: &mut [u8]) {
|
||||
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs()));
|
||||
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size()));
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
@@ -535,12 +527,12 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
vmp::vmp_apply_dft_to_dft(
|
||||
self.ptr,
|
||||
c.ptr as *mut vec_znx_dft_t,
|
||||
c.limbs() as u64,
|
||||
c.size() as u64,
|
||||
a.ptr as *const vec_znx_dft_t,
|
||||
a.limbs() as u64,
|
||||
a.size() as u64,
|
||||
b.as_ptr() as *const vmp_pmat_t,
|
||||
b.rows() as u64,
|
||||
b.limbs() as u64,
|
||||
b.size() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
@@ -553,7 +545,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
b: &MatZnxDft<FFT64>,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs()));
|
||||
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size()));
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
@@ -562,19 +554,19 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
vmp::vmp_apply_dft_to_dft_add(
|
||||
self.ptr,
|
||||
c.ptr as *mut vec_znx_dft_t,
|
||||
c.limbs() as u64,
|
||||
c.size() as u64,
|
||||
a.ptr as *const vec_znx_dft_t,
|
||||
a.limbs() as u64,
|
||||
a.size() as u64,
|
||||
b.as_ptr() as *const vmp_pmat_t,
|
||||
b.rows() as u64,
|
||||
b.limbs() as u64,
|
||||
b.size() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft<FFT64>, a: &MatZnxDft<FFT64>, tmp_bytes: &mut [u8]) {
|
||||
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(b.limbs(), b.limbs(), a.rows(), a.limbs()));
|
||||
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(b.size(), b.size(), a.rows(), a.size()));
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
@@ -583,12 +575,12 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
vmp::vmp_apply_dft_to_dft(
|
||||
self.ptr,
|
||||
b.ptr as *mut vec_znx_dft_t,
|
||||
b.limbs() as u64,
|
||||
b.size() as u64,
|
||||
b.ptr as *mut vec_znx_dft_t,
|
||||
b.limbs() as u64,
|
||||
b.size() as u64,
|
||||
a.as_ptr() as *const vmp_pmat_t,
|
||||
a.rows() as u64,
|
||||
a.limbs() as u64,
|
||||
a.size() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user