use std::{ fmt, hash::{DefaultHasher, Hasher}, marker::PhantomData, }; use rand_distr::num_traits::Zero; use crate::{ alloc_aligned, layouts::{ Backend, Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, VecZnxBig, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, }, oep::VecZnxDftAllocBytesImpl, }; #[repr(C)] #[derive(PartialEq, Eq)] pub struct VecZnxDft { pub data: D, pub n: usize, pub cols: usize, pub size: usize, pub max_size: usize, pub _phantom: PhantomData, } impl DigestU64 for VecZnxDft { fn digest_u64(&self) -> u64 { let mut h: DefaultHasher = DefaultHasher::new(); h.write(self.data.as_ref()); h.write_usize(self.n); h.write_usize(self.cols); h.write_usize(self.size); h.write_usize(self.max_size); h.finish() } } impl ZnxSliceSize for VecZnxDft { fn sl(&self) -> usize { B::layout_prep_word_count() * self.n() * self.cols() } } impl ZnxView for VecZnxDft { type Scalar = B::ScalarPrep; } impl VecZnxDft { pub fn into_big(self) -> VecZnxBig { VecZnxBig::::from_data(self.data, self.n, self.cols, self.size) } } impl ZnxInfos for VecZnxDft { fn cols(&self) -> usize { self.cols } fn rows(&self) -> usize { 1 } fn n(&self) -> usize { self.n } fn size(&self) -> usize { self.size } } impl DataView for VecZnxDft { type D = D; fn data(&self) -> &Self::D { &self.data } } impl DataViewMut for VecZnxDft { fn data_mut(&mut self) -> &mut Self::D { &mut self.data } } impl VecZnxDft { pub fn max_size(&self) -> usize { self.max_size } } impl VecZnxDft { pub fn set_size(&mut self, size: usize) { assert!(size <= self.max_size); self.size = size } } impl ZnxZero for VecZnxDft where Self: ZnxViewMut, ::Scalar: Zero + Copy, { fn zero(&mut self) { self.raw_mut().fill(::Scalar::zero()) } fn zero_at(&mut self, i: usize, j: usize) { self.at_mut(i, j).fill(::Scalar::zero()); } } impl>, B: Backend> VecZnxDft where B: VecZnxDftAllocBytesImpl, { pub fn alloc(n: usize, cols: usize, size: usize) -> Self { let data: Vec = alloc_aligned::(B::vec_znx_dft_alloc_bytes_impl(n, cols, size)); Self { data: data.into(), n, cols, size, max_size: size, _phantom: PhantomData, } } pub fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); assert!(data.len() == B::vec_znx_dft_alloc_bytes_impl(n, cols, size)); Self { data: data.into(), n, cols, size, max_size: size, _phantom: PhantomData, } } } pub type VecZnxDftOwned = VecZnxDft, B>; impl VecZnxDft { pub fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self { Self { data, n, cols, size, max_size: size, _phantom: PhantomData, } } } pub trait VecZnxDftToRef { fn to_ref(&self) -> VecZnxDft<&[u8], B>; } impl VecZnxDftToRef for VecZnxDft { fn to_ref(&self) -> VecZnxDft<&[u8], B> { VecZnxDft { data: self.data.as_ref(), n: self.n, cols: self.cols, size: self.size, max_size: self.max_size, _phantom: std::marker::PhantomData, } } } pub trait VecZnxDftToMut { fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B>; } impl VecZnxDftToMut for VecZnxDft { fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { VecZnxDft { data: self.data.as_mut(), n: self.n, cols: self.cols, size: self.size, max_size: self.max_size, _phantom: std::marker::PhantomData, } } } impl fmt::Display for VecZnxDft { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!( f, "VecZnxDft(n={}, cols={}, size={})", self.n, self.cols, self.size )?; for col in 0..self.cols { writeln!(f, "Column {}:", col)?; for size in 0..self.size { let coeffs = self.at(col, size); write!(f, " Size {}: [", size)?; let max_show = 100; let show_count = coeffs.len().min(max_show); for (i, &coeff) in coeffs.iter().take(show_count).enumerate() { if i > 0 { write!(f, ", ")?; } write!(f, "{}", coeff)?; } if coeffs.len() > max_show { write!(f, ", ... ({} more)", coeffs.len() - max_show)?; } writeln!(f, "]")?; } } Ok(()) } }