Added more serialization tests + generalize methods to any n

This commit is contained in:
Pro7ech
2025-08-13 15:28:52 +02:00
parent 068470783e
commit 940742ce6c
117 changed files with 3658 additions and 2577 deletions

View File

@@ -1,7 +1,7 @@
use crate::{
alloc_aligned,
hal::{
api::{DataView, DataViewMut, FillUniform, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero},
api::{DataView, DataViewMut, FillUniform, Reset, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero},
layouts::{Data, DataMut, DataRef, ReaderFrom, VecZnx, WriterTo},
},
};
@@ -78,15 +78,13 @@ impl<D: Data> MatZnx<D> {
}
}
impl<D: DataRef> MatZnx<D> {
pub fn bytes_of(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
rows * cols_in * VecZnx::<Vec<u8>>::alloc_bytes::<i64>(n, cols_out, size)
impl MatZnx<Vec<u8>> {
pub fn alloc_bytes(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
rows * cols_in * VecZnx::<Vec<u8>>::alloc_bytes(n, cols_out, size)
}
}
impl<D: DataRef + From<Vec<u8>>> MatZnx<D> {
pub(crate) fn alloc(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
let data: Vec<u8> = alloc_aligned(Self::bytes_of(n, rows, cols_in, cols_out, size));
pub fn alloc(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
let data: Vec<u8> = alloc_aligned(Self::alloc_bytes(n, rows, cols_in, cols_out, size));
Self {
data: data.into(),
n,
@@ -97,16 +95,9 @@ impl<D: DataRef + From<Vec<u8>>> MatZnx<D> {
}
}
pub(crate) fn from_bytes(
n: usize,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
bytes: impl Into<Vec<u8>>,
) -> Self {
pub fn from_bytes(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
let data: Vec<u8> = bytes.into();
assert!(data.len() == Self::bytes_of(n, rows, cols_in, cols_out, size));
assert!(data.len() == Self::alloc_bytes(n, rows, cols_in, cols_out, size));
Self {
data: data.into(),
n,
@@ -127,7 +118,7 @@ impl<D: DataRef> MatZnx<D> {
}
let self_ref: MatZnx<&[u8]> = self.to_ref();
let nb_bytes: usize = VecZnx::<Vec<u8>>::alloc_bytes::<i64>(self.n, self.cols_out, self.size);
let nb_bytes: usize = VecZnx::<Vec<u8>>::alloc_bytes(self.n, self.cols_out, self.size);
let start: usize = nb_bytes * self.cols() * row + col * nb_bytes;
let end: usize = start + nb_bytes;
@@ -155,7 +146,7 @@ impl<D: DataMut> MatZnx<D> {
let size: usize = self.size();
let self_ref: MatZnx<&mut [u8]> = self.to_mut();
let nb_bytes: usize = VecZnx::<Vec<u8>>::alloc_bytes::<i64>(n, cols_out, size);
let nb_bytes: usize = VecZnx::<Vec<u8>>::alloc_bytes(n, cols_out, size);
let start: usize = nb_bytes * cols_in * row + col * nb_bytes;
let end: usize = start + nb_bytes;
@@ -175,6 +166,17 @@ impl<D: DataMut> FillUniform for MatZnx<D> {
}
}
impl<D: DataMut> Reset for MatZnx<D> {
fn reset(&mut self) {
self.zero();
self.n = 0;
self.size = 0;
self.rows = 0;
self.cols_in = 0;
self.cols_out = 0;
}
}
pub type MatZnxOwned = MatZnx<Vec<u8>>;
pub type MatZnxMut<'a> = MatZnx<&'a mut [u8]>;
pub type MatZnxRef<'a> = MatZnx<&'a [u8]>;