mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
Applied discussed changes, everything working, but still to discuss
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize};
|
||||
use crate::znx_base::{GetZnxBase, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize};
|
||||
use crate::{Backend, FFT64, Module, alloc_aligned};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
@@ -10,6 +10,8 @@ use std::marker::PhantomData;
|
||||
/// See the trait [MatZnxDftOps] for additional information.
|
||||
pub struct MatZnxDft<B: Backend> {
|
||||
pub inner: ZnxBase,
|
||||
pub cols_in: usize,
|
||||
pub cols_out: usize,
|
||||
_marker: PhantomData<B>,
|
||||
}
|
||||
|
||||
@@ -35,18 +37,54 @@ impl ZnxLayout for MatZnxDft<FFT64> {
|
||||
type Scalar = f64;
|
||||
}
|
||||
|
||||
impl<B: Backend> ZnxAlloc<B> for MatZnxDft<B> {
|
||||
type Scalar = u8;
|
||||
impl<B: Backend> MatZnxDft<B> {
|
||||
pub fn new(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
|
||||
let bytes: Vec<u8> = alloc_aligned(Self::bytes_of(module, rows, cols_in, cols_out, size));
|
||||
Self::from_bytes(module, rows, cols_in, cols_out, size, bytes)
|
||||
}
|
||||
|
||||
fn from_bytes_borrow(module: &Module<B>, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
|
||||
pub fn from_bytes(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize, mut bytes: Vec<u8>) -> Self {
|
||||
let mut mat: MatZnxDft<B> = Self::from_bytes_borrow(module, rows, cols_in, cols_out, size, &mut bytes);
|
||||
mat.znx_mut().data = bytes;
|
||||
mat
|
||||
}
|
||||
|
||||
pub fn from_bytes_borrow(
|
||||
module: &Module<B>,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
bytes: &mut [u8],
|
||||
) -> Self {
|
||||
debug_assert_eq!(
|
||||
bytes.len(),
|
||||
Self::bytes_of(module, rows, cols_in, cols_out, size)
|
||||
);
|
||||
Self {
|
||||
inner: ZnxBase::from_bytes_borrow(module.n(), rows, cols, size, bytes),
|
||||
inner: ZnxBase::from_bytes_borrow(module.n(), rows, cols_out, size, bytes),
|
||||
cols_in: cols_in,
|
||||
cols_out: cols_out,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn bytes_of(module: &Module<B>, rows: usize, cols: usize, size: usize) -> usize {
|
||||
unsafe { crate::ffi::vmp::bytes_of_vmp_pmat(module.ptr, rows as u64, size as u64) as usize * cols }
|
||||
pub fn bytes_of(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
unsafe {
|
||||
crate::ffi::vmp::bytes_of_vmp_pmat(
|
||||
module.ptr,
|
||||
(rows * cols_in) as u64,
|
||||
(size * cols_out) as u64,
|
||||
) as usize
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cols_in(&self) -> usize {
|
||||
self.cols_in
|
||||
}
|
||||
|
||||
pub fn cols_out(&self) -> usize {
|
||||
self.cols_out
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user