use std::marker::PhantomData; use poulpy_hal::{ DEFAULTALIGN, alloc_aligned, api::ScratchFromBytes, layouts::{Backend, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}, oep::{ ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, SvpPPolAllocBytesImpl, TakeMatZnxImpl, TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, TakeVmpPMatImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl, VmpPMatAllocBytesImpl, }, }; use crate::cpu_fft64_ref::FFT64Ref; unsafe impl ScratchOwnedAllocImpl for FFT64Ref { fn scratch_owned_alloc_impl(size: usize) -> ScratchOwned { let data: Vec = alloc_aligned(size); ScratchOwned { data, _phantom: PhantomData, } } } unsafe impl ScratchOwnedBorrowImpl for FFT64Ref where B: ScratchFromBytesImpl, { fn scratch_owned_borrow_impl(scratch: &mut ScratchOwned) -> &mut Scratch { Scratch::from_bytes(&mut scratch.data) } } unsafe impl ScratchFromBytesImpl for FFT64Ref { fn scratch_from_bytes_impl(data: &mut [u8]) -> &mut Scratch { unsafe { &mut *(data as *mut [u8] as *mut Scratch) } } } unsafe impl ScratchAvailableImpl for FFT64Ref { fn scratch_available_impl(scratch: &Scratch) -> usize { let ptr: *const u8 = scratch.data.as_ptr(); let self_len: usize = scratch.data.len(); let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN); self_len.saturating_sub(aligned_offset) } } unsafe impl TakeSliceImpl for FFT64Ref where B: ScratchFromBytesImpl, { fn take_slice_impl(scratch: &mut Scratch, len: usize) -> (&mut [T], &mut Scratch) { let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, len * std::mem::size_of::()); unsafe { ( &mut *(std::ptr::slice_from_raw_parts_mut(take_slice.as_mut_ptr() as *mut T, len)), Scratch::from_bytes(rem_slice), ) } } } unsafe impl TakeScalarZnxImpl for FFT64Ref where B: ScratchFromBytesImpl, { fn take_scalar_znx_impl(scratch: &mut Scratch, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Scratch) { let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, ScalarZnx::alloc_bytes(n, cols)); ( ScalarZnx::from_data(take_slice, n, cols), Scratch::from_bytes(rem_slice), ) } } unsafe impl TakeSvpPPolImpl for FFT64Ref where B: SvpPPolAllocBytesImpl + ScratchFromBytesImpl, { fn take_svp_ppol_impl(scratch: &mut Scratch, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Scratch) { let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, B::svp_ppol_alloc_bytes_impl(n, cols)); ( SvpPPol::from_data(take_slice, n, cols), Scratch::from_bytes(rem_slice), ) } } unsafe impl TakeVecZnxImpl for FFT64Ref where B: ScratchFromBytesImpl, { fn take_vec_znx_impl(scratch: &mut Scratch, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Scratch) { let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, VecZnx::alloc_bytes(n, cols, size)); ( VecZnx::from_data(take_slice, n, cols, size), Scratch::from_bytes(rem_slice), ) } } unsafe impl TakeVecZnxBigImpl for FFT64Ref where B: VecZnxBigAllocBytesImpl + ScratchFromBytesImpl, { fn take_vec_znx_big_impl( scratch: &mut Scratch, n: usize, cols: usize, size: usize, ) -> (VecZnxBig<&mut [u8], B>, &mut Scratch) { let (take_slice, rem_slice) = take_slice_aligned( &mut scratch.data, B::vec_znx_big_alloc_bytes_impl(n, cols, size), ); ( VecZnxBig::from_data(take_slice, n, cols, size), Scratch::from_bytes(rem_slice), ) } } unsafe impl TakeVecZnxDftImpl for FFT64Ref where B: VecZnxDftAllocBytesImpl + ScratchFromBytesImpl, { fn take_vec_znx_dft_impl( scratch: &mut Scratch, n: usize, cols: usize, size: usize, ) -> (VecZnxDft<&mut [u8], B>, &mut Scratch) { let (take_slice, rem_slice) = take_slice_aligned( &mut scratch.data, B::vec_znx_dft_alloc_bytes_impl(n, cols, size), ); ( VecZnxDft::from_data(take_slice, n, cols, size), Scratch::from_bytes(rem_slice), ) } } unsafe impl TakeVecZnxDftSliceImpl for FFT64Ref where B: VecZnxDftAllocBytesImpl + ScratchFromBytesImpl + TakeVecZnxDftImpl, { fn take_vec_znx_dft_slice_impl( scratch: &mut Scratch, len: usize, n: usize, cols: usize, size: usize, ) -> (Vec>, &mut Scratch) { let mut scratch: &mut Scratch = scratch; let mut slice: Vec> = Vec::with_capacity(len); for _ in 0..len { let (znx, new_scratch) = B::take_vec_znx_dft_impl(scratch, n, cols, size); scratch = new_scratch; slice.push(znx); } (slice, scratch) } } unsafe impl TakeVecZnxSliceImpl for FFT64Ref where B: ScratchFromBytesImpl + TakeVecZnxImpl, { fn take_vec_znx_slice_impl( scratch: &mut Scratch, len: usize, n: usize, cols: usize, size: usize, ) -> (Vec>, &mut Scratch) { let mut scratch: &mut Scratch = scratch; let mut slice: Vec> = Vec::with_capacity(len); for _ in 0..len { let (znx, new_scratch) = B::take_vec_znx_impl(scratch, n, cols, size); scratch = new_scratch; slice.push(znx); } (slice, scratch) } } unsafe impl TakeVmpPMatImpl for FFT64Ref where B: VmpPMatAllocBytesImpl + ScratchFromBytesImpl, { fn take_vmp_pmat_impl( scratch: &mut Scratch, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize, ) -> (VmpPMat<&mut [u8], B>, &mut Scratch) { let (take_slice, rem_slice) = take_slice_aligned( &mut scratch.data, B::vmp_pmat_alloc_bytes_impl(n, rows, cols_in, cols_out, size), ); ( VmpPMat::from_data(take_slice, n, rows, cols_in, cols_out, size), Scratch::from_bytes(rem_slice), ) } } unsafe impl TakeMatZnxImpl for FFT64Ref where B: ScratchFromBytesImpl, { fn take_mat_znx_impl( scratch: &mut Scratch, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize, ) -> (MatZnx<&mut [u8]>, &mut Scratch) { let (take_slice, rem_slice) = take_slice_aligned( &mut scratch.data, MatZnx::alloc_bytes(n, rows, cols_in, cols_out, size), ); ( MatZnx::from_data(take_slice, n, rows, cols_in, cols_out, size), Scratch::from_bytes(rem_slice), ) } } fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]) { let ptr: *mut u8 = data.as_mut_ptr(); let self_len: usize = data.len(); let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN); let aligned_len: usize = self_len.saturating_sub(aligned_offset); if let Some(rem_len) = aligned_len.checked_sub(take_len) { unsafe { let rem_ptr: *mut u8 = ptr.add(aligned_offset).add(take_len); let rem_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(rem_ptr, rem_len); let take_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(ptr.add(aligned_offset), take_len); (take_slice, rem_slice) } } else { panic!("Attempted to take {take_len} from scratch with {aligned_len} aligned bytes left"); } }