use crate::ffi::vec_znx_big; use crate::znx_base::{ZnxInfos, ZnxView}; use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxViewMut, ZnxZero, alloc_aligned}; use std::fmt; use std::marker::PhantomData; pub struct VecZnxBig { data: D, n: usize, cols: usize, size: usize, _phantom: PhantomData, } impl ZnxInfos for VecZnxBig { fn cols(&self) -> usize { self.cols } fn rows(&self) -> usize { 1 } fn n(&self) -> usize { self.n } fn size(&self) -> usize { self.size } } impl ZnxSliceSize for VecZnxBig { fn sl(&self) -> usize { self.n() * self.cols() } } impl DataView for VecZnxBig { type D = D; fn data(&self) -> &Self::D { &self.data } } impl DataViewMut for VecZnxBig { fn data_mut(&mut self) -> &mut Self::D { &mut self.data } } impl> ZnxView for VecZnxBig { type Scalar = i64; } pub(crate) fn bytes_of_vec_znx_big(module: &Module, cols: usize, size: usize) -> usize { unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, size as u64) as usize * cols } } impl>, B: Backend> VecZnxBig { pub(crate) fn new(module: &Module, cols: usize, size: usize) -> Self { let data = alloc_aligned::(bytes_of_vec_znx_big(module, cols, size)); Self { data: data.into(), n: module.n(), cols, size, _phantom: PhantomData, } } pub(crate) fn new_from_bytes(module: &Module, cols: usize, size: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); assert!(data.len() == bytes_of_vec_znx_big(module, cols, size)); Self { data: data.into(), n: module.n(), cols, size, _phantom: PhantomData, } } } impl VecZnxBig { pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self { Self { data, n, cols, size, _phantom: PhantomData, } } } impl VecZnxBig where VecZnxBig: VecZnxBigToMut + ZnxInfos, { /// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self]. pub fn extract_column(&mut self, self_col: usize, a: &VecZnxBig, a_col: usize) where VecZnxBig: VecZnxBigToRef + ZnxInfos, { #[cfg(debug_assertions)] { assert!(self_col < self.cols()); assert!(a_col < a.cols()); } let min_size: usize = self.size.min(a.size()); let max_size: usize = self.size; let mut self_mut: VecZnxBig<&mut [u8], FFT64> = self.to_mut(); let a_ref: VecZnxBig<&[u8], FFT64> = a.to_ref(); (0..min_size).for_each(|i: usize| { self_mut .at_mut(self_col, i) .copy_from_slice(a_ref.at(a_col, i)); }); (min_size..max_size).for_each(|i| { self_mut.zero_at(self_col, i); }); } } pub type VecZnxBigOwned = VecZnxBig, B>; pub trait VecZnxBigToRef { fn to_ref(&self) -> VecZnxBig<&[u8], B>; } pub trait VecZnxBigToMut { fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B>; } impl VecZnxBigToMut for VecZnxBig, B> { fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> { VecZnxBig { data: self.data.as_mut_slice(), n: self.n, cols: self.cols, size: self.size, _phantom: PhantomData, } } } impl VecZnxBigToRef for VecZnxBig, B> { fn to_ref(&self) -> VecZnxBig<&[u8], B> { VecZnxBig { data: self.data.as_slice(), n: self.n, cols: self.cols, size: self.size, _phantom: PhantomData, } } } impl VecZnxBigToMut for VecZnxBig<&mut [u8], B> { fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> { VecZnxBig { data: self.data, n: self.n, cols: self.cols, size: self.size, _phantom: PhantomData, } } } impl VecZnxBigToRef for VecZnxBig<&mut [u8], B> { fn to_ref(&self) -> VecZnxBig<&[u8], B> { VecZnxBig { data: self.data, n: self.n, cols: self.cols, size: self.size, _phantom: PhantomData, } } } impl VecZnxBigToRef for VecZnxBig<&[u8], B> { fn to_ref(&self) -> VecZnxBig<&[u8], B> { VecZnxBig { data: self.data, n: self.n, cols: self.cols, size: self.size, _phantom: PhantomData, } } } impl> fmt::Display for VecZnxBig { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!( f, "VecZnxBig(n={}, cols={}, size={})", self.n, self.cols, self.size )?; for col in 0..self.cols { writeln!(f, "Column {}:", col)?; for size in 0..self.size { let coeffs = self.at(col, size); write!(f, " Size {}: [", size)?; let max_show = 100; let show_count = coeffs.len().min(max_show); for (i, &coeff) in coeffs.iter().take(show_count).enumerate() { if i > 0 { write!(f, ", ")?; } write!(f, "{}", coeff)?; } if coeffs.len() > max_show { write!(f, ", ... ({} more)", coeffs.len() - max_show)?; } writeln!(f, "]")?; } } Ok(()) } }