use std::marker::PhantomData; use crate::ffi::svp; use crate::znx_base::ZnxInfos; use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxView, alloc_aligned}; pub struct ScalarZnxDft { data: D, n: usize, cols: usize, _phantom: PhantomData, } impl ZnxInfos for ScalarZnxDft { fn cols(&self) -> usize { self.cols } fn rows(&self) -> usize { 1 } fn n(&self) -> usize { self.n } fn size(&self) -> usize { 1 } } impl ZnxSliceSize for ScalarZnxDft { fn sl(&self) -> usize { self.n() } } impl DataView for ScalarZnxDft { type D = D; fn data(&self) -> &Self::D { &self.data } } impl DataViewMut for ScalarZnxDft { fn data_mut(&mut self) -> &mut Self::D { &mut self.data } } impl> ZnxView for ScalarZnxDft { type Scalar = f64; } pub(crate) fn bytes_of_scalar_znx_dft(module: &Module, cols: usize) -> usize { ScalarZnxDftOwned::bytes_of(module, cols) } impl>, B: Backend> ScalarZnxDft { pub(crate) fn bytes_of(module: &Module, cols: usize) -> usize { unsafe { svp::bytes_of_svp_ppol(module.ptr) as usize * cols } } pub(crate) fn new(module: &Module, cols: usize) -> Self { let data = alloc_aligned::(Self::bytes_of(module, cols)); Self { data: data.into(), n: module.n(), cols, _phantom: PhantomData, } } pub(crate) fn new_from_bytes(module: &Module, cols: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); assert!(data.len() == Self::bytes_of(module, cols)); Self { data: data.into(), n: module.n(), cols, _phantom: PhantomData, } } } impl ScalarZnxDft { pub(crate) fn from_data(data: D, n: usize, cols: usize) -> Self { Self { data, n, cols, _phantom: PhantomData, } } } pub type ScalarZnxDftOwned = ScalarZnxDft, B>; pub trait ScalarZnxDftToRef { fn to_ref(&self) -> ScalarZnxDft<&[u8], B>; } pub trait ScalarZnxDftToMut { fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B>; } impl ScalarZnxDftToMut for ScalarZnxDft, B> { fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B> { ScalarZnxDft { data: self.data.as_mut_slice(), n: self.n, cols: self.cols, _phantom: PhantomData, } } } impl ScalarZnxDftToRef for ScalarZnxDft, B> { fn to_ref(&self) -> ScalarZnxDft<&[u8], B> { ScalarZnxDft { data: self.data.as_slice(), n: self.n, cols: self.cols, _phantom: PhantomData, } } } impl ScalarZnxDftToMut for ScalarZnxDft<&mut [u8], B> { fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B> { ScalarZnxDft { data: self.data, n: self.n, cols: self.cols, _phantom: PhantomData, } } } impl ScalarZnxDftToRef for ScalarZnxDft<&mut [u8], B> { fn to_ref(&self) -> ScalarZnxDft<&[u8], B> { ScalarZnxDft { data: self.data, n: self.n, cols: self.cols, _phantom: PhantomData, } } } impl ScalarZnxDftToRef for ScalarZnxDft<&[u8], B> { fn to_ref(&self) -> ScalarZnxDft<&[u8], B> { ScalarZnxDft { data: self.data, n: self.n, cols: self.cols, _phantom: PhantomData, } } }