use std::marker::PhantomData; use rand_distr::num_traits::Zero; use crate::{ alloc_aligned, hal::{ api::{DataView, DataViewMut, ZnxInfos, ZnxView, ZnxViewMut, ZnxZero}, layouts::{Backend, Data, DataMut, DataRef, VecZnxBig}, }, }; #[derive(PartialEq, Eq)] pub struct VecZnxDft { pub(crate) data: D, pub(crate) n: usize, pub(crate) cols: usize, pub(crate) size: usize, pub(crate) max_size: usize, pub(crate) _phantom: PhantomData, } impl VecZnxDft { pub fn into_big(self) -> VecZnxBig { VecZnxBig::::from_data(self.data, self.n, self.cols, self.size) } } impl ZnxInfos for VecZnxDft { fn cols(&self) -> usize { self.cols } fn rows(&self) -> usize { 1 } fn n(&self) -> usize { self.n } fn size(&self) -> usize { self.size } } impl DataView for VecZnxDft { type D = D; fn data(&self) -> &Self::D { &self.data } } impl DataViewMut for VecZnxDft { fn data_mut(&mut self) -> &mut Self::D { &mut self.data } } impl VecZnxDft { pub fn max_size(&self) -> usize { self.max_size } } impl VecZnxDft { pub fn set_size(&mut self, size: usize) { assert!(size <= self.max_size); self.size = size } } impl ZnxZero for VecZnxDft where Self: ZnxViewMut, ::Scalar: Zero + Copy, { fn zero(&mut self) { self.raw_mut().fill(::Scalar::zero()) } fn zero_at(&mut self, i: usize, j: usize) { self.at_mut(i, j).fill(::Scalar::zero()); } } pub trait VecZnxDftBytesOf { fn bytes_of(n: usize, cols: usize, size: usize) -> usize; } impl>, B: Backend> VecZnxDft where VecZnxDft: VecZnxDftBytesOf, { pub(crate) fn alloc(n: usize, cols: usize, size: usize) -> Self { let data: Vec = alloc_aligned::(Self::bytes_of(n, cols, size)); Self { data: data.into(), n, cols, size, max_size: size, _phantom: PhantomData, } } pub(crate) fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); assert!(data.len() == Self::bytes_of(n, cols, size)); Self { data: data.into(), n, cols, size, max_size: size, _phantom: PhantomData, } } } pub type VecZnxDftOwned = VecZnxDft, B>; impl VecZnxDft { pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self { Self { data, n, cols, size, max_size: size, _phantom: PhantomData, } } } pub trait VecZnxDftToRef { fn to_ref(&self) -> VecZnxDft<&[u8], B>; } impl VecZnxDftToRef for VecZnxDft { fn to_ref(&self) -> VecZnxDft<&[u8], B> { VecZnxDft { data: self.data.as_ref(), n: self.n, cols: self.cols, size: self.size, max_size: self.max_size, _phantom: std::marker::PhantomData, } } } pub trait VecZnxDftToMut { fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B>; } impl VecZnxDftToMut for VecZnxDft { fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { VecZnxDft { data: self.data.as_mut(), n: self.n, cols: self.cols, size: self.size, max_size: self.max_size, _phantom: std::marker::PhantomData, } } }