use std::{ hash::{DefaultHasher, Hasher}, marker::PhantomData, }; use crate::{ alloc_aligned, layouts::{Backend, Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, ZnxInfos, ZnxView}, oep::VmpPMatAllocBytesImpl, }; #[repr(C)] #[derive(PartialEq, Eq, Hash)] pub struct VmpPMat { data: D, n: usize, size: usize, rows: usize, cols_in: usize, cols_out: usize, _phantom: PhantomData, } impl DigestU64 for VmpPMat { fn digest_u64(&self) -> u64 { let mut h: DefaultHasher = DefaultHasher::new(); h.write(self.data.as_ref()); h.write_usize(self.n); h.write_usize(self.size); h.write_usize(self.rows); h.write_usize(self.cols_in); h.write_usize(self.cols_out); h.finish() } } impl ZnxView for VmpPMat { type Scalar = B::ScalarPrep; } impl ZnxInfos for VmpPMat { fn cols(&self) -> usize { self.cols_in } fn rows(&self) -> usize { self.rows } fn n(&self) -> usize { self.n } fn size(&self) -> usize { self.size } fn poly_count(&self) -> usize { self.rows() * self.cols_in() * self.size() * self.cols_out() } } impl DataView for VmpPMat { type D = D; fn data(&self) -> &Self::D { &self.data } } impl DataViewMut for VmpPMat { fn data_mut(&mut self) -> &mut Self::D { &mut self.data } } impl VmpPMat { pub fn cols_in(&self) -> usize { self.cols_in } pub fn cols_out(&self) -> usize { self.cols_out } } impl>, B: Backend> VmpPMat where B: VmpPMatAllocBytesImpl, { pub fn alloc(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self { let data: Vec = alloc_aligned(B::vmp_pmat_bytes_of_impl(n, rows, cols_in, cols_out, size)); Self { data: data.into(), n, size, rows, cols_in, cols_out, _phantom: PhantomData, } } pub fn from_bytes(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); assert!(data.len() == B::vmp_pmat_bytes_of_impl(n, rows, cols_in, cols_out, size)); Self { data: data.into(), n, size, rows, cols_in, cols_out, _phantom: PhantomData, } } } pub type VmpPMatOwned = VmpPMat, B>; pub type VmpPMatRef<'a, B> = VmpPMat<&'a [u8], B>; pub trait VmpPMatToRef { fn to_ref(&self) -> VmpPMat<&[u8], B>; } impl VmpPMatToRef for VmpPMat { fn to_ref(&self) -> VmpPMat<&[u8], B> { VmpPMat { data: self.data.as_ref(), n: self.n, rows: self.rows, cols_in: self.cols_in, cols_out: self.cols_out, size: self.size, _phantom: std::marker::PhantomData, } } } pub trait VmpPMatToMut { fn to_mut(&mut self) -> VmpPMat<&mut [u8], B>; } impl VmpPMatToMut for VmpPMat { fn to_mut(&mut self) -> VmpPMat<&mut [u8], B> { VmpPMat { data: self.data.as_mut(), n: self.n, rows: self.rows, cols_in: self.cols_in, cols_out: self.cols_out, size: self.size, _phantom: std::marker::PhantomData, } } } impl VmpPMat { pub fn from_data(data: D, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self { Self { data, n, rows, cols_in, cols_out, size, _phantom: PhantomData, } } }