use crate::{ api::{ VmpApplyDft, VmpApplyDftTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftAddTmpBytes, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPMatBytesOf, VmpPMatFromBytes, VmpPrepare, VmpPrepareTmpBytes, VmpZero, }, layouts::{ Backend, MatZnxToRef, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef, }, oep::{ VmpApplyDftImpl, VmpApplyDftTmpBytesImpl, VmpApplyDftToDftAddImpl, VmpApplyDftToDftAddTmpBytesImpl, VmpApplyDftToDftImpl, VmpApplyDftToDftTmpBytesImpl, VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPMatFromBytesImpl, VmpPrepareImpl, VmpPrepareTmpBytesImpl, VmpZeroImpl, }, }; impl VmpPMatAlloc for Module where B: Backend + VmpPMatAllocImpl, { fn vmp_pmat_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned { B::vmp_pmat_alloc_impl(self.n(), rows, cols_in, cols_out, size) } } impl VmpPMatBytesOf for Module where B: Backend + VmpPMatAllocBytesImpl, { fn bytes_of_vmp_pmat(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { B::vmp_pmat_bytes_of_impl(self.n(), rows, cols_in, cols_out, size) } } impl VmpPMatFromBytes for Module where B: Backend + VmpPMatFromBytesImpl, { fn vmp_pmat_from_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: Vec) -> VmpPMatOwned { B::vmp_pmat_from_bytes_impl(self.n(), rows, cols_in, cols_out, size, bytes) } } impl VmpPrepareTmpBytes for Module where B: Backend + VmpPrepareTmpBytesImpl, { fn vmp_prepare_tmp_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { B::vmp_prepare_tmp_bytes_impl(self, rows, cols_in, cols_out, size) } } impl VmpPrepare for Module where B: Backend + VmpPrepareImpl, { fn vmp_prepare(&self, res: &mut R, a: &A, scratch: &mut Scratch) where R: VmpPMatToMut, A: MatZnxToRef, { B::vmp_prepare_impl(self, res, a, scratch) } } impl VmpApplyDftTmpBytes for Module where B: Backend + VmpApplyDftTmpBytesImpl, { fn vmp_apply_dft_tmp_bytes( &self, res_size: usize, a_size: usize, b_rows: usize, b_cols_in: usize, b_cols_out: usize, b_size: usize, ) -> usize { B::vmp_apply_dft_tmp_bytes_impl( self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size, ) } } impl VmpApplyDft for Module where B: Backend + VmpApplyDftImpl, { fn vmp_apply_dft(&self, res: &mut R, a: &A, b: &C, scratch: &mut Scratch) where R: VecZnxDftToMut, A: VecZnxToRef, C: VmpPMatToRef, { B::vmp_apply_dft_impl(self, res, a, b, scratch); } } impl VmpApplyDftToDftTmpBytes for Module where B: Backend + VmpApplyDftToDftTmpBytesImpl, { fn vmp_apply_dft_to_dft_tmp_bytes( &self, res_size: usize, a_size: usize, b_rows: usize, b_cols_in: usize, b_cols_out: usize, b_size: usize, ) -> usize { B::vmp_apply_dft_to_dft_tmp_bytes_impl( self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size, ) } } impl VmpApplyDftToDft for Module where B: Backend + VmpApplyDftToDftImpl, { fn vmp_apply_dft_to_dft(&self, res: &mut R, a: &A, b: &C, scratch: &mut Scratch) where R: VecZnxDftToMut, A: VecZnxDftToRef, C: VmpPMatToRef, { B::vmp_apply_dft_to_dft_impl(self, res, a, b, scratch); } } impl VmpApplyDftToDftAddTmpBytes for Module where B: Backend + VmpApplyDftToDftAddTmpBytesImpl, { fn vmp_apply_dft_to_dft_add_tmp_bytes( &self, res_size: usize, a_size: usize, b_rows: usize, b_cols_in: usize, b_cols_out: usize, b_size: usize, ) -> usize { B::vmp_apply_dft_to_dft_add_tmp_bytes_impl( self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size, ) } } impl VmpApplyDftToDftAdd for Module where B: Backend + VmpApplyDftToDftAddImpl, { fn vmp_apply_dft_to_dft_add(&self, res: &mut R, a: &A, b: &C, scale: usize, scratch: &mut Scratch) where R: VecZnxDftToMut, A: VecZnxDftToRef, C: VmpPMatToRef, { B::vmp_apply_dft_to_dft_add_impl(self, res, a, b, scale, scratch); } } impl VmpZero for Module where B: Backend + VmpZeroImpl, { fn vmp_zero(&self, res: &mut R) where R: VmpPMatToMut, { B::vmp_zero_impl(self, res); } }