use crate::{ alloc_aligned, hal::{ api::{DataView, DataViewMut, FillUniform, Reset, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero}, layouts::{Data, DataMut, DataRef, ReaderFrom, ToOwnedDeep, VecZnx, WriterTo}, }, }; use std::fmt; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use rand::RngCore; use sampling::source::Source; #[derive(PartialEq, Eq, Clone)] pub struct MatZnx { data: D, n: usize, size: usize, rows: usize, cols_in: usize, cols_out: usize, } impl ToOwnedDeep for MatZnx { type Owned = MatZnx>; fn to_owned_deep(&self) -> Self::Owned { MatZnx { data: self.data.as_ref().to_vec(), n: self.n, size: self.size, rows: self.rows, cols_in: self.cols_in, cols_out: self.cols_out, } } } impl fmt::Debug for MatZnx { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}", self) } } impl ZnxInfos for MatZnx { fn cols(&self) -> usize { self.cols_in } fn rows(&self) -> usize { self.rows } fn n(&self) -> usize { self.n } fn size(&self) -> usize { self.size } } impl ZnxSliceSize for MatZnx { fn sl(&self) -> usize { self.n() * self.cols_out() } } impl DataView for MatZnx { type D = D; fn data(&self) -> &Self::D { &self.data } } impl DataViewMut for MatZnx { fn data_mut(&mut self) -> &mut Self::D { &mut self.data } } impl ZnxView for MatZnx { type Scalar = i64; } impl MatZnx { pub fn cols_in(&self) -> usize { self.cols_in } pub fn cols_out(&self) -> usize { self.cols_out } } impl MatZnx> { pub fn alloc_bytes(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { rows * cols_in * VecZnx::>::alloc_bytes(n, cols_out, size) } pub fn alloc(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self { let data: Vec = alloc_aligned(Self::alloc_bytes(n, rows, cols_in, cols_out, size)); Self { data, n, size, rows, cols_in, cols_out, } } pub fn from_bytes(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); assert!(data.len() == Self::alloc_bytes(n, rows, cols_in, cols_out, size)); Self { data, n, size, rows, cols_in, cols_out, } } } impl MatZnx { pub fn at(&self, row: usize, col: usize) -> VecZnx<&[u8]> { #[cfg(debug_assertions)] { assert!(row < self.rows(), "rows: {} >= {}", row, self.rows()); assert!(col < self.cols_in(), "cols: {} >= {}", col, self.cols_in()); } let self_ref: MatZnx<&[u8]> = self.to_ref(); let nb_bytes: usize = VecZnx::>::alloc_bytes(self.n, self.cols_out, self.size); let start: usize = nb_bytes * self.cols() * row + col * nb_bytes; let end: usize = start + nb_bytes; VecZnx { data: &self_ref.data[start..end], n: self.n, cols: self.cols_out, size: self.size, max_size: self.size, } } } impl MatZnx { pub fn at_mut(&mut self, row: usize, col: usize) -> VecZnx<&mut [u8]> { #[cfg(debug_assertions)] { assert!(row < self.rows(), "rows: {} >= {}", row, self.rows()); assert!(col < self.cols_in(), "cols: {} >= {}", col, self.cols_in()); } let n: usize = self.n(); let cols_out: usize = self.cols_out(); let cols_in: usize = self.cols_in(); let size: usize = self.size(); let self_ref: MatZnx<&mut [u8]> = self.to_mut(); let nb_bytes: usize = VecZnx::>::alloc_bytes(n, cols_out, size); let start: usize = nb_bytes * cols_in * row + col * nb_bytes; let end: usize = start + nb_bytes; VecZnx { data: &mut self_ref.data[start..end], n, cols: cols_out, size, max_size: size, } } } impl FillUniform for MatZnx { fn fill_uniform(&mut self, source: &mut Source) { source.fill_bytes(self.data.as_mut()); } } impl Reset for MatZnx { fn reset(&mut self) { self.zero(); self.n = 0; self.size = 0; self.rows = 0; self.cols_in = 0; self.cols_out = 0; } } pub type MatZnxOwned = MatZnx>; pub type MatZnxMut<'a> = MatZnx<&'a mut [u8]>; pub type MatZnxRef<'a> = MatZnx<&'a [u8]>; pub trait MatZnxToRef { fn to_ref(&self) -> MatZnx<&[u8]>; } impl MatZnxToRef for MatZnx { fn to_ref(&self) -> MatZnx<&[u8]> { MatZnx { data: self.data.as_ref(), n: self.n, rows: self.rows, cols_in: self.cols_in, cols_out: self.cols_out, size: self.size, } } } pub trait MatZnxToMut { fn to_mut(&mut self) -> MatZnx<&mut [u8]>; } impl MatZnxToMut for MatZnx { fn to_mut(&mut self) -> MatZnx<&mut [u8]> { MatZnx { data: self.data.as_mut(), n: self.n, rows: self.rows, cols_in: self.cols_in, cols_out: self.cols_out, size: self.size, } } } impl MatZnx { pub(crate) fn from_data(data: D, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self { Self { data, n, rows, cols_in, cols_out, size, } } } impl ReaderFrom for MatZnx { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { self.n = reader.read_u64::()? as usize; self.size = reader.read_u64::()? as usize; self.rows = reader.read_u64::()? as usize; self.cols_in = reader.read_u64::()? as usize; self.cols_out = reader.read_u64::()? as usize; let len: usize = reader.read_u64::()? as usize; let buf: &mut [u8] = self.data.as_mut(); if buf.len() != len { return Err(std::io::Error::new( std::io::ErrorKind::UnexpectedEof, format!("self.data.len()={} != read len={}", buf.len(), len), )); } reader.read_exact(&mut buf[..len])?; Ok(()) } } impl WriterTo for MatZnx { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { writer.write_u64::(self.n as u64)?; writer.write_u64::(self.size as u64)?; writer.write_u64::(self.rows as u64)?; writer.write_u64::(self.cols_in as u64)?; writer.write_u64::(self.cols_out as u64)?; let buf: &[u8] = self.data.as_ref(); writer.write_u64::(buf.len() as u64)?; writer.write_all(buf)?; Ok(()) } } impl fmt::Display for MatZnx { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!( f, "MatZnx(n={}, rows={}, cols_in={}, cols_out={}, size={})", self.n, self.rows, self.cols_in, self.cols_out, self.size )?; for row_i in 0..self.rows { writeln!(f, "Row {}:", row_i)?; for col_i in 0..self.cols_in { writeln!(f, "cols_in {}:", col_i)?; writeln!(f, "{}:", self.at(row_i, col_i))?; } } Ok(()) } } impl ZnxZero for MatZnx { fn zero(&mut self) { self.raw_mut().fill(0) } fn zero_at(&mut self, i: usize, j: usize) { self.at_mut(i, j).zero(); } }