mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
everything compiles. Scratchpad not yet implemented
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
use crate::znx_base::{GetZnxBase, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize};
|
||||
use crate::{Backend, FFT64, Module, alloc_aligned};
|
||||
use crate::znx_base::{GetZnxBase, ZnxBase, ZnxInfos};
|
||||
use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxView, alloc_aligned};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx],
|
||||
@@ -8,68 +8,67 @@ use std::marker::PhantomData;
|
||||
///
|
||||
/// [MatZnxDft] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [MatZnxDft].
|
||||
/// See the trait [MatZnxDftOps] for additional information.
|
||||
pub struct MatZnxDft<B: Backend> {
|
||||
pub inner: ZnxBase,
|
||||
pub cols_in: usize,
|
||||
pub cols_out: usize,
|
||||
pub struct MatZnxDft<D, B> {
|
||||
data: D,
|
||||
n: usize,
|
||||
size: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
_marker: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> GetZnxBase for MatZnxDft<B> {
|
||||
fn znx(&self) -> &ZnxBase {
|
||||
&self.inner
|
||||
impl<D, B> ZnxInfos for MatZnxDft<D, B> {
|
||||
fn cols(&self) -> usize {
|
||||
self.cols_in
|
||||
}
|
||||
|
||||
fn znx_mut(&mut self) -> &mut ZnxBase {
|
||||
&mut self.inner
|
||||
fn rows(&self) -> usize {
|
||||
self.rows
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ZnxInfos for MatZnxDft<B> {}
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
|
||||
impl ZnxSliceSize for MatZnxDft<FFT64> {
|
||||
fn sl(&self) -> usize {
|
||||
self.n()
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxLayout for MatZnxDft<FFT64> {
|
||||
impl<D, B> DataView for MatZnxDft<D, B> {
|
||||
type D = D;
|
||||
fn data(&self) -> &Self::D {
|
||||
&self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, B> DataViewMut for MatZnxDft<D, B> {
|
||||
fn data_mut(&mut self) -> &mut Self::D {
|
||||
&mut self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: AsRef<[u8]>> ZnxView for MatZnxDft<D, FFT64> {
|
||||
type Scalar = f64;
|
||||
}
|
||||
|
||||
impl<B: Backend> MatZnxDft<B> {
|
||||
pub fn new(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
|
||||
let bytes: Vec<u8> = alloc_aligned(Self::bytes_of(module, rows, cols_in, cols_out, size));
|
||||
Self::from_bytes(module, rows, cols_in, cols_out, size, bytes)
|
||||
impl<D, B> MatZnxDft<D, B> {
|
||||
pub(crate) fn cols_in(&self) -> usize {
|
||||
self.cols_in
|
||||
}
|
||||
|
||||
pub fn from_bytes(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize, mut bytes: Vec<u8>) -> Self {
|
||||
let mut mat: MatZnxDft<B> = Self::from_bytes_borrow(module, rows, cols_in, cols_out, size, &mut bytes);
|
||||
mat.znx_mut().data = bytes;
|
||||
mat
|
||||
pub(crate) fn cols_out(&self) -> usize {
|
||||
self.cols_out
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_bytes_borrow(
|
||||
module: &Module<B>,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
bytes: &mut [u8],
|
||||
) -> Self {
|
||||
debug_assert_eq!(
|
||||
bytes.len(),
|
||||
Self::bytes_of(module, rows, cols_in, cols_out, size)
|
||||
);
|
||||
Self {
|
||||
inner: ZnxBase::from_bytes_borrow(module.n(), rows, cols_out, size, bytes),
|
||||
cols_in: cols_in,
|
||||
cols_out: cols_out,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bytes_of(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
impl<D: From<Vec<u8>>, B: Backend> MatZnxDft<D, B> {
|
||||
pub(crate) fn bytes_of(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
unsafe {
|
||||
crate::ffi::vmp::bytes_of_vmp_pmat(
|
||||
module.ptr,
|
||||
@@ -79,16 +78,62 @@ impl<B: Backend> MatZnxDft<B> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cols_in(&self) -> usize {
|
||||
self.cols_in
|
||||
pub(crate) fn new(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
|
||||
let data: Vec<u8> = alloc_aligned(Self::bytes_of(module, rows, cols_in, cols_out, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n: module.n(),
|
||||
size,
|
||||
rows,
|
||||
cols_in,
|
||||
cols_out,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cols_out(&self) -> usize {
|
||||
self.cols_out
|
||||
pub(crate) fn new_from_bytes(
|
||||
module: &Module<B>,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
bytes: impl Into<Vec<u8>>,
|
||||
) -> Self {
|
||||
let data: Vec<u8> = bytes.into();
|
||||
assert!(data.len() == Self::bytes_of(module, rows, cols_in, cols_out, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n: module.n(),
|
||||
size,
|
||||
rows,
|
||||
cols_in,
|
||||
cols_out,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
// pub fn from_bytes_borrow(
|
||||
// module: &Module<B>,
|
||||
// rows: usize,
|
||||
// cols_in: usize,
|
||||
// cols_out: usize,
|
||||
// size: usize,
|
||||
// bytes: &mut [u8],
|
||||
// ) -> Self {
|
||||
// debug_assert_eq!(
|
||||
// bytes.len(),
|
||||
// Self::bytes_of(module, rows, cols_in, cols_out, size)
|
||||
// );
|
||||
// Self {
|
||||
// inner: ZnxBase::from_bytes_borrow(module.n(), rows, cols_out, size, bytes),
|
||||
// cols_in: cols_in,
|
||||
// cols_out: cols_out,
|
||||
// _marker: PhantomData,
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
impl MatZnxDft<FFT64> {
|
||||
impl<D: AsRef<[u8]>> MatZnxDft<D, FFT64> {
|
||||
/// Returns a copy of the backend array at index (i, j) of the [MatZnxDft].
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -123,3 +168,5 @@ impl MatZnxDft<FFT64> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type MatZnxDftAllocOwned<B> = MatZnxDft<Vec<u8>, B>;
|
||||
|
||||
Reference in New Issue
Block a user