everything compiles. Scratchpad not yet implemented

This commit is contained in:
Janmajaya Mall
2025-05-03 16:37:20 +05:30
parent 3ed6fa8ab5
commit ff8370e023
19 changed files with 919 additions and 504 deletions

View File

@@ -1,5 +1,5 @@
use crate::znx_base::{GetZnxBase, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize};
use crate::{Backend, FFT64, Module, alloc_aligned};
use crate::znx_base::{GetZnxBase, ZnxBase, ZnxInfos};
use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxView, alloc_aligned};
use std::marker::PhantomData;
/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx],
@@ -8,68 +8,67 @@ use std::marker::PhantomData;
///
/// [MatZnxDft] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [MatZnxDft].
/// See the trait [MatZnxDftOps] for additional information.
pub struct MatZnxDft<B: Backend> {
pub inner: ZnxBase,
pub cols_in: usize,
pub cols_out: usize,
pub struct MatZnxDft<D, B> {
data: D,
n: usize,
size: usize,
rows: usize,
cols_in: usize,
cols_out: usize,
_marker: PhantomData<B>,
}
impl<B: Backend> GetZnxBase for MatZnxDft<B> {
fn znx(&self) -> &ZnxBase {
&self.inner
impl<D, B> ZnxInfos for MatZnxDft<D, B> {
fn cols(&self) -> usize {
self.cols_in
}
fn znx_mut(&mut self) -> &mut ZnxBase {
&mut self.inner
fn rows(&self) -> usize {
self.rows
}
}
impl<B: Backend> ZnxInfos for MatZnxDft<B> {}
fn n(&self) -> usize {
self.n
}
fn size(&self) -> usize {
self.size
}
impl ZnxSliceSize for MatZnxDft<FFT64> {
fn sl(&self) -> usize {
self.n()
}
}
impl ZnxLayout for MatZnxDft<FFT64> {
impl<D, B> DataView for MatZnxDft<D, B> {
type D = D;
fn data(&self) -> &Self::D {
&self.data
}
}
impl<D, B> DataViewMut for MatZnxDft<D, B> {
fn data_mut(&mut self) -> &mut Self::D {
&mut self.data
}
}
impl<D: AsRef<[u8]>> ZnxView for MatZnxDft<D, FFT64> {
type Scalar = f64;
}
impl<B: Backend> MatZnxDft<B> {
pub fn new(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
let bytes: Vec<u8> = alloc_aligned(Self::bytes_of(module, rows, cols_in, cols_out, size));
Self::from_bytes(module, rows, cols_in, cols_out, size, bytes)
impl<D, B> MatZnxDft<D, B> {
pub(crate) fn cols_in(&self) -> usize {
self.cols_in
}
pub fn from_bytes(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize, mut bytes: Vec<u8>) -> Self {
let mut mat: MatZnxDft<B> = Self::from_bytes_borrow(module, rows, cols_in, cols_out, size, &mut bytes);
mat.znx_mut().data = bytes;
mat
pub(crate) fn cols_out(&self) -> usize {
self.cols_out
}
}
pub fn from_bytes_borrow(
module: &Module<B>,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
bytes: &mut [u8],
) -> Self {
debug_assert_eq!(
bytes.len(),
Self::bytes_of(module, rows, cols_in, cols_out, size)
);
Self {
inner: ZnxBase::from_bytes_borrow(module.n(), rows, cols_out, size, bytes),
cols_in: cols_in,
cols_out: cols_out,
_marker: PhantomData,
}
}
pub fn bytes_of(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
impl<D: From<Vec<u8>>, B: Backend> MatZnxDft<D, B> {
pub(crate) fn bytes_of(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
unsafe {
crate::ffi::vmp::bytes_of_vmp_pmat(
module.ptr,
@@ -79,16 +78,62 @@ impl<B: Backend> MatZnxDft<B> {
}
}
pub fn cols_in(&self) -> usize {
self.cols_in
pub(crate) fn new(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
let data: Vec<u8> = alloc_aligned(Self::bytes_of(module, rows, cols_in, cols_out, size));
Self {
data: data.into(),
n: module.n(),
size,
rows,
cols_in,
cols_out,
_marker: PhantomData,
}
}
pub fn cols_out(&self) -> usize {
self.cols_out
pub(crate) fn new_from_bytes(
module: &Module<B>,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
bytes: impl Into<Vec<u8>>,
) -> Self {
let data: Vec<u8> = bytes.into();
assert!(data.len() == Self::bytes_of(module, rows, cols_in, cols_out, size));
Self {
data: data.into(),
n: module.n(),
size,
rows,
cols_in,
cols_out,
_marker: PhantomData,
}
}
// pub fn from_bytes_borrow(
// module: &Module<B>,
// rows: usize,
// cols_in: usize,
// cols_out: usize,
// size: usize,
// bytes: &mut [u8],
// ) -> Self {
// debug_assert_eq!(
// bytes.len(),
// Self::bytes_of(module, rows, cols_in, cols_out, size)
// );
// Self {
// inner: ZnxBase::from_bytes_borrow(module.n(), rows, cols_out, size, bytes),
// cols_in: cols_in,
// cols_out: cols_out,
// _marker: PhantomData,
// }
// }
}
impl MatZnxDft<FFT64> {
impl<D: AsRef<[u8]>> MatZnxDft<D, FFT64> {
/// Returns a copy of the backend array at index (i, j) of the [MatZnxDft].
///
/// # Arguments
@@ -123,3 +168,5 @@ impl MatZnxDft<FFT64> {
}
}
}
pub type MatZnxDftAllocOwned<B> = MatZnxDft<Vec<u8>, B>;