use std::fmt; use crate::{ alloc_aligned, hal::{ api::{DataView, DataViewMut, FillUniform, Reset, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero}, layouts::{Data, DataMut, DataRef, ReaderFrom, WriterTo}, }, }; #[derive(PartialEq, Eq, Clone)] pub struct VecZnx { pub(crate) data: D, pub(crate) n: usize, pub(crate) cols: usize, pub(crate) size: usize, pub(crate) max_size: usize, } impl fmt::Debug for VecZnx { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}", self) } } impl ZnxInfos for VecZnx { fn cols(&self) -> usize { self.cols } fn rows(&self) -> usize { 1 } fn n(&self) -> usize { self.n } fn size(&self) -> usize { self.size } } impl ZnxSliceSize for VecZnx { fn sl(&self) -> usize { self.n() * self.cols() } } impl DataView for VecZnx { type D = D; fn data(&self) -> &Self::D { &self.data } } impl DataViewMut for VecZnx { fn data_mut(&mut self) -> &mut Self::D { &mut self.data } } impl ZnxView for VecZnx { type Scalar = i64; } impl VecZnx> { pub fn rsh_scratch_space(n: usize) -> usize { n * std::mem::size_of::() } } impl ZnxZero for VecZnx { fn zero(&mut self) { self.raw_mut().fill(0) } fn zero_at(&mut self, i: usize, j: usize) { self.at_mut(i, j).fill(0); } } impl VecZnx> { pub fn alloc_bytes(n: usize, cols: usize, size: usize) -> usize { n * cols * size * size_of::() } pub fn alloc(n: usize, cols: usize, size: usize) -> Self { let data: Vec = alloc_aligned::(Self::alloc_bytes(n, cols, size)); Self { data: data.into(), n, cols, size, max_size: size, } } pub fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); assert!(data.len() == Self::alloc_bytes(n, cols, size)); Self { data: data.into(), n, cols, size, max_size: size, } } } impl VecZnx { pub fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self { Self { data, n, cols, size, max_size: size, } } } impl fmt::Display for VecZnx { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!( f, "VecZnx(n={}, cols={}, size={})", self.n, self.cols, self.size )?; for col in 0..self.cols { writeln!(f, "Column {}:", col)?; for size in 0..self.size { let coeffs = self.at(col, size); write!(f, " Size {}: [", size)?; let max_show = 100; let show_count = coeffs.len().min(max_show); for (i, &coeff) in coeffs.iter().take(show_count).enumerate() { if i > 0 { write!(f, ", ")?; } write!(f, "{}", coeff)?; } if coeffs.len() > max_show { write!(f, ", ... ({} more)", coeffs.len() - max_show)?; } writeln!(f, "]")?; } } Ok(()) } } impl FillUniform for VecZnx { fn fill_uniform(&mut self, source: &mut Source) { source.fill_bytes(self.data.as_mut()); } } impl Reset for VecZnx { fn reset(&mut self) { self.zero(); self.n = 0; self.cols = 0; self.size = 0; self.max_size = 0; } } pub type VecZnxOwned = VecZnx>; pub type VecZnxMut<'a> = VecZnx<&'a mut [u8]>; pub type VecZnxRef<'a> = VecZnx<&'a [u8]>; pub trait VecZnxToRef { fn to_ref(&self) -> VecZnx<&[u8]>; } impl VecZnxToRef for VecZnx { fn to_ref(&self) -> VecZnx<&[u8]> { VecZnx { data: self.data.as_ref(), n: self.n, cols: self.cols, size: self.size, max_size: self.max_size, } } } pub trait VecZnxToMut { fn to_mut(&mut self) -> VecZnx<&mut [u8]>; } impl VecZnxToMut for VecZnx { fn to_mut(&mut self) -> VecZnx<&mut [u8]> { VecZnx { data: self.data.as_mut(), n: self.n, cols: self.cols, size: self.size, max_size: self.max_size, } } } impl VecZnx { pub fn clone(&self) -> VecZnx> { let self_ref: VecZnx<&[u8]> = self.to_ref(); VecZnx { data: self_ref.data.to_vec(), n: self_ref.n, cols: self_ref.cols, size: self_ref.size, max_size: self_ref.max_size, } } } use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use rand::RngCore; use sampling::source::Source; impl ReaderFrom for VecZnx { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { self.n = reader.read_u64::()? as usize; self.cols = reader.read_u64::()? as usize; self.size = reader.read_u64::()? as usize; self.max_size = reader.read_u64::()? as usize; let len: usize = reader.read_u64::()? as usize; let buf: &mut [u8] = self.data.as_mut(); if buf.len() != len { return Err(std::io::Error::new( std::io::ErrorKind::UnexpectedEof, format!("self.data.len()={} != read len={}", buf.len(), len), )); } reader.read_exact(&mut buf[..len])?; Ok(()) } } impl WriterTo for VecZnx { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { writer.write_u64::(self.n as u64)?; writer.write_u64::(self.cols as u64)?; writer.write_u64::(self.size as u64)?; writer.write_u64::(self.max_size as u64)?; let buf: &[u8] = self.data.as_ref(); writer.write_u64::(buf.len() as u64)?; writer.write_all(buf)?; Ok(()) } }