Applied discussed changes, everything working, but still to discuss

This commit is contained in:
Jean-Philippe Bossuat
2025-05-01 10:33:19 +02:00
parent 4e6fce3458
commit ca5e6d46c9
14 changed files with 710 additions and 508 deletions

View File

@@ -1,4 +1,4 @@
use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize};
use crate::znx_base::{GetZnxBase, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize};
use crate::{Backend, FFT64, Module, alloc_aligned};
use std::marker::PhantomData;
@@ -10,6 +10,8 @@ use std::marker::PhantomData;
/// See the trait [MatZnxDftOps] for additional information.
pub struct MatZnxDft<B: Backend> {
pub inner: ZnxBase,
pub cols_in: usize,
pub cols_out: usize,
_marker: PhantomData<B>,
}
@@ -35,18 +37,54 @@ impl ZnxLayout for MatZnxDft<FFT64> {
type Scalar = f64;
}
impl<B: Backend> ZnxAlloc<B> for MatZnxDft<B> {
type Scalar = u8;
impl<B: Backend> MatZnxDft<B> {
pub fn new(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
let bytes: Vec<u8> = alloc_aligned(Self::bytes_of(module, rows, cols_in, cols_out, size));
Self::from_bytes(module, rows, cols_in, cols_out, size, bytes)
}
fn from_bytes_borrow(module: &Module<B>, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
pub fn from_bytes(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize, mut bytes: Vec<u8>) -> Self {
let mut mat: MatZnxDft<B> = Self::from_bytes_borrow(module, rows, cols_in, cols_out, size, &mut bytes);
mat.znx_mut().data = bytes;
mat
}
pub fn from_bytes_borrow(
module: &Module<B>,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
bytes: &mut [u8],
) -> Self {
debug_assert_eq!(
bytes.len(),
Self::bytes_of(module, rows, cols_in, cols_out, size)
);
Self {
inner: ZnxBase::from_bytes_borrow(module.n(), rows, cols, size, bytes),
inner: ZnxBase::from_bytes_borrow(module.n(), rows, cols_out, size, bytes),
cols_in: cols_in,
cols_out: cols_out,
_marker: PhantomData,
}
}
fn bytes_of(module: &Module<B>, rows: usize, cols: usize, size: usize) -> usize {
unsafe { crate::ffi::vmp::bytes_of_vmp_pmat(module.ptr, rows as u64, size as u64) as usize * cols }
pub fn bytes_of(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
unsafe {
crate::ffi::vmp::bytes_of_vmp_pmat(
module.ptr,
(rows * cols_in) as u64,
(size * cols_out) as u64,
) as usize
}
}
pub fn cols_in(&self) -> usize {
self.cols_in
}
pub fn cols_out(&self) -> usize {
self.cols_out
}
}