renamed vmp API closer to spqlios

This commit is contained in:
Pro7ech
2025-08-25 11:58:51 +02:00
parent 1551f7a6f0
commit a1b865709d
38 changed files with 431 additions and 393 deletions

View File

@@ -1,12 +1,12 @@
use poulpy_hal::{
api::{TakeSlice, VmpApplyTmpBytes, VmpPrepareTmpBytes},
api::{TakeSlice, VmpApplyDftToDftTmpBytes, VmpPrepareTmpBytes},
layouts::{
Backend, MatZnx, MatZnxToRef, Module, Scratch, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, VmpPMat, VmpPMatOwned,
VmpPMatToMut, VmpPMatToRef, ZnxInfos, ZnxView, ZnxViewMut,
},
oep::{
VmpApplyAddImpl, VmpApplyAddTmpBytesImpl, VmpApplyImpl, VmpApplyTmpBytesImpl, VmpPMatAllocBytesImpl, VmpPMatAllocImpl,
VmpPMatFromBytesImpl, VmpPMatPrepareImpl, VmpPrepareTmpBytesImpl,
VmpApplyDftToDftAddImpl, VmpApplyDftToDftAddTmpBytesImpl, VmpApplyDftToDftImpl, VmpApplyDftToDftTmpBytesImpl,
VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPMatFromBytesImpl, VmpPMatPrepareImpl, VmpPrepareTmpBytesImpl,
},
};
@@ -109,8 +109,8 @@ unsafe impl VmpPMatPrepareImpl<FFT64> for FFT64 {
}
}
unsafe impl VmpApplyTmpBytesImpl<FFT64> for FFT64 {
fn vmp_apply_tmp_bytes_impl(
unsafe impl VmpApplyDftToDftTmpBytesImpl<FFT64> for FFT64 {
fn vmp_apply_dft_to_dft_tmp_bytes_impl(
module: &Module<FFT64>,
res_size: usize,
a_size: usize,
@@ -131,8 +131,8 @@ unsafe impl VmpApplyTmpBytesImpl<FFT64> for FFT64 {
}
}
unsafe impl VmpApplyImpl<FFT64> for FFT64 {
fn vmp_apply_impl<R, A, C>(module: &Module<FFT64>, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<FFT64>)
unsafe impl VmpApplyDftToDftImpl<FFT64> for FFT64 {
fn vmp_apply_dft_to_dft_impl<R, A, C>(module: &Module<FFT64>, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<FFT64>)
where
R: VecZnxDftToMut<FFT64>,
A: VecZnxDftToRef<FFT64>,
@@ -162,7 +162,7 @@ unsafe impl VmpApplyImpl<FFT64> for FFT64 {
);
}
let (tmp_bytes, _) = scratch.take_slice(module.vmp_apply_tmp_bytes(
let (tmp_bytes, _) = scratch.take_slice(module.vmp_apply_dft_to_dft_tmp_bytes(
res.size(),
a.size(),
b.rows(),
@@ -186,8 +186,8 @@ unsafe impl VmpApplyImpl<FFT64> for FFT64 {
}
}
unsafe impl VmpApplyAddTmpBytesImpl<FFT64> for FFT64 {
fn vmp_apply_add_tmp_bytes_impl(
unsafe impl VmpApplyDftToDftAddTmpBytesImpl<FFT64> for FFT64 {
fn vmp_apply_dft_to_dft_add_tmp_bytes_impl(
module: &Module<FFT64>,
res_size: usize,
a_size: usize,
@@ -208,9 +208,15 @@ unsafe impl VmpApplyAddTmpBytesImpl<FFT64> for FFT64 {
}
}
unsafe impl VmpApplyAddImpl<FFT64> for FFT64 {
fn vmp_apply_add_impl<R, A, C>(module: &Module<FFT64>, res: &mut R, a: &A, b: &C, scale: usize, scratch: &mut Scratch<FFT64>)
where
unsafe impl VmpApplyDftToDftAddImpl<FFT64> for FFT64 {
fn vmp_apply_dft_to_dft_add_impl<R, A, C>(
module: &Module<FFT64>,
res: &mut R,
a: &A,
b: &C,
scale: usize,
scratch: &mut Scratch<FFT64>,
) where
R: VecZnxDftToMut<FFT64>,
A: VecZnxDftToRef<FFT64>,
C: VmpPMatToRef<FFT64>,
@@ -239,7 +245,7 @@ unsafe impl VmpApplyAddImpl<FFT64> for FFT64 {
);
}
let (tmp_bytes, _) = scratch.take_slice(module.vmp_apply_tmp_bytes(
let (tmp_bytes, _) = scratch.take_slice(module.vmp_apply_dft_to_dft_tmp_bytes(
res.size(),
a.size(),
b.rows(),