use std::marker::PhantomData; use rand_distr::num_traits::Zero; use crate::{ alloc_aligned, hal::{ api::{DataView, DataViewMut, ZnxInfos, ZnxView, ZnxViewMut, ZnxZero}, layouts::{Backend, Data, DataMut, DataRef}, }, }; #[derive(PartialEq, Eq)] pub struct VecZnxBig { pub(crate) data: D, pub(crate) n: usize, pub(crate) cols: usize, pub(crate) size: usize, pub(crate) max_size: usize, pub(crate) _phantom: PhantomData, } impl ZnxInfos for VecZnxBig { fn cols(&self) -> usize { self.cols } fn rows(&self) -> usize { 1 } fn n(&self) -> usize { self.n } fn size(&self) -> usize { self.size } } impl DataView for VecZnxBig { type D = D; fn data(&self) -> &Self::D { &self.data } } impl DataViewMut for VecZnxBig { fn data_mut(&mut self) -> &mut Self::D { &mut self.data } } pub trait VecZnxBigBytesOf { fn bytes_of(n: usize, cols: usize, size: usize) -> usize; } impl ZnxZero for VecZnxBig where Self: ZnxViewMut, ::Scalar: Zero + Copy, { fn zero(&mut self) { self.raw_mut().fill(::Scalar::zero()) } fn zero_at(&mut self, i: usize, j: usize) { self.at_mut(i, j).fill(::Scalar::zero()); } } impl>, B: Backend> VecZnxBig where VecZnxBig: VecZnxBigBytesOf, { pub(crate) fn new(n: usize, cols: usize, size: usize) -> Self { let data = alloc_aligned::(Self::bytes_of(n, cols, size)); Self { data: data.into(), n, cols, size, max_size: size, _phantom: PhantomData, } } pub(crate) fn new_from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); assert!(data.len() == Self::bytes_of(n, cols, size)); Self { data: data.into(), n, cols, size, max_size: size, _phantom: PhantomData, } } } impl VecZnxBig { pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self { Self { data, n, cols, size, max_size: size, _phantom: PhantomData, } } } pub type VecZnxBigOwned = VecZnxBig, B>; pub trait VecZnxBigToRef { fn to_ref(&self) -> VecZnxBig<&[u8], B>; } impl VecZnxBigToRef for VecZnxBig { fn to_ref(&self) -> VecZnxBig<&[u8], B> { VecZnxBig { data: self.data.as_ref(), n: self.n, cols: self.cols, size: self.size, max_size: self.max_size, _phantom: std::marker::PhantomData, } } } pub trait VecZnxBigToMut { fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B>; } impl VecZnxBigToMut for VecZnxBig { fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> { VecZnxBig { data: self.data.as_mut(), n: self.n, cols: self.cols, size: self.size, max_size: self.max_size, _phantom: std::marker::PhantomData, } } }